|
from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
from llama_cpp import Llama |
|
import torch |
|
import soundfile as sf |
|
import io |
|
import os |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI, File, UploadFile, Response |
|
app = FastAPI() |
|
|
|
|
|
if os.path.exists("./models/tts_model"): |
|
tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model") |
|
tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model") |
|
else: |
|
tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech") |
|
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech") |
|
|
|
|
|
if os.path.exists("./models/sst_model"): |
|
sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model") |
|
sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model") |
|
else: |
|
sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
|
|
|
|
if os.path.exists("./models/llama.gguf"): |
|
llm = Llama("./models/llama.gguf") |
|
else: |
|
raise FileNotFoundError("Please upload llama.gguf to models/ directory") |
|
|
|
|
|
class TTSRequest(BaseModel): |
|
text: str |
|
|
|
class LLMRequest(BaseModel): |
|
prompt: str |
|
|
|
@app.post("/tts") |
|
async def tts_endpoint(request: TTSRequest): |
|
text = request.text |
|
inputs = tts_tokenizer(text, return_tensors="pt") |
|
with torch.no_grad(): |
|
audio = tts_model.generate(**inputs) |
|
audio = audio.squeeze().cpu().numpy() |
|
buffer = io.BytesIO() |
|
sf.write(buffer, audio, 22050, format="WAV") |
|
buffer.seek(0) |
|
return Response(content=buffer.getvalue(), media_type="audio/wav") |
|
|
|
|
|
@app.post("/sst") |
|
async def sst_endpoint(file: UploadFile = File(...)): |
|
audio_bytes = await file.read() |
|
audio, sr = sf.read(io.BytesIO(audio_bytes)) |
|
inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt") |
|
with torch.no_grad(): |
|
logits = sst_model(inputs.input_values).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = sst_processor.batch_decode(predicted_ids)[0] |
|
return {"text": transcription} |
|
|
|
@app.post("/llm") |
|
async def llm_endpoint(request: LLMRequest): |
|
prompt = request.prompt |
|
output = llm(prompt, max_tokens=50) |
|
return {"text": output["choices"][0]["text"]} |