Upload pipeline.py
Browse files- pipeline.py +3 -3
pipeline.py
CHANGED
@@ -12,8 +12,8 @@ class PreTrainedPipeline():
|
|
12 |
"""
|
13 |
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
14 |
self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert")
|
15 |
-
def __call__(self, inputs: str) -> List[float]:
|
16 |
-
|
17 |
"""
|
18 |
Args:
|
19 |
inputs (:obj:`str`):
|
@@ -27,4 +27,4 @@ class PreTrainedPipeline():
|
|
27 |
predicted_class_id = self.model(X).logits.argmax().item()
|
28 |
labels = ['positive','neutral','negative']
|
29 |
reps = labels[predicted_class_id]
|
30 |
-
return
|
|
|
12 |
"""
|
13 |
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
14 |
self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert")
|
15 |
+
#def __call__(self, inputs: str) -> List[float]:
|
16 |
+
def __call__(self, inputs: str) -> str:
|
17 |
"""
|
18 |
Args:
|
19 |
inputs (:obj:`str`):
|
|
|
27 |
predicted_class_id = self.model(X).logits.argmax().item()
|
28 |
labels = ['positive','neutral','negative']
|
29 |
reps = labels[predicted_class_id]
|
30 |
+
return reps
|