nuera / app.py
akshatOP's picture
Switch to facebook/tts_transformer-en-ljspeech for TTS
81875a2
raw
history blame
2.49 kB
from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from llama_cpp import Llama
import torch
import soundfile as sf
import io
import os
from pydantic import BaseModel
from fastapi import FastAPI, File, UploadFile, Response
app = FastAPI()
# Load TTS model
if os.path.exists("./models/tts_model"):
tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model")
tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
else:
tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech")
# Load SST model
if os.path.exists("./models/sst_model"):
sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
else:
sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
# Load LLM model
if os.path.exists("./models/llama.gguf"):
llm = Llama("./models/llama.gguf")
else:
raise FileNotFoundError("Please upload llama.gguf to models/ directory")
# Request models (unchanged)
class TTSRequest(BaseModel):
text: str
class LLMRequest(BaseModel):
prompt: str
@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
text = request.text
inputs = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
audio = tts_model.generate(**inputs)
audio = audio.squeeze().cpu().numpy()
buffer = io.BytesIO()
sf.write(buffer, audio, 22050, format="WAV")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="audio/wav")
# SST and LLM endpoints remain unchanged
@app.post("/sst")
async def sst_endpoint(file: UploadFile = File(...)):
audio_bytes = await file.read()
audio, sr = sf.read(io.BytesIO(audio_bytes))
inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = sst_model(inputs.input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = sst_processor.batch_decode(predicted_ids)[0]
return {"text": transcription}
@app.post("/llm")
async def llm_endpoint(request: LLMRequest):
prompt = request.prompt
output = llm(prompt, max_tokens=50)
return {"text": output["choices"][0]["text"]}