Text Classification
Transformers
PyTorch
bert
Inference Endpoints
rttl commited on
Commit
4d90ca7
1 Parent(s): 97a1414

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +3 -3
pipeline.py CHANGED
@@ -12,8 +12,8 @@ class PreTrainedPipeline():
12
  """
13
  self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
14
  self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert")
15
- def __call__(self, inputs: str) -> List[float]:
16
-
17
  """
18
  Args:
19
  inputs (:obj:`str`):
@@ -27,4 +27,4 @@ class PreTrainedPipeline():
27
  predicted_class_id = self.model(X).logits.argmax().item()
28
  labels = ['positive','neutral','negative']
29
  reps = labels[predicted_class_id]
30
- return predicted_class_id
 
12
  """
13
  self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
14
  self.model = FoodyBertForSequenceClassification.from_pretrained("rttl-ai/foody-bert")
15
+ #def __call__(self, inputs: str) -> List[float]:
16
+ def __call__(self, inputs: str) -> str:
17
  """
18
  Args:
19
  inputs (:obj:`str`):
 
27
  predicted_class_id = self.model(X).logits.argmax().item()
28
  labels = ['positive','neutral','negative']
29
  reps = labels[predicted_class_id]
30
+ return reps