File size: 2,494 Bytes
81875a2 2f38e4a 81875a2 2f38e4a 81875a2 fa48fc0 81875a2 fa48fc0 2f38e4a 81875a2 2f38e4a 81875a2 fa48fc0 2f38e4a 81875a2 2f38e4a 81875a2 fa48fc0 2f38e4a 81875a2 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a 81875a2 2f38e4a fa48fc0 2f38e4a fa48fc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from llama_cpp import Llama
import torch
import soundfile as sf
import io
import os
from pydantic import BaseModel
from fastapi import FastAPI, File, UploadFile, Response
app = FastAPI()
# Load TTS model
if os.path.exists("./models/tts_model"):
tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model")
tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
else:
tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech")
# Load SST model
if os.path.exists("./models/sst_model"):
sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
else:
sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
# Load LLM model
if os.path.exists("./models/llama.gguf"):
llm = Llama("./models/llama.gguf")
else:
raise FileNotFoundError("Please upload llama.gguf to models/ directory")
# Request models (unchanged)
class TTSRequest(BaseModel):
text: str
class LLMRequest(BaseModel):
prompt: str
@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
text = request.text
inputs = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
audio = tts_model.generate(**inputs)
audio = audio.squeeze().cpu().numpy()
buffer = io.BytesIO()
sf.write(buffer, audio, 22050, format="WAV")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="audio/wav")
# SST and LLM endpoints remain unchanged
@app.post("/sst")
async def sst_endpoint(file: UploadFile = File(...)):
audio_bytes = await file.read()
audio, sr = sf.read(io.BytesIO(audio_bytes))
inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = sst_model(inputs.input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = sst_processor.batch_decode(predicted_ids)[0]
return {"text": transcription}
@app.post("/llm")
async def llm_endpoint(request: LLMRequest):
prompt = request.prompt
output = llm(prompt, max_tokens=50)
return {"text": output["choices"][0]["text"]} |