Chris4K commited on
Commit
cc1edab
·
1 Parent(s): 1c02c6e

Update sentiment_analysis.py

Browse files
Files changed (1) hide show
  1. sentiment_analysis.py +24 -24
sentiment_analysis.py CHANGED
@@ -13,30 +13,30 @@ class SentimentAnalysisTool:
13
  def __call__(self, inputs: str):
14
  return SentimentAnalysisTool.predicto(str)
15
 
16
- model_id_1 = "nlptown/bert-base-multilingual-uncased-sentiment"
17
- model_id_2 = "microsoft/deberta-xlarge-mnli"
18
- model_id_3 = "distilbert-base-uncased-finetuned-sst-2-english"
19
- model_id_4 = "lordtt13/emo-mobilebert"
20
- model_id_5 = "juliensimon/reviews-sentiment-analysis"
21
- model_id_6 = "sbcBI/sentiment_analysis_model"
22
- model_id_7 = "models/oliverguhr/german-sentiment-bert"
23
-
24
- def parse_output(output_json):
25
- list_pred=[]
26
- for i in range(len(output_json[0])):
27
- label = output_json[0][i]['label']
28
- score = output_json[0][i]['score']
29
- list_pred.append((label, score))
30
- return list_pred
31
-
32
- def get_prediction(model_id):
33
- classifier = pipeline("text-classification", model=model_id, return_all_scores=True)
34
-
35
- def predicto(review):
36
- classifier = SentimentAnalysisTool.get_prediction(SentimentAnalysisTool.model_id_7)
37
- prediction = classifier(review)
38
- print(prediction)
39
- return SentimentAnalysisTool.parse_output(prediction)
40
 
41
 
42
 
 
13
  def __call__(self, inputs: str):
14
  return SentimentAnalysisTool.predicto(str)
15
 
16
+ model_id_1 = "nlptown/bert-base-multilingual-uncased-sentiment"
17
+ model_id_2 = "microsoft/deberta-xlarge-mnli"
18
+ model_id_3 = "distilbert-base-uncased-finetuned-sst-2-english"
19
+ model_id_4 = "lordtt13/emo-mobilebert"
20
+ model_id_5 = "juliensimon/reviews-sentiment-analysis"
21
+ model_id_6 = "sbcBI/sentiment_analysis_model"
22
+ model_id_7 = "models/oliverguhr/german-sentiment-bert"
23
+
24
+ def parse_output(output_json):
25
+ list_pred=[]
26
+ for i in range(len(output_json[0])):
27
+ label = output_json[0][i]['label']
28
+ score = output_json[0][i]['score']
29
+ list_pred.append((label, score))
30
+ return list_pred
31
+
32
+ def get_prediction(model_id):
33
+ classifier = pipeline("text-classification", model=model_id, return_all_scores=True)
34
+
35
+ def predicto(review):
36
+ classifier = SentimentAnalysisTool.get_prediction(SentimentAnalysisTool.model_id_7)
37
+ prediction = classifier(review)
38
+ print(prediction)
39
+ return SentimentAnalysisTool.parse_output(prediction)
40
 
41
 
42