cstr commited on
Commit
0bb0d8e
·
verified ·
1 Parent(s): 95d86d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +601 -0
app.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #python app.py
2
+ import gradio as gr
3
+ import os
4
+ import pandas as pd
5
+ import requests
6
+ from pathlib import Path
7
+ import ctranslate2
8
+ import time
9
+ import logging
10
+ import transformers
11
+ import json
12
+ from tqdm import tqdm
13
+ import subprocess
14
+ from huggingface_hub import snapshot_download, upload_file
15
+
16
+ # Function to download a Parquet file from a specified URL
17
+ def download_parquet(url, local_path):
18
+ response = requests.get(url, stream=True)
19
+ if response.status_code == 200:
20
+ with open(local_path, 'wb') as file:
21
+ for chunk in response.iter_content(chunk_size=1024):
22
+ file.write(chunk)
23
+ print("File downloaded successfully.")
24
+ else:
25
+ print(f"Failed to download file, status code: {response.status_code}")
26
+
27
+ # Function to convert Parquet files to JSONL format
28
+ def convert_parquet_to_jsonl_polars(input_file, output_dir, override=False):
29
+ output_dir_path = Path(output_dir)
30
+ output_dir_path.mkdir(parents=True, exist_ok=True)
31
+
32
+ input_path = Path(input_file)
33
+ output_file_path = output_dir_path / input_path.with_suffix(".jsonl").name
34
+
35
+ if output_file_path.exists() and not override:
36
+ print(f"Skipping because output exists already: {output_file_path}")
37
+ else:
38
+ df = pl.read_parquet(input_path)
39
+ df.write_ndjson(output_file_path)
40
+ print(f"Data written to {output_file_path}")
41
+
42
+ def convert_parquet_to_jsonl(parquet_filename, jsonl_filename):
43
+ # Read the parquet file
44
+ df = pd.read_parquet(parquet_filename)
45
+
46
+ # Convert the dataframe to a JSON string and handle Unicode characters and forward slashes
47
+ json_str = df.to_json(orient='records', lines=True, force_ascii=False)
48
+
49
+ # Replace escaped forward slashes if needed
50
+ json_str = json_str.replace('\\/', '/')
51
+
52
+ # Write the modified JSON string to the JSONL file
53
+ with open(jsonl_filename, 'w', encoding='utf-8') as file:
54
+ file.write(json_str)
55
+
56
+ print(f"Data saved to {jsonl_filename}")
57
+
58
+ # Function to count lines in a JSONL file
59
+ def count_lines_in_jsonl(file_path):
60
+ with open(file_path, 'r', encoding='utf-8') as file:
61
+ line_count = sum(1 for _ in file)
62
+ return line_count
63
+
64
+ def parse_range_specification(range_specification, file_length):
65
+ line_indices = []
66
+ ranges = range_specification.split(',')
67
+ for r in ranges:
68
+ if '-' in r:
69
+ parts = r.split('-')
70
+ start = int(parts[0]) - 1 if parts[0] else 0
71
+ end = int(parts[1]) - 1 if parts[1] else file_length - 1
72
+ if start < 0 or end >= file_length:
73
+ logging.error(f"Range {r} is out of bounds.")
74
+ continue # Skip ranges that are out of bounds
75
+ line_indices.extend(range(start, end + 1))
76
+ else:
77
+ single_line = int(r) - 1
78
+ if single_line < 0 or single_line >= file_length:
79
+ logging.error(f"Line number {r} is out of bounds.")
80
+ continue # Skip line numbers that are out of bounds
81
+ line_indices.append(single_line)
82
+ return line_indices
83
+
84
+ def translate_text(text, translator, tokenizer):
85
+ """
86
+ Translates the given text from English to German using CTranslate2 and the WMT21 model,
87
+ with special handling for newlines and segmenting text longer than 500 characters.
88
+ Ensures sequences of newlines (\n\n, \n\n\n, etc.) are accurately reproduced.
89
+ """
90
+ try:
91
+ segments = []
92
+ newline_sequences = [] # To store sequences of newlines
93
+ segment = ""
94
+
95
+ i = 0
96
+ while i < len(text):
97
+ # Collect sequences of newlines
98
+ if text[i] == '\n':
99
+ newline_sequence = '\n'
100
+ while i + 1 < len(text) and text[i + 1] == '\n':
101
+ newline_sequence += '\n'
102
+ i += 1
103
+ if segment:
104
+ segments.append(segment) # Add the preceding text segment
105
+ segment = ""
106
+ newline_sequences.append(newline_sequence) # Store the newline sequence
107
+ else:
108
+ segment += text[i]
109
+ # If segment exceeds 500 characters, or if we reach the end of the text, process it
110
+ if len(segment) >= 500 or i == len(text) - 1:
111
+ end_index = max(segment.rfind('.', 0, 500), segment.rfind('?', 0, 500), segment.rfind('!', 0, 500))
112
+ if end_index != -1 and len(segment) > 500:
113
+ # Split at the last punctuation within the first 500 characters
114
+ segments.append(segment[:end_index+1])
115
+ segment = segment[end_index+1:].lstrip()
116
+ else:
117
+ # No suitable punctuation or end of text, add the whole segment
118
+ segments.append(segment)
119
+ segment = ""
120
+ i += 1
121
+
122
+ # Translate the collected text segments
123
+ translated_segments = []
124
+ for segment in segments:
125
+ source = tokenizer.convert_ids_to_tokens(tokenizer.encode(segment))
126
+ target_prefix = [tokenizer.lang_code_to_token["de"]]
127
+ results = translator.translate_batch([source], target_prefix=[target_prefix])
128
+ target = results[0].hypotheses[0][1:]
129
+ translated_segment = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
130
+ translated_segments.append(translated_segment)
131
+
132
+ # Reassemble the translated text with original newline sequences
133
+ translated_text = ""
134
+ for i, segment in enumerate(translated_segments):
135
+ translated_text += segment
136
+ if i < len(newline_sequences):
137
+ translated_text += newline_sequences[i] # Insert the newline sequence
138
+
139
+ return translated_text.strip()
140
+
141
+ except Exception as e:
142
+ logging.error(f"An error occurred during translation: {e}")
143
+ return None
144
+
145
+ def translate_item_ufb(item, raw_file_path, translator, tokenizer):
146
+ try:
147
+ # Translate the prompt directly since it's a string
148
+ translated_prompt = translate_text(item['prompt'], translator, tokenizer)
149
+
150
+ # Translate the chosen and rejected contents
151
+ translated_chosen = []
152
+ for choice in item['chosen']:
153
+ translated_content = translate_text(choice['content'], translator, tokenizer)
154
+ translated_chosen.append({'content': translated_content, 'role': choice['role']})
155
+
156
+ translated_rejected = []
157
+ for choice in item['rejected']:
158
+ translated_content = translate_text(choice['content'], translator, tokenizer)
159
+ translated_rejected.append({'content': translated_content, 'role': choice['role']})
160
+
161
+ # Write the raw response to a backup file
162
+ with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
163
+ raw_file.write(f"Prompt: {translated_prompt}\n")
164
+ raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n")
165
+ raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n")
166
+
167
+ logging.info("Translation request successful.")
168
+ # Update the original item with the translated fields
169
+ item['prompt'] = translated_prompt
170
+ item['chosen'] = translated_chosen
171
+ item['rejected'] = translated_rejected
172
+ return item
173
+
174
+ except Exception as e:
175
+ logging.error(f"An error occurred during translation: {e}")
176
+ return None
177
+
178
+ def validate_item_ufb(item):
179
+ # Check basic required fields including 'prompt' as a simple string
180
+ required_fields = ['source', 'prompt', 'chosen', 'rejected']
181
+ for field in required_fields:
182
+ if field not in item:
183
+ logging.warning(f"Missing required field: {field}")
184
+ return False
185
+ if field == 'prompt' and not isinstance(item['prompt'], str):
186
+ logging.warning("Prompt must be a string.")
187
+ return False
188
+
189
+ # Check 'chosen' and 'rejected' which should be lists of dictionaries
190
+ for field in ['chosen', 'rejected']:
191
+ if not isinstance(item[field], list) or not item[field]:
192
+ logging.warning(f"No entries or incorrect type for section: {field}")
193
+ return False
194
+ for idx, message in enumerate(item[field]):
195
+ if 'content' not in message or 'role' not in message:
196
+ logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}")
197
+ return False
198
+ if not isinstance(message['content'], str) or not isinstance(message['role'], str):
199
+ logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}")
200
+ return False
201
+
202
+ return True
203
+
204
+
205
+
206
+ def translate_item_mix(item, raw_file_path, translator, tokenizer):
207
+ """
208
+ Translates the relevant fields in the given item from English to German using CTranslate2 and the WMT21 model,
209
+ and saves the raw response to a backup file.
210
+ """
211
+ #print ("translating:", item)
212
+ try:
213
+ # Translate each part of the prompt separately and preserve the order
214
+ translated_prompts = []
215
+ for message in item['prompt']:
216
+ translated_content = translate_text(message['content'], translator, tokenizer)
217
+ translated_prompts.append({'content': translated_content, 'role': message['role']})
218
+
219
+ # Translate the chosen and rejected contents
220
+ translated_chosen_content = translate_text(item['chosen'][0]['content'], translator, tokenizer)
221
+ translated_rejected_content = translate_text(item['rejected'][0]['content'], translator, tokenizer)
222
+
223
+ # Write the raw response to a backup file
224
+ with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
225
+ raw_file.write("Prompt content:\n")
226
+ for translated_prompt in translated_prompts:
227
+ raw_file.write(f"{translated_prompt['role']}: {translated_prompt['content']}\n")
228
+ raw_file.write(f"Chosen content: {translated_chosen_content}\n")
229
+ raw_file.write(f"Rejected content: {translated_rejected_content}\n\n")
230
+
231
+ logging.info("Translation request successful.")
232
+ except Exception as e:
233
+ logging.error(f"An error occurred during translation: {e}")
234
+ return None
235
+
236
+ # Update the original item with the translated fields
237
+ item['prompt'] = translated_prompts
238
+ item['chosen'][0]['content'] = translated_chosen_content
239
+ item['rejected'][0]['content'] = translated_rejected_content
240
+
241
+ logging.info("Translation processing successful.")
242
+ return item
243
+
244
+ def validate_item_mix(item):
245
+ """
246
+ Validates the structure, presence, and content of required fields in the given item,
247
+ allowing for multiple elements in the 'prompt' field for multi-turn conversations.
248
+ """
249
+ required_fields = ['dataset', 'prompt', 'chosen', 'rejected']
250
+ for field in required_fields:
251
+ if field not in item:
252
+ logging.warning(f"Missing required field: {field}")
253
+ return False
254
+
255
+ # Check for at least one element in 'prompt' and exactly one element in 'chosen' and 'rejected'
256
+ if len(item['prompt']) < 1 or len(item['chosen']) != 1 or len(item['rejected']) != 1:
257
+ logging.warning("Invalid number of elements in 'prompt', 'chosen', or 'rejected' field.")
258
+ return False
259
+
260
+ # Validate 'content' and 'role' fields in all messages of 'prompt', and single elements of 'chosen' and 'rejected'
261
+ for choice in item['prompt'] + item['chosen'] + item['rejected']:
262
+ if 'content' not in choice or 'role' not in choice:
263
+ logging.warning("Missing 'content' or 'role' field in choice.")
264
+ return False
265
+ if not isinstance(choice['content'], str) or not isinstance(choice['role'], str):
266
+ logging.warning("Invalid type for 'content' or 'role' field in choice.")
267
+ return False
268
+
269
+ return True
270
+
271
+ def translate_item_orpo(item, raw_file_path, translator, tokenizer):
272
+ try:
273
+ translated_texts = {} # Cache to store translated texts
274
+
275
+ # Translate the prompt if necessary (which is a user input and can appear again)
276
+ if item['prompt'] not in translated_texts:
277
+ translated_prompt = translate_text(item['prompt'], translator, tokenizer)
278
+ translated_texts[item['prompt']] = translated_prompt
279
+ else:
280
+ translated_prompt = translated_texts[item['prompt']]
281
+
282
+ # Helper function to handle content translation with caching
283
+ def get_translated_content(content):
284
+ if content not in translated_texts:
285
+ translated_texts[content] = translate_text(content, translator, tokenizer)
286
+ return translated_texts[content]
287
+
288
+ # Process translations for chosen and rejected sections
289
+ def translate_interactions(interactions):
290
+ translated_interactions = []
291
+ for interaction in interactions:
292
+ translated_content = get_translated_content(interaction['content'])
293
+ translated_interactions.append({'content': translated_content, 'role': interaction['role']})
294
+ return translated_interactions
295
+
296
+ translated_chosen = translate_interactions(item['chosen'])
297
+ translated_rejected = translate_interactions(item['rejected'])
298
+
299
+ # Write the raw response to a backup file
300
+ with open(raw_file_path, 'a', encoding='utf-8') as raw_file:
301
+ raw_file.write(f"Prompt: {translated_prompt}\n")
302
+ raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n")
303
+ raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n")
304
+
305
+ logging.info("Translation request successful.")
306
+ # Update the original item with the translated fields
307
+ item['prompt'] = translated_prompt
308
+ item['chosen'] = translated_chosen
309
+ item['rejected'] = translated_rejected
310
+ return item
311
+
312
+ except Exception as e:
313
+ logging.error(f"An error occurred during translation: {e}")
314
+ return None
315
+
316
+ def validate_item_orpo(item):
317
+ # Check basic required fields
318
+ required_fields = ['source', 'prompt', 'chosen', 'rejected']
319
+ for field in required_fields:
320
+ if field not in item:
321
+ logging.warning(f"Missing required field: {field}")
322
+ return False
323
+
324
+ # Ensure 'prompt' is a string
325
+ if not isinstance(item['prompt'], str):
326
+ logging.warning("Prompt must be a string.")
327
+ return False
328
+
329
+ # Check 'chosen' and 'rejected' which should be lists of dictionaries
330
+ for field in ['chosen', 'rejected']:
331
+ if not isinstance(item[field], list) or not item[field]:
332
+ logging.warning(f"No entries or incorrect type for section: {field}")
333
+ return False
334
+ for idx, message in enumerate(item[field]):
335
+ if 'content' not in message or 'role' not in message:
336
+ logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}")
337
+ return False
338
+ if not isinstance(message['content'], str) or not isinstance(message['role'], str):
339
+ logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}")
340
+ return False
341
+
342
+ return True
343
+
344
+ def process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type):
345
+ try:
346
+ # Assigning validation and translation functions based on model_type
347
+ if model_type == "mix":
348
+ print ("translating a mix-style model...")
349
+ validate_item = validate_item_mix
350
+ translate_item = translate_item_mix
351
+ elif model_type == "orpo":
352
+ print ("translating an orpo-style model...")
353
+ validate_item = validate_item_orpo
354
+ translate_item = translate_item_orpo # def translate_item_ufb(item, raw_file_path, translator, tokenizer):
355
+ elif model_type == "ufb":
356
+ print ("translating an ultrafeedback-style model...")
357
+ validate_item = validate_item_ufb
358
+ translate_item = translate_item_ufb # def translate_item_ufb(item, raw_file_path, translator, tokenizer):
359
+ else:
360
+ raise ValueError(f"Unsupported model_type: {model_type}")
361
+
362
+ with open(input_file_path, 'r', encoding='utf-8') as file:
363
+ data_points = [json.loads(line) for line in file]
364
+
365
+ failed_items = []
366
+ failed_items_indices = []
367
+
368
+ for index in tqdm(line_indices, desc="Processing lines", unit="item"):
369
+ item = data_points[index]
370
+
371
+ # Validate the item structure
372
+ if not validate_item(item):
373
+ logging.warning("Skipping item due to invalid structure.")
374
+ failed_items.append(item)
375
+ continue
376
+
377
+ # Translate the relevant fields in the item
378
+ translated_item = None
379
+ retry_count = 0
380
+ while translated_item is None and retry_count < 3:
381
+ print ("going to translate the item...")
382
+ translated_item = translate_item(item, raw_file_path, translator, tokenizer)
383
+ retry_count += 1
384
+ if translated_item is None:
385
+ logging.warning(f"Translation failed for item. Retry attempt: {retry_count}")
386
+ time.sleep(1)
387
+
388
+ if translated_item is not None:
389
+ translated_item['index'] = index
390
+ with open(output_file_path, 'a', encoding='utf-8') as file:
391
+ file.write(json.dumps(translated_item, ensure_ascii=False) + "\n")
392
+ else:
393
+ failed_items_indices.append(index)
394
+ failed_items.append(item)
395
+ logging.error("Translation failed after multiple attempts. Skipping item.")
396
+
397
+ # Validate the translated item structure
398
+ if not validate_item(translated_item):
399
+ logging.warning("Skipping translated item due to invalid structure.")
400
+ failed_items.append(item)
401
+ continue
402
+
403
+ with open('failed_items.jsonl', 'w', encoding='utf-8') as file:
404
+ for item in failed_items:
405
+ file.write(json.dumps(item, ensure_ascii=False) + "\n")
406
+
407
+ failed_items_str = generate_failed_items_str(failed_items_indices)
408
+ with open('failed_items_index.txt', 'w', encoding='utf-8') as f:
409
+ f.write(failed_items_str)
410
+
411
+ logging.info("Translation completed successfully.")
412
+
413
+ except Exception as e:
414
+ logging.error(f"An error occurred: {e}")
415
+
416
+ def generate_failed_items_str(indices):
417
+ """
418
+ Converts a list of failed item indices into a string.
419
+ """
420
+ if not indices:
421
+ return ""
422
+
423
+ # Sort the list of indices and initialize the first range
424
+ indices.sort()
425
+ range_start = indices[0]
426
+ current = range_start
427
+ ranges = []
428
+
429
+ for i in indices[1:]:
430
+ if i == current + 1:
431
+ current = i
432
+ else:
433
+ if range_start == current:
434
+ ranges.append(f"{range_start}")
435
+ else:
436
+ ranges.append(f"{range_start}-{current}")
437
+ range_start = current = i
438
+
439
+ # Add the last range
440
+ if range_start == current:
441
+ ranges.append(f"{range_start}")
442
+ else:
443
+ ranges.append(f"{range_start}-{current}")
444
+
445
+ return ",".join(ranges)
446
+
447
+ # Function to upload the output file to Hugging Face
448
+ def upload_output_to_huggingface(output_file_path, repo_name, token):
449
+ upload_file(
450
+ path_or_fileobj=output_file_path,
451
+ path_in_repo=output_file_path,
452
+ repo_id=repo_name,
453
+ repo_type="dataset",
454
+ token=token
455
+ )
456
+ print(f"Uploaded {output_file_path} to Hugging Face repository: {repo_name}")
457
+
458
+ def translate_dataset(train_url, local_parquet_path, input_file_path, output_file_path, raw_file_path, range_specification, model_type, output_dir, output_repo_name, token, translator, tokenizer):
459
+ try:
460
+ # Download the Parquet file
461
+ download_parquet(train_url, local_parquet_path)
462
+ except Exception as e:
463
+ logging.error(f"Failed to download the Parquet file from {train_url}: {e}")
464
+ return
465
+
466
+ try:
467
+ # Convert the downloaded Parquet file to JSONL
468
+ convert_parquet_to_jsonl(local_parquet_path, output_dir)
469
+ except Exception as e:
470
+ logging.error(f"Failed to convert Parquet to JSONL: {e}")
471
+ return
472
+
473
+ try:
474
+ # Rename the JSONL file using subprocess to ensure correct handling
475
+ subprocess.run(["mv", f"{output_dir}/train.jsonl", input_file_path], check=True)
476
+ except subprocess.CalledProcessError as e:
477
+ logging.error(f"Failed to rename the file from 'train.jsonl' to {input_file_path}: {e}")
478
+ return
479
+
480
+ try:
481
+ # Count lines in the JSONL file to validate contents
482
+ line_count = count_lines_in_jsonl(input_file_path)
483
+ logging.info(f"Number of lines in the file: {line_count}")
484
+ except Exception as e:
485
+ logging.error(f"Failed to count lines in {input_file_path}: {e}")
486
+ return
487
+
488
+ try:
489
+ # Parse the range specification for processing specific lines
490
+ line_indices = parse_range_specification(range_specification, file_length=line_count)
491
+ if not line_indices:
492
+ logging.error("No valid line indices to process. Please check the range specifications.")
493
+ return
494
+ except Exception as e:
495
+ logging.error(f"Error parsing range specification '{range_specification}': {e}")
496
+ return
497
+
498
+ try:
499
+ # Process the file with specified model type and line indices
500
+ process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type)
501
+ except Exception as e:
502
+ logging.error(f"Failed to process the file {input_file_path}: {e}")
503
+ return
504
+
505
+ try:
506
+ # Upload the output file to Hugging Face repository
507
+ upload_output_to_huggingface(output_file_path, output_repo_name, token)
508
+ except Exception as e:
509
+ logging.error(f"Failed to upload {output_file_path} to Hugging Face: {e}")
510
+
511
+ # Setup logging configuration
512
+ logging.basicConfig(level=logging.INFO, filename='translation.log', filemode='a',
513
+ format='%(asctime)s - %(levelname)s - %(message)s')
514
+
515
+ def main(model_id, dataset_url, model_type, output_dataset_name):
516
+ try:
517
+ # Login to Hugging Face
518
+ token = login()
519
+ if token:
520
+ logging.info("Logged in to Hugging Face")
521
+
522
+ # Configuration and paths
523
+ tokenizer_name = "facebook/wmt21-dense-24-wide-en-x"
524
+ model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from
525
+
526
+ # Download the model snapshot from Hugging Face
527
+ model_path = snapshot_download(repo_id=model_repo_name, token=token)
528
+ logging.info(f"Model downloaded to: {model_path}")
529
+
530
+ # Load the CTranslate2 model
531
+ translator = ctranslate2.Translator(model_path, device="auto")
532
+ logging.info("CTranslate2 model loaded successfully.")
533
+
534
+ # Load the tokenizer
535
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
536
+ tokenizer.src_lang = "en"
537
+ logging.info("Tokenizer loaded successfully.")
538
+
539
+ # Define the task based on user input
540
+ task = {
541
+ "url": dataset_url,
542
+ "local_path": "train.parquet",
543
+ "input_file": f"{model_type}_en.jsonl",
544
+ "output_file": f"{model_type}_de.jsonl",
545
+ "raw_file": f"{model_type}_de_raw.jsonl",
546
+ "range_spec": "1-",
547
+ "model_type": model_type
548
+ }
549
+
550
+ # Call the translate_dataset function with the provided parameters
551
+ translate_dataset(
552
+ train_url=task["url"],
553
+ local_parquet_path=task["local_path"],
554
+ input_file_path=task["input_file"],
555
+ output_file_path=task["output_file"],
556
+ output_dir=".",
557
+ output_repo_name=output_dataset_name,
558
+ raw_file_path=task["raw_file"],
559
+ token=token,
560
+ range_specification=task["range_spec"],
561
+ model_type=task["model_type"],
562
+ translator=translator,
563
+ tokenizer=tokenizer,
564
+ )
565
+ return "Dataset translation completed!"
566
+ else:
567
+ return "Login failed. Please try again."
568
+ except Exception as e:
569
+ logging.error(f"An error occurred in the main function: {e}")
570
+ return f"An error occurred: {e}"
571
+
572
+ # Gradio interface setup
573
+ gradio_title = "🧐 WMT21 Dataset Translation"
574
+ gradio_desc = """This tool translates datasets using the WMT21 translation model.
575
+ ## 💭 What Does This Tool Do:
576
+ - Translates datasets based on the selected model type.
577
+ - Uploads the translated dataset to Hugging Face.
578
+ ## 🛠️ Backend:
579
+ The translation backend runs on the Hugging Face Hub API.
580
+ """
581
+
582
+ with gr.Blocks() as demo:
583
+ gr.HTML(f"""<h1 align="center" id="space-title">{gradio_title}</h1>""")
584
+ gr.Markdown(gradio_desc)
585
+
586
+ with gr.Row(equal_height=False):
587
+ with gr.Column():
588
+ model_id = gr.Textbox(label="Model ID or URL", lines=1)
589
+ dataset_url = gr.Textbox(label="Dataset URL", lines=1)
590
+ model_type = gr.Dropdown(choices=["mix", "orpo", "ufb"], label="Model Type")
591
+ output_dataset_name = gr.Textbox(label="Output Dataset Name", lines=1)
592
+ login_button = gr.Button("Login to Hugging Face")
593
+
594
+ with gr.Column():
595
+ output = gr.Textbox(label="Output", lines=1)
596
+ logout_button = gr.Button("Logout")
597
+
598
+ submit_btn = gr.Button("Translate Dataset", variant="primary")
599
+ submit_btn.click(main, inputs=[model_id, dataset_url, model_type, output_dataset_name], outputs=output)
600
+
601
+ demo.launch()