akshatOP commited on
Commit
81875a2
·
1 Parent(s): fa48fc0

Switch to facebook/tts_transformer-en-ljspeech for TTS

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -1,5 +1,4 @@
1
- from fastapi import FastAPI, File, UploadFile, Response
2
- from transformers import ParlerTTSForConditionalGeneration, AutoTokenizer
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
  from llama_cpp import Llama
5
  import torch
@@ -7,31 +6,32 @@ import soundfile as sf
7
  import io
8
  import os
9
  from pydantic import BaseModel
10
-
11
  app = FastAPI()
12
 
13
- # Load models
14
  if os.path.exists("./models/tts_model"):
15
- tts_model = ParlerTTSForConditionalGeneration.from_pretrained("./models/tts_model")
16
  tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
17
  else:
18
- tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")
19
- tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
20
 
21
- # SST and LLM loading remains unchanged
22
  if os.path.exists("./models/sst_model"):
23
  sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
24
  sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
25
  else:
26
  sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
27
- sst_processor = Wav2Vec2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
28
 
 
29
  if os.path.exists("./models/llama.gguf"):
30
  llm = Llama("./models/llama.gguf")
31
  else:
32
  raise FileNotFoundError("Please upload llama.gguf to models/ directory")
33
 
34
- # Request models and endpoints remain unchanged
35
  class TTSRequest(BaseModel):
36
  text: str
37
 
@@ -50,6 +50,7 @@ async def tts_endpoint(request: TTSRequest):
50
  buffer.seek(0)
51
  return Response(content=buffer.getvalue(), media_type="audio/wav")
52
 
 
53
  @app.post("/sst")
54
  async def sst_endpoint(file: UploadFile = File(...)):
55
  audio_bytes = await file.read()
 
1
+ from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer
 
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  from llama_cpp import Llama
4
  import torch
 
6
  import io
7
  import os
8
  from pydantic import BaseModel
9
+ from fastapi import FastAPI, File, UploadFile, Response
10
  app = FastAPI()
11
 
12
+ # Load TTS model
13
  if os.path.exists("./models/tts_model"):
14
+ tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model")
15
  tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
16
  else:
17
+ tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech")
18
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech")
19
 
20
+ # Load SST model
21
  if os.path.exists("./models/sst_model"):
22
  sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
23
  sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
24
  else:
25
  sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
26
+ sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
27
 
28
+ # Load LLM model
29
  if os.path.exists("./models/llama.gguf"):
30
  llm = Llama("./models/llama.gguf")
31
  else:
32
  raise FileNotFoundError("Please upload llama.gguf to models/ directory")
33
 
34
+ # Request models (unchanged)
35
  class TTSRequest(BaseModel):
36
  text: str
37
 
 
50
  buffer.seek(0)
51
  return Response(content=buffer.getvalue(), media_type="audio/wav")
52
 
53
+ # SST and LLM endpoints remain unchanged
54
  @app.post("/sst")
55
  async def sst_endpoint(file: UploadFile = File(...)):
56
  audio_bytes = await file.read()