vtiw commited on
Commit
cc5d0c7
·
verified ·
1 Parent(s): dc1eeb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -21
app.py CHANGED
@@ -14,19 +14,21 @@ from happytransformer import HappyTextToText, TTSettings
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
15
  from transformers.integrations import deepspeed
16
  import re
 
17
  import torch
18
- from lang_list import (
19
- LANGUAGE_NAME_TO_CODE,
20
- T2TT_TARGET_LANGUAGE_NAMES,
21
- TEXT_SOURCE_LANGUAGE_NAMES,
22
- )
23
  logging.set_verbosity_error()
24
 
25
  DEFAULT_TARGET_LANGUAGE = "English"
26
- from transformers import SeamlessM4TForTextToText
27
- from transformers import AutoProcessor
28
- model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-large")
29
- processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-large")
 
 
 
 
30
 
31
 
32
  import pytesseract as pt
@@ -174,22 +176,59 @@ def split_text_into_batches(text, max_tokens_per_batch):
174
  @spaces.GPU(duration=60)
175
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
176
  if file_uploader is not None:
177
- with open(file_uploader, 'r') as file:
178
- input_text=file.read()
179
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
180
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
181
- max_tokens_per_batch= 2048
 
 
 
 
 
 
 
 
 
182
  batches = split_text_into_batches(input_text, max_tokens_per_batch)
183
  translated_text = ""
 
184
  for batch in batches:
185
- text_inputs = processor(text=batch, src_lang=source_language_code, return_tensors="pt")
186
- output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
187
- translated_batch = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
188
- translated_text += translated_batch + " "
189
- output=translated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  _output_name = "result.txt"
191
- open(_output_name, 'w').write(output)
192
- return str(output), _output_name
 
 
193
 
194
  with gr.Blocks() as demo_t2tt:
195
  with gr.Row():
 
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
15
  from transformers.integrations import deepspeed
16
  import re
17
+ from IndicTransToolkit import IndicProcessor
18
  import torch
19
+ import torch
20
+
 
 
 
21
  logging.set_verbosity_error()
22
 
23
  DEFAULT_TARGET_LANGUAGE = "English"
24
+ # Load IndicTrans2 model
25
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
28
+ ip = IndicProcessor(inference=True)
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model.to(DEVICE)
31
+
32
 
33
 
34
  import pytesseract as pt
 
176
  @spaces.GPU(duration=60)
177
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
178
  if file_uploader is not None:
179
+ with open(file_uploader.name, "r", encoding="utf-8") as file:
180
+ input_text = file.read()
181
+
182
+ # Language mapping
183
+ lang_code_map = {
184
+ "Hindi": "hin_Deva",
185
+ "Punjabi": "pan_Guru",
186
+ "English": "eng_Latn",
187
+ }
188
+
189
+ src_lang = lang_code_map[source_language]
190
+ tgt_lang = lang_code_map[target_language]
191
+
192
+ max_tokens_per_batch = 256
193
  batches = split_text_into_batches(input_text, max_tokens_per_batch)
194
  translated_text = ""
195
+
196
  for batch in batches:
197
+ batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang)
198
+ inputs = tokenizer(
199
+ batch_preprocessed,
200
+ truncation=True,
201
+ padding="longest",
202
+ return_tensors="pt",
203
+ return_attention_mask=True,
204
+ ).to(DEVICE)
205
+
206
+ with torch.no_grad():
207
+ generated_tokens = model.generate(
208
+ **inputs,
209
+ use_cache=True,
210
+ min_length=0,
211
+ max_length=256,
212
+ num_beams=5,
213
+ num_return_sequences=1,
214
+ )
215
+
216
+ with tokenizer.as_target_tokenizer():
217
+ decoded_tokens = tokenizer.batch_decode(
218
+ generated_tokens.detach().cpu().tolist(),
219
+ skip_special_tokens=True,
220
+ clean_up_tokenization_spaces=True,
221
+ )
222
+
223
+ translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang)
224
+ translated_text += " ".join(translations) + " "
225
+
226
+ output = translated_text.strip()
227
  _output_name = "result.txt"
228
+ with open(_output_name, "w", encoding="utf-8") as out_file:
229
+ out_file.write(output)
230
+
231
+ return output, _output_name
232
 
233
  with gr.Blocks() as demo_t2tt:
234
  with gr.Row():