anttip commited on
Commit
ca9a50b
1 Parent(s): 0012f15

Create translate.py

Browse files
Files changed (1) hide show
  1. translate.py +487 -0
translate.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctranslate2
2
+ import functools
3
+
4
+ try:
5
+ from transformers import AutoTokenizer
6
+ autotokenizer_ok = True
7
+ except ImportError:
8
+ AutoTokenizer = object
9
+ autotokenizer_ok = False
10
+
11
+ try:
12
+ from typing import Literal
13
+ except ImportError:
14
+ from typing_extensions import Literal
15
+
16
+ from typing import Any, Union, List
17
+ import os
18
+
19
+ from hf_hub_ctranslate2.util import utils as _utils
20
+
21
+
22
+ class CTranslate2ModelfromHuggingfaceHub:
23
+ """CTranslate2 compatibility class for Translator and Generator"""
24
+
25
+ def __init__(
26
+ self,
27
+ model_name_or_path: str,
28
+ device: Literal["cpu", "cuda"] = "cuda",
29
+ device_index=0,
30
+ compute_type: Literal["int8_float16", "int8"] = "int8_float16",
31
+ tokenizer: Union[AutoTokenizer, None] = None,
32
+ hub_kwargs: dict = {},
33
+ **kwargs: Any,
34
+ ):
35
+ # adaptions from https://github.com/guillaumekln/faster-whisper
36
+ if os.path.isdir(model_name_or_path):
37
+ model_path = model_name_or_path
38
+ else:
39
+ try:
40
+ model_path = _utils._download_model(
41
+ model_name_or_path, hub_kwargs=hub_kwargs, local_files_only=True,
42
+ )
43
+ except Exception:
44
+ hub_kwargs["local_files_only"] = True
45
+ model_path = _utils._download_model(
46
+ model_name_or_path, hub_kwargs=hub_kwargs, local_files_only=True,
47
+ )
48
+ self.model = self.ctranslate_class(
49
+ model_path,
50
+ device=device,
51
+ device_index=device_index,
52
+ compute_type=compute_type,
53
+ **kwargs,
54
+ )
55
+
56
+ if tokenizer is not None:
57
+ self.tokenizer = tokenizer
58
+ else:
59
+ if "tokenizer.json" in os.listdir(model_path):
60
+ if not autotokenizer_ok:
61
+ raise ValueError(
62
+ "`pip install transformers` missing to load AutoTokenizer."
63
+ )
64
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, fast=True)
65
+ else:
66
+ raise ValueError(
67
+ "no suitable Tokenizer found. "
68
+ "Please set one via tokenizer=AutoTokenizer.from_pretrained(..) arg."
69
+ )
70
+
71
+ def _forward(self, *args: Any, **kwds: Any) -> Any:
72
+ raise NotImplementedError
73
+
74
+ def tokenize_encode(self, text, *args, **kwargs):
75
+ return [
76
+ self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(p)) for p in text
77
+ ]
78
+
79
+ def tokenize_decode(self, tokens_out, *args, **kwargs):
80
+ raise NotImplementedError
81
+
82
+ def generate(
83
+ self,
84
+ text: Union[str, List[str]],
85
+ encode_kwargs={},
86
+ decode_kwargs={},
87
+ *forward_args,
88
+ **forward_kwds: Any,
89
+ ):
90
+ orig_type = list
91
+ if isinstance(text, str):
92
+ orig_type = str
93
+ text = [text]
94
+ token_list = self.tokenize_encode(text, **encode_kwargs)
95
+ tokens_out = self._forward(token_list, *forward_args, **forward_kwds)
96
+ texts_out = self.tokenize_decode(tokens_out, **decode_kwargs)
97
+ if orig_type == str:
98
+ return texts_out[0]
99
+ else:
100
+ return texts_out
101
+
102
+
103
+ class TranslatorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
104
+ def __init__(
105
+ self,
106
+ model_name_or_path: str,
107
+ device: Literal["cpu", "cuda"] = "cuda",
108
+ device_index=0,
109
+ compute_type: Literal["int8_float16", "int8"] = "int8_float16",
110
+ tokenizer: Union[AutoTokenizer, None] = None,
111
+ hub_kwargs={},
112
+ **kwargs: Any,
113
+ ):
114
+ """for ctranslate2.Translator models, in particular m2m-100
115
+ Args:
116
+ model_name_or_path (str): _description_
117
+ device (Literal[cpu, cuda], optional): _description_. Defaults to "cuda".
118
+ device_index (int, optional): _description_. Defaults to 0.
119
+ compute_type (Literal[int8_float16, int8], optional): _description_. Defaults to "int8_float16".
120
+ tokenizer (Union[AutoTokenizer, None], optional): _description_. Defaults to None.
121
+ hub_kwargs (dict, optional): _description_. Defaults to {}.
122
+ **kwargs (Any, optional): Any additional arguments
123
+ """
124
+ self.ctranslate_class = ctranslate2.Translator
125
+ super().__init__(
126
+ model_name_or_path,
127
+ device,
128
+ device_index,
129
+ compute_type,
130
+ tokenizer,
131
+ hub_kwargs,
132
+ **kwargs,
133
+ )
134
+
135
+ def _forward(self, *args, **kwds):
136
+ return self.model.translate_batch(*args, **kwds)
137
+
138
+ def tokenize_decode(self, tokens_out, *args, **kwargs):
139
+ return [
140
+ self.tokenizer.decode(
141
+ self.tokenizer.convert_tokens_to_ids(tokens_out[i].hypotheses[0]),
142
+ *args,
143
+ **kwargs,
144
+ )
145
+ for i in range(len(tokens_out))
146
+ ]
147
+
148
+ def generate(
149
+ self,
150
+ text: Union[str, List[str]],
151
+ encode_tok_kwargs={},
152
+ decode_tok_kwargs={},
153
+ *forward_args,
154
+ **forward_kwds: Any,
155
+ ):
156
+ """_summary_
157
+ Args:
158
+ text (Union[str, List[str]]): Input texts
159
+ encode_tok_kwargs (dict, optional): additional kwargs for tokenizer
160
+ decode_tok_kwargs (dict, optional): additional kwargs for tokenizer
161
+ max_batch_size (int, optional): Batch size. Defaults to 0.
162
+ batch_type (str, optional): _. Defaults to "examples".
163
+ asynchronous (bool, optional): Only False supported. Defaults to False.
164
+ beam_size (int, optional): _. Defaults to 2.
165
+ patience (float, optional): _. Defaults to 1.
166
+ num_hypotheses (int, optional): _. Defaults to 1.
167
+ length_penalty (float, optional): _. Defaults to 1.
168
+ coverage_penalty (float, optional): _. Defaults to 0.
169
+ repetition_penalty (float, optional): _. Defaults to 1.
170
+ no_repeat_ngram_size (int, optional): _. Defaults to 0.
171
+ disable_unk (bool, optional): _. Defaults to False.
172
+ suppress_sequences (Optional[List[List[str]]], optional): _.
173
+ Defaults to None.
174
+ end_token (Optional[Union[str, List[str], List[int]]], optional): _.
175
+ Defaults to None.
176
+ return_end_token (bool, optional): _. Defaults to False.
177
+ prefix_bias_beta (float, optional): _. Defaults to 0.
178
+ max_input_length (int, optional): _. Defaults to 1024.
179
+ max_decoding_length (int, optional): _. Defaults to 256.
180
+ min_decoding_length (int, optional): _. Defaults to 1.
181
+ use_vmap (bool, optional): _. Defaults to False.
182
+ return_scores (bool, optional): _. Defaults to False.
183
+ return_attention (bool, optional): _. Defaults to False.
184
+ return_alternatives (bool, optional): _. Defaults to False.
185
+ min_alternative_expansion_prob (float, optional): _. Defaults to 0.
186
+ sampling_topk (int, optional): _. Defaults to 1.
187
+ sampling_temperature (float, optional): _. Defaults to 1.
188
+ replace_unknowns (bool, optional): _. Defaults to False.
189
+ callback (_type_, optional): _. Defaults to None.
190
+ Returns:
191
+ Union[str, List[str]]: text as output, if list, same len as input
192
+ """
193
+ return super().generate(
194
+ text,
195
+ encode_kwargs=encode_tok_kwargs,
196
+ decode_kwargs=decode_tok_kwargs,
197
+ *forward_args,
198
+ **forward_kwds,
199
+ )
200
+
201
+
202
+ class MultiLingualTranslatorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
203
+ def __init__(
204
+ self,
205
+ model_name_or_path: str,
206
+ device: Literal["cpu", "cuda"] = "cuda",
207
+ device_index=0,
208
+ compute_type: Literal["int8_float16", "int8"] = "int8_float16",
209
+ tokenizer: Union[AutoTokenizer, None] = None,
210
+ hub_kwargs={},
211
+ **kwargs: Any,
212
+ ):
213
+ """for ctranslate2.Translator models
214
+ Args:
215
+ model_name_or_path (str): _description_
216
+ device (Literal[cpu, cuda], optional): _description_. Defaults to "cuda".
217
+ device_index (int, optional): _description_. Defaults to 0.
218
+ compute_type (Literal[int8_float16, int8], optional): _description_. Defaults to "int8_float16".
219
+ tokenizer (Union[AutoTokenizer, None], optional): _description_. Defaults to None.
220
+ hub_kwargs (dict, optional): _description_. Defaults to {}.
221
+ **kwargs (Any, optional): Any additional arguments
222
+ """
223
+ self.ctranslate_class = ctranslate2.Translator
224
+ super().__init__(
225
+ model_name_or_path,
226
+ device,
227
+ device_index,
228
+ compute_type,
229
+ tokenizer,
230
+ hub_kwargs,
231
+ **kwargs,
232
+ )
233
+
234
+ def _forward(self, *args, **kwds):
235
+ target_prefix = [
236
+ [self.tokenizer.lang_code_to_token[lng]] for lng in kwds.pop("tgt_lang")
237
+ ]
238
+ # target_prefix=[['__de__'], ['__fr__']]
239
+ return self.model.translate_batch(*args, **kwds, target_prefix=target_prefix)
240
+
241
+ def tokenize_encode(self, text, *args, **kwargs):
242
+ tokens = []
243
+ src_lang = kwargs.pop("src_lang")
244
+ for t, src_language in zip(text, src_lang):
245
+ self.tokenizer.src_lang = src_language
246
+ tokens.append(
247
+ self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(t))
248
+ )
249
+ return tokens
250
+
251
+ def tokenize_decode(self, tokens_out, *args, **kwargs):
252
+ return [
253
+ self.tokenizer.decode(
254
+ self.tokenizer.convert_tokens_to_ids(tokens_out[i].hypotheses[0][1:]),
255
+ *args,
256
+ **kwargs,
257
+ )
258
+ for i in range(len(tokens_out))
259
+ ]
260
+
261
+ def generate(
262
+ self,
263
+ text: Union[str, List[str]],
264
+ src_lang: Union[str, List[str]],
265
+ tgt_lang: Union[str, List[str]],
266
+ *forward_args,
267
+ **forward_kwds: Any,
268
+ ):
269
+ """_summary_
270
+ Args:
271
+ text (Union[str, List[str]]): Input texts
272
+ src_lang (Union[str, List[str]]): soruce language of the Input texts
273
+ tgt_lang (Union[str, List[str]]): target language for outputs
274
+ max_batch_size (int, optional): Batch size. Defaults to 0.
275
+ batch_type (str, optional): _. Defaults to "examples".
276
+ asynchronous (bool, optional): Only False supported. Defaults to False.
277
+ beam_size (int, optional): _. Defaults to 2.
278
+ patience (float, optional): _. Defaults to 1.
279
+ num_hypotheses (int, optional): _. Defaults to 1.
280
+ length_penalty (float, optional): _. Defaults to 1.
281
+ coverage_penalty (float, optional): _. Defaults to 0.
282
+ repetition_penalty (float, optional): _. Defaults to 1.
283
+ no_repeat_ngram_size (int, optional): _. Defaults to 0.
284
+ disable_unk (bool, optional): _. Defaults to False.
285
+ suppress_sequences (Optional[List[List[str]]], optional): _.
286
+ Defaults to None.
287
+ end_token (Optional[Union[str, List[str], List[int]]], optional): _.
288
+ Defaults to None.
289
+ return_end_token (bool, optional): _. Defaults to False.
290
+ prefix_bias_beta (float, optional): _. Defaults to 0.
291
+ max_input_length (int, optional): _. Defaults to 1024.
292
+ max_decoding_length (int, optional): _. Defaults to 256.
293
+ min_decoding_length (int, optional): _. Defaults to 1.
294
+ use_vmap (bool, optional): _. Defaults to False.
295
+ return_scores (bool, optional): _. Defaults to False.
296
+ return_attention (bool, optional): _. Defaults to False.
297
+ return_alternatives (bool, optional): _. Defaults to False.
298
+ min_alternative_expansion_prob (float, optional): _. Defaults to 0.
299
+ sampling_topk (int, optional): _. Defaults to 1.
300
+ sampling_temperature (float, optional): _. Defaults to 1.
301
+ replace_unknowns (bool, optional): _. Defaults to False.
302
+ callback (_type_, optional): _. Defaults to None.
303
+ Returns:
304
+ Union[str, List[str]]: text as output, if list, same len as input
305
+ """
306
+ if not len(text) == len(src_lang) == len(tgt_lang):
307
+ raise ValueError(
308
+ f"unequal len: text={len(text)} src_lang={len(src_lang)} tgt_lang={len(tgt_lang)}"
309
+ )
310
+ forward_kwds["tgt_lang"] = tgt_lang
311
+ return super().generate(
312
+ text, *forward_args, **forward_kwds, encode_kwargs={"src_lang": src_lang}
313
+ )
314
+
315
+
316
+ class EncoderCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
317
+ def __init__(
318
+ self,
319
+ model_name_or_path: str,
320
+ device: Literal["cpu", "cuda"] = "cuda",
321
+ device_index=0,
322
+ compute_type: Literal["int8_float16", "int8"] = "int8_float16",
323
+ tokenizer: Union[AutoTokenizer, None] = None,
324
+ hub_kwargs={},
325
+ **kwargs: Any,
326
+ ):
327
+ """for ctranslate2.Translator models, in particular m2m-100
328
+ Args:
329
+ model_name_or_path (str): _description_
330
+ device (Literal[cpu, cuda], optional): _description_. Defaults to "cuda".
331
+ device_index (int, optional): _description_. Defaults to 0.
332
+ compute_type (Literal[int8_float16, int8], optional): _description_. Defaults to "int8_float16".
333
+ tokenizer (Union[AutoTokenizer, None], optional): _description_. Defaults to None.
334
+ hub_kwargs (dict, optional): _description_. Defaults to {}.
335
+ **kwargs (Any, optional): Any additional arguments
336
+ """
337
+ self.ctranslate_class = ctranslate2.Encoder
338
+ super().__init__(
339
+ model_name_or_path,
340
+ device,
341
+ device_index,
342
+ compute_type,
343
+ tokenizer,
344
+ hub_kwargs,
345
+ **kwargs,
346
+ )
347
+ self.device = device
348
+ if device == "cuda":
349
+ try:
350
+ import torch
351
+ except ImportError:
352
+ raise ValueError(
353
+ "decoding storageview on CUDA of encoder requires torch"
354
+ )
355
+ self.tensor_decode_method = functools.partial(
356
+ torch.as_tensor, device=device
357
+ )
358
+ self.input_dtype=torch.int32
359
+ else:
360
+ try:
361
+ import numpy as np
362
+ except ImportError:
363
+ raise ValueError(
364
+ "decoding storageview on CPU of encoder requires numpy"
365
+ )
366
+ self.tensor_decode_method = np.asarray
367
+
368
+ def _forward(self, features, *args, **kwds):
369
+ input_ids = features["input_ids"]
370
+ tokens_out = self.model.forward_batch(input_ids, *args, **kwds)
371
+ outputs = dict(
372
+ pooler_output = self.tensor_decode_method(tokens_out.pooler_output),
373
+ last_hidden_state = self.tensor_decode_method(tokens_out.last_hidden_state),
374
+ attention_mask=features["attention_mask"]
375
+ )
376
+ return outputs
377
+
378
+ def tokenize_encode(self, text, *args, **kwargs):
379
+ return self.tokenizer(text)
380
+
381
+ def tokenize_decode(self, tokens_out, *args, **kwargs):
382
+ return tokens_out
383
+
384
+ def generate(
385
+ self,
386
+ text: Union[str, List[str]],
387
+ encode_tok_kwargs={},
388
+ decode_tok_kwargs={},
389
+ *forward_args,
390
+ **forward_kwds: Any,
391
+ ):
392
+ return super().generate(
393
+ text,
394
+ encode_kwargs=encode_tok_kwargs,
395
+ decode_kwargs=decode_tok_kwargs,
396
+ *forward_args,
397
+ **forward_kwds,
398
+ )
399
+
400
+
401
+ class GeneratorCT2fromHfHub(CTranslate2ModelfromHuggingfaceHub):
402
+ def __init__(
403
+ self,
404
+ model_name_or_path: str,
405
+ device: Literal["cpu", "cuda"] = "cuda",
406
+ device_index=0,
407
+ compute_type: Literal["int8_float16", "int8"] = "int8_float16",
408
+ tokenizer: Union[AutoTokenizer, None] = None,
409
+ hub_kwargs={},
410
+ **kwargs: Any,
411
+ ):
412
+ """for ctranslate2.Generator models
413
+ Args:
414
+ model_name_or_path (str): _description_
415
+ device (Literal[cpu, cuda], optional): _description_. Defaults to "cuda".
416
+ device_index (int, optional): _description_. Defaults to 0.
417
+ compute_type (Literal[int8_float16, int8], optional): _description_. Defaults to "int8_float16".
418
+ tokenizer (Union[AutoTokenizer, None], optional): _description_. Defaults to None.
419
+ hub_kwargs (dict, optional): _description_. Defaults to {}.
420
+ **kwargs (Any, optional): Any additional arguments
421
+ """
422
+ self.ctranslate_class = ctranslate2.Generator
423
+ super().__init__(
424
+ model_name_or_path,
425
+ device,
426
+ device_index,
427
+ compute_type,
428
+ tokenizer,
429
+ hub_kwargs,
430
+ **kwargs,
431
+ )
432
+
433
+ def _forward(self, *args, **kwds):
434
+ return self.model.generate_batch(*args, **kwds)
435
+
436
+ def tokenize_decode(self, tokens_out, *args, **kwargs):
437
+ return [
438
+ self.tokenizer.decode(tokens_out[i].sequences_ids[0], *args, **kwargs)
439
+ for i in range(len(tokens_out))
440
+ ]
441
+
442
+ def generate(
443
+ self,
444
+ text: Union[str, List[str]],
445
+ encode_tok_kwargs={},
446
+ decode_tok_kwargs={},
447
+ *forward_args,
448
+ **forward_kwds: Any,
449
+ ):
450
+ """_summary_
451
+ Args:
452
+ text (str | List[str]): Input texts
453
+ encode_tok_kwargs (dict, optional): additional kwargs for tokenizer
454
+ decode_tok_kwargs (dict, optional): additional kwargs for tokenizer
455
+ max_batch_size (int, optional): _. Defaults to 0.
456
+ batch_type (str, optional): _. Defaults to 'examples'.
457
+ asynchronous (bool, optional): _. Defaults to False.
458
+ beam_size (int, optional): _. Defaults to 1.
459
+ patience (float, optional): _. Defaults to 1.
460
+ num_hypotheses (int, optional): _. Defaults to 1.
461
+ length_penalty (float, optional): _. Defaults to 1.
462
+ repetition_penalty (float, optional): _. Defaults to 1.
463
+ no_repeat_ngram_size (int, optional): _. Defaults to 0.
464
+ disable_unk (bool, optional): _. Defaults to False.
465
+ suppress_sequences (Optional[List[List[str]]], optional): _.
466
+ Defaults to None.
467
+ end_token (Optional[Union[str, List[str], List[int]]], optional): _.
468
+ Defaults to None.
469
+ return_end_token (bool, optional): _. Defaults to False.
470
+ max_length (int, optional): _. Defaults to 512.
471
+ min_length (int, optional): _. Defaults to 0.
472
+ include_prompt_in_result (bool, optional): _. Defaults to True.
473
+ return_scores (bool, optional): _. Defaults to False.
474
+ return_alternatives (bool, optional): _. Defaults to False.
475
+ min_alternative_expansion_prob (float, optional): _. Defaults to 0.
476
+ sampling_topk (int, optional): _. Defaults to 1.
477
+ sampling_temperature (float, optional): _. Defaults to 1.
478
+ Returns:
479
+ str | List[str]: text as output, if list, same len as input
480
+ """
481
+ return super().generate(
482
+ text,
483
+ encode_kwargs=encode_tok_kwargs,
484
+ decode_kwargs=decode_tok_kwargs,
485
+ *forward_args,
486
+ **forward_kwds,
487
+ )