File size: 2,494 Bytes
81875a2
2f38e4a
 
 
 
 
 
 
81875a2
2f38e4a
 
81875a2
fa48fc0
81875a2
fa48fc0
2f38e4a
81875a2
 
2f38e4a
81875a2
fa48fc0
 
 
2f38e4a
 
81875a2
2f38e4a
81875a2
fa48fc0
 
2f38e4a
 
 
81875a2
2f38e4a
 
 
 
 
 
 
 
 
 
 
fa48fc0
 
2f38e4a
fa48fc0
2f38e4a
 
 
81875a2
2f38e4a
 
 
 
 
 
fa48fc0
 
2f38e4a
 
 
 
 
 
fa48fc0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from llama_cpp import Llama
import torch
import soundfile as sf
import io
import os
from pydantic import BaseModel
from fastapi import FastAPI, File, UploadFile, Response
app = FastAPI()

# Load TTS model
if os.path.exists("./models/tts_model"):
    tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model")
    tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
else:
    tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech")
    tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech")

# Load SST model
if os.path.exists("./models/sst_model"):
    sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
    sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
else:
    sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
    sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

# Load LLM model
if os.path.exists("./models/llama.gguf"):
    llm = Llama("./models/llama.gguf")
else:
    raise FileNotFoundError("Please upload llama.gguf to models/ directory")

# Request models (unchanged)
class TTSRequest(BaseModel):
    text: str

class LLMRequest(BaseModel):
    prompt: str

@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
    text = request.text
    inputs = tts_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        audio = tts_model.generate(**inputs)
    audio = audio.squeeze().cpu().numpy()
    buffer = io.BytesIO()
    sf.write(buffer, audio, 22050, format="WAV")
    buffer.seek(0)
    return Response(content=buffer.getvalue(), media_type="audio/wav")

# SST and LLM endpoints remain unchanged
@app.post("/sst")
async def sst_endpoint(file: UploadFile = File(...)):
    audio_bytes = await file.read()
    audio, sr = sf.read(io.BytesIO(audio_bytes))
    inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
    with torch.no_grad():
        logits = sst_model(inputs.input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = sst_processor.batch_decode(predicted_ids)[0]
    return {"text": transcription}

@app.post("/llm")
async def llm_endpoint(request: LLMRequest):
    prompt = request.prompt
    output = llm(prompt, max_tokens=50)
    return {"text": output["choices"][0]["text"]}