Update optimum_encoder.py
Browse files- optimum_encoder.py +27 -26
optimum_encoder.py
CHANGED
@@ -60,11 +60,11 @@ class OptimumEncoder(BaseEncoder):
|
|
60 |
**self.tokenizer_kwargs,
|
61 |
)
|
62 |
|
63 |
-
provider_options = {
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
}
|
68 |
|
69 |
session_options = ort.SessionOptions()
|
70 |
session_options.log_severity_level = 0
|
@@ -73,33 +73,34 @@ class OptimumEncoder(BaseEncoder):
|
|
73 |
model_id=self.name,
|
74 |
file_name='model_fp16.onnx',
|
75 |
subfolder='onnx',
|
76 |
-
provider='
|
77 |
-
|
|
|
78 |
session_options=session_options,
|
79 |
**self.model_kwargs
|
80 |
)
|
81 |
|
82 |
-
print("Building engine for a short sequence...")
|
83 |
-
short_text = ["short"]
|
84 |
-
short_encoded_input = tokenizer(
|
85 |
-
|
86 |
-
).to(self.device)
|
87 |
-
short_output = ort_model(**short_encoded_input)
|
88 |
|
89 |
-
print("Building engine for a long sequence...")
|
90 |
-
long_text = ["a very long input just for demo purpose, this is very long" * 10]
|
91 |
-
long_encoded_input = tokenizer(
|
92 |
-
|
93 |
-
).to(self.device)
|
94 |
-
long_output = ort_model(**long_encoded_input)
|
95 |
-
|
96 |
-
text = ["Replace me by any text you'd like."]
|
97 |
-
encoded_input = tokenizer(
|
98 |
-
|
99 |
-
).to(self.device)
|
100 |
|
101 |
-
for i in range(3):
|
102 |
-
|
103 |
|
104 |
return tokenizer, ort_model
|
105 |
|
|
|
60 |
**self.tokenizer_kwargs,
|
61 |
)
|
62 |
|
63 |
+
#provider_options = {
|
64 |
+
# "trt_engine_cache_enable": True,
|
65 |
+
# "trt_engine_cache_path": os.getenv('HF_HOME'),
|
66 |
+
# "trt_fp16_enable": True
|
67 |
+
#}
|
68 |
|
69 |
session_options = ort.SessionOptions()
|
70 |
session_options.log_severity_level = 0
|
|
|
73 |
model_id=self.name,
|
74 |
file_name='model_fp16.onnx',
|
75 |
subfolder='onnx',
|
76 |
+
provider='CUDAExecutionProvider',
|
77 |
+
use_io_binding=True,
|
78 |
+
#provider_options=provider_options,
|
79 |
session_options=session_options,
|
80 |
**self.model_kwargs
|
81 |
)
|
82 |
|
83 |
+
# print("Building engine for a short sequence...")
|
84 |
+
# short_text = ["short"]
|
85 |
+
# short_encoded_input = tokenizer(
|
86 |
+
# short_text, padding=True, truncation=True, return_tensors="pt"
|
87 |
+
# ).to(self.device)
|
88 |
+
# short_output = ort_model(**short_encoded_input)
|
89 |
|
90 |
+
# print("Building engine for a long sequence...")
|
91 |
+
# long_text = ["a very long input just for demo purpose, this is very long" * 10]
|
92 |
+
# long_encoded_input = tokenizer(
|
93 |
+
# long_text, padding=True, truncation=True, return_tensors="pt"
|
94 |
+
# ).to(self.device)
|
95 |
+
# long_output = ort_model(**long_encoded_input)
|
96 |
+
|
97 |
+
# text = ["Replace me by any text you'd like."]
|
98 |
+
# encoded_input = tokenizer(
|
99 |
+
# text, padding=True, truncation=True, return_tensors="pt"
|
100 |
+
# ).to(self.device)
|
101 |
|
102 |
+
# for i in range(3):
|
103 |
+
# output = ort_model(**encoded_input)
|
104 |
|
105 |
return tokenizer, ort_model
|
106 |
|