indic-trans / app.py
darshankr's picture
Update app.py
438e2f4 verified
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForSeq2SeqLM
from IndicTransToolkit import IndicProcessor
from typing import List
import os
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True)
ip = IndicProcessor(inference=True)
app = FastAPI()
# Define request body with Pydantic
class InputData(BaseModel):
sentences: List[str]
target_lang: str
# API endpoint to receive input and return predictions
@app.post("/translate/")
async def predict(input_data: InputData):
try:
result = model(input_data.text)
return {"output": result}
src_lang, tgt_lang = "eng_Latn", input_data.target_lang
batch = ip.preprocess_batch(
input_sentences,
src_lang=src_lang,
tgt_lang=tgt_lang,
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Tokenize the sentences and generate input encodings
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations using the model
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# Postprocess the translations, including entity replacement
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
return {"output": translations}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))