EzekielMW commited on
Commit
1558a49
·
verified ·
1 Parent(s): a51e729

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import torch
5
+ import os
6
+ import yaml
7
+ import transformers
8
+
9
+ app = FastAPI()
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"], # Adjust this as needed
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ # Load the model and tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset")
20
+ model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset")
21
+
22
+ # Where should output files be stored locally
23
+ drive_folder = "./serverlogs"
24
+
25
+ if not os.path.exists(drive_folder):
26
+ os.makedirs(drive_folder)
27
+
28
+
29
+ # Large batch sizes generally give good results for translation
30
+ effective_train_batch_size = 480
31
+ train_batch_size = 6
32
+ eval_batch_size = train_batch_size
33
+
34
+ gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)
35
+
36
+ # Everything in one yaml string, so that it can all be logged.
37
+ yaml_config = '''
38
+ training_args:
39
+ output_dir: "{drive_folder}"
40
+ eval_strategy: steps
41
+ eval_steps: 100
42
+ save_steps: 100
43
+ gradient_accumulation_steps: {gradient_accumulation_steps}
44
+ learning_rate: 3.0e-4 # Include decimal point to parse as float
45
+ # optim: adafactor
46
+ per_device_train_batch_size: {train_batch_size}
47
+ per_device_eval_batch_size: {eval_batch_size}
48
+ weight_decay: 0.01
49
+ save_total_limit: 3
50
+ max_steps: 500
51
+ predict_with_generate: True
52
+ fp16: True
53
+ logging_dir: "{drive_folder}"
54
+ load_best_model_at_end: True
55
+ metric_for_best_model: loss
56
+ seed: 123
57
+ push_to_hub: False
58
+
59
+ max_input_length: 128
60
+ eval_pretrained_model: False
61
+ early_stopping_patience: 4
62
+ data_dir: .
63
+
64
+ # Use a 600M parameter model here, which is easier to train on a free Colab
65
+ # instance. Bigger models work better, however: results will be improved
66
+ # if able to train on nllb-200-1.3B instead.
67
+ model_checkpoint: facebook/nllb-200-distilled-600M
68
+
69
+ datasets:
70
+ train:
71
+ huggingface_load:
72
+ # We will load two datasets here: English/KSL Gloss, and also SALT
73
+ # Swahili/English, so that we can try out multi-way translation.
74
+
75
+ - path: EzekielMW/Eksl_dataset
76
+ split: train[:-1000]
77
+ - path: sunbird/salt
78
+ name: text-all
79
+ split: train
80
+ source:
81
+ # This is a text translation only, no audio.
82
+ type: text
83
+ # The source text can be any of English, KSL or Swahili.
84
+ language: [eng,ksl,swa]
85
+ preprocessing:
86
+ # The models are case sensitive, so if the training text is all
87
+ # capitals, then it will only learn to translate capital letters and
88
+ # won't understand lower case. Make everything lower case for now.
89
+ - lower_case
90
+ # We can also augment the spelling of the input text, which makes the
91
+ # model more robust to spelling errors.
92
+ - augment_characters
93
+ target:
94
+ type: text
95
+ # The target text with any of English, KSL or Swahili.
96
+ language: [eng,ksl,swa]
97
+ # The models are case sensitive: make everything lower case for now.
98
+ preprocessing:
99
+ - lower_case
100
+
101
+ shuffle: True
102
+ allow_same_src_and_tgt_language: False
103
+
104
+ validation:
105
+ huggingface_load:
106
+ # Use the last 500 of the KSL examples for validation.
107
+ - path: EzekielMW/Eksl_dataset
108
+ split: train[-1000:]
109
+ # Add some Swahili validation text.
110
+ - path: sunbird/salt
111
+ name: text-all
112
+ split: dev
113
+ source:
114
+ type: text
115
+ language: [swa,ksl,eng]
116
+ preprocessing:
117
+ - lower_case
118
+ target:
119
+ type: text
120
+ language: [swa,ksl,eng]
121
+ preprocessing:
122
+ - lower_case
123
+ allow_same_src_and_tgt_language: False
124
+ '''
125
+
126
+ yaml_config = yaml_config.format(
127
+ drive_folder=drive_folder,
128
+ train_batch_size=train_batch_size,
129
+ eval_batch_size=eval_batch_size,
130
+ gradient_accumulation_steps=gradient_accumulation_steps,
131
+ )
132
+
133
+ config = yaml.safe_load(yaml_config)
134
+
135
+ training_settings = transformers.Seq2SeqTrainingArguments(
136
+ **config["training_args"])
137
+ # The pre-trained model that we use has support for some African languages, but
138
+ # we need to adapt the tokenizer to languages that it wasn't trained with,
139
+ # such as KSL. Here we reuse the token from a different language.
140
+ LANGUAGE_CODES = ["eng", "swa", "ksl"]
141
+
142
+ code_mapping = {
143
+ # Exact/close mapping
144
+ 'eng': 'eng_Latn',
145
+ 'swa': 'swh_Latn',
146
+ # Random mapping
147
+ 'ksl': 'ace_Latn',
148
+ }
149
+ tokenizer = transformers.NllbTokenizer.from_pretrained(
150
+ config['model_checkpoint'],
151
+ src_lang='eng_Latn',
152
+ tgt_lang='eng_Latn')
153
+
154
+ offset = tokenizer.sp_model_size + tokenizer.fairseq_offset
155
+
156
+ for code in LANGUAGE_CODES:
157
+ i = tokenizer.convert_tokens_to_ids(code_mapping[code])
158
+ tokenizer._added_tokens_encoder[code] = i
159
+
160
+ # Define a translation function
161
+ def translate(text, source_language, target_language):
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+ inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
164
+ inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
165
+ translated_tokens = model.to(device).generate(
166
+ **inputs,
167
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
168
+ max_length=100,
169
+ num_beams=5,
170
+ )
171
+ result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
172
+
173
+ if target_language == 'ksl':
174
+ result = result.upper()
175
+
176
+ return result
177
+
178
+ @app.post("/translate")
179
+ async def translate_text(request: Request):
180
+ data = await request.json()
181
+ text = data.get("text")
182
+ source_language = data.get("source_language")
183
+ target_language = data.get("target_language")
184
+
185
+ translation = translate(text, source_language, target_language)
186
+ return {"translation": translation}
187
+
188
+
189
+
190
+
191
+
192
+