import gradio as gr import nltk nltk.download('punkt_tab') from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor import torch # Load IndicTrans2 model model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) ip = IndicProcessor(inference=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model.to(DEVICE) def split_text_into_batches(text, max_tokens_per_batch): sentences = nltk.sent_tokenize(text) # Tokenize text into sentences batches = [] current_batch = "" for sentence in sentences: if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space current_batch += sentence + " " # Add sentence to current batch else: batches.append(current_batch.strip()) # Add current batch to batches list current_batch = sentence + " " # Start a new batch with the current sentence if current_batch: batches.append(current_batch.strip()) # Add the last batch return batches def run_translation(file_uploader, input_text, source_language, target_language): if file_uploader is not None: with open(file_uploader.name, "r", encoding="utf-8") as file: input_text = file.read() # Language mapping lang_code_map = { "Hindi": "hin_Deva", "Punjabi": "pan_Guru", "English": "eng_Latn", } src_lang = lang_code_map[source_language] tgt_lang = lang_code_map[target_language] max_tokens_per_batch = 256 batches = split_text_into_batches(input_text, max_tokens_per_batch) translated_text = "" for batch in batches: batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang) inputs = tokenizer( batch_preprocessed, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): decoded_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang) translated_text += " ".join(translations) + " " output = translated_text.strip() _output_name = "result.txt" with open(_output_name, "w", encoding="utf-8") as out_file: out_file.write(output) return output, _output_name # Define Gradio UI with gr.Blocks() as demo: with gr.Row(): with gr.Column(): file_uploader = gr.File(label="Upload a text file (Optional)") input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...") source_language = gr.Dropdown( label="Source language", choices=["Hindi", "Punjabi", "English"], value="Hindi", ) target_language = gr.Dropdown( label="Target language", choices=["Hindi", "Punjabi", "English"], value="English", ) btn = gr.Button("Translate") with gr.Column(): output_text = gr.Textbox(label="Translated text", lines=5) output_file = gr.File(label="Translated text file") btn.click( fn=run_translation, inputs=[file_uploader, input_text, source_language, target_language], outputs=[output_text, output_file], ) if __name__ == "__main__": demo.launch(debug=True)