transcriptionV3 / app.py
KIMOSSINO's picture
Rename APP.PY to app.py
d30be26 verified
raw
history blame
3.06 kB
from fastapi import FastAPI, File, UploadFile
import uvicorn
from pydantic import BaseModel
import whisper
from transformers import MarianMTModel, MarianTokenizer
import subprocess
import os
from pathlib import Path
app = FastAPI()
# Load models
def load_models():
global whisper_model, translation_models
whisper_model = whisper.load_model("base") # Whisper model
translation_models = {
"en": MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-es"),
"es": MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-es-en"),
"fr": MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-fr-en"),
"ar": MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-ar-en"),
}
translation_tokenizers = {
lang: MarianTokenizer.from_pretrained(f"Helsinki-NLP/opus-mt-{lang}-en")
for lang in translation_models.keys()
}
return translation_models, translation_tokenizers
translation_models, translation_tokenizers = load_models()
# Whisper endpoint
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...), language: str = "en"):
try:
# Save the file temporarily
temp_file = f"temp/{file.filename}"
Path("temp").mkdir(parents=True, exist_ok=True)
with open(temp_file, "wb") as f:
f.write(await file.read())
# Transcription using Whisper
result = whisper_model.transcribe(temp_file, language=language)
transcription = result["text"]
os.remove(temp_file) # Clean up
return {"success": True, "transcription": transcription}
except Exception as e:
return {"success": False, "error": str(e)}
# Translation endpoint
@app.post("/translate")
async def translate(text: str, source_lang: str, target_lang: str):
try:
if source_lang not in translation_models or target_lang != "en":
return {"success": False, "error": "Unsupported language."}
# Tokenize and translate
tokenizer = translation_tokenizers[source_lang]
model = translation_models[source_lang]
inputs = tokenizer(text, return_tensors="pt", padding=True)
translated_tokens = model.generate(**inputs)
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return {"success": True, "translation": translated_text}
except Exception as e:
return {"success": False, "error": str(e)}
# TTS endpoint
@app.post("/tts")
async def text_to_speech(text: str, speaker: str = "male", speed: str = "normal"):
try:
output_file = "output.wav"
# Coqui TTS command
tts_command = [
"tts",
f"--text={text}",
"--model_name=tts_models/en/ljspeech/tacotron2-DCA",
f"--out_path={output_file}",
]
subprocess.run(tts_command)
return {"success": True, "audio_file": output_file}
except Exception as e:
return {"success": False, "error": str(e)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)