from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from llama_cpp import Llama import torch import soundfile as sf import io import os from pydantic import BaseModel from fastapi import FastAPI, File, UploadFile, Response app = FastAPI() # Load TTS model if os.path.exists("./models/tts_model"): tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model") tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model") else: tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech") tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech") # Load SST model if os.path.exists("./models/sst_model"): sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model") sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model") else: sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") # Load LLM model if os.path.exists("./models/llama.gguf"): llm = Llama("./models/llama.gguf") else: raise FileNotFoundError("Please upload llama.gguf to models/ directory") # Request models (unchanged) class TTSRequest(BaseModel): text: str class LLMRequest(BaseModel): prompt: str @app.post("/tts") async def tts_endpoint(request: TTSRequest): text = request.text inputs = tts_tokenizer(text, return_tensors="pt") with torch.no_grad(): audio = tts_model.generate(**inputs) audio = audio.squeeze().cpu().numpy() buffer = io.BytesIO() sf.write(buffer, audio, 22050, format="WAV") buffer.seek(0) return Response(content=buffer.getvalue(), media_type="audio/wav") # SST and LLM endpoints remain unchanged @app.post("/sst") async def sst_endpoint(file: UploadFile = File(...)): audio_bytes = await file.read() audio, sr = sf.read(io.BytesIO(audio_bytes)) inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt") with torch.no_grad(): logits = sst_model(inputs.input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = sst_processor.batch_decode(predicted_ids)[0] return {"text": transcription} @app.post("/llm") async def llm_endpoint(request: LLMRequest): prompt = request.prompt output = llm(prompt, max_tokens=50) return {"text": output["choices"][0]["text"]}