pouchedfox commited on
Commit
25d443b
·
1 Parent(s): 5047d71

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import csv
3
+ from typing import Optional
4
+ from urllib.request import urlopen
5
+ import gradio as gr
6
+
7
+
8
+ class SentimentTransform():
9
+ def __init__(
10
+ self,
11
+ model_name: str = "cardiffnlp/twitter-roberta-base-sentiment",
12
+ highlight: bool = False,
13
+ positive_sentiment_name: str = "positive",
14
+ max_number_of_shap_documents: Optional[int] = None,
15
+ min_abs_score: float = 0.1,
16
+ sensitivity: float = 0,
17
+ **kwargs,
18
+ ):
19
+ """
20
+ Sentiment Ops.
21
+ Parameters
22
+ -------------
23
+ model_name: str
24
+ The name of the model
25
+ sensitivity: float
26
+ How confident it is about being `neutral`. If you are dealing with news sources,
27
+ you probably want less sensitivity
28
+ """
29
+ self.model_name = model_name
30
+ self.highlight = highlight
31
+ self.positive_sentiment_name = positive_sentiment_name
32
+ self.max_number_of_shap_documents = max_number_of_shap_documents
33
+ self.min_abs_score = min_abs_score
34
+ self.sensitivity = sensitivity
35
+ for k, v in kwargs.items():
36
+ setattr(self, k, v)
37
+
38
+ def preprocess(self, text: str):
39
+ new_text = []
40
+ for t in text.split(" "):
41
+ t = "@user" if t.startswith("@") and len(t) > 1 else t
42
+ t = "http" if t.startswith("http") else t
43
+ new_text.append(t)
44
+ return " ".join(new_text)
45
+
46
+ @property
47
+ def classifier(self):
48
+ if not hasattr(self, "_classifier"):
49
+ import transformers
50
+
51
+ self._classifier = transformers.pipeline(
52
+ return_all_scores=True,
53
+ model=self.model_name,
54
+ )
55
+ return self._classifier
56
+
57
+ def _get_label_mapping(self, task: str):
58
+ # Note: this is specific to the current model
59
+ labels = []
60
+ mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt"
61
+ with urlopen(mapping_link) as f:
62
+ html = f.read().decode("utf-8").split("\n")
63
+ csvreader = csv.reader(html, delimiter="\t")
64
+ labels = [row[1] for row in csvreader if len(row) > 1]
65
+ return labels
66
+
67
+ @property
68
+ def label_mapping(self):
69
+ return {"LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"}
70
+
71
+ def analyze_sentiment(
72
+ self,
73
+ text,
74
+ highlight: bool = False,
75
+ positive_sentiment_name: str = "positive",
76
+ max_number_of_shap_documents: Optional[int] = None,
77
+ min_abs_score: float = 0.1,
78
+ ):
79
+ if text is None:
80
+ return None
81
+ labels = self.classifier([str(text)], truncation=True, max_length=512)
82
+ ind_max = np.argmax([l["score"] for l in labels[0]])
83
+ sentiment = labels[0][ind_max]["label"]
84
+ max_score = labels[0][ind_max]["score"]
85
+ sentiment = self.label_mapping.get(sentiment, sentiment)
86
+ if sentiment.lower() == "neutral" and max_score > self.sensitivity:
87
+ overall_sentiment = 1e-5
88
+ elif sentiment.lower() == "neutral":
89
+ # get the next highest score
90
+ new_labels = labels[0][:ind_max] + labels[0][(ind_max + 1):]
91
+ new_ind_max = np.argmax([l["score"] for l in new_labels])
92
+ new_max_score = new_labels[new_ind_max]["score"]
93
+ new_sentiment = new_labels[new_ind_max]["label"]
94
+ new_sentiment = self.label_mapping.get(new_sentiment, new_sentiment)
95
+ overall_sentiment = self._calculate_overall_sentiment(
96
+ new_max_score, new_sentiment
97
+ )
98
+
99
+ else:
100
+ overall_sentiment = self._calculate_overall_sentiment(max_score, sentiment)
101
+ # Adjust to avoid bug
102
+ if overall_sentiment == 0:
103
+ overall_sentiment = 1e-5
104
+ if not highlight:
105
+ return {
106
+ "sentiment": sentiment,
107
+ "overall_sentiment_score": overall_sentiment,
108
+ }
109
+ shap_documents = self.get_shap_values(
110
+ text,
111
+ sentiment_ind=ind_max,
112
+ max_number_of_shap_documents=max_number_of_shap_documents,
113
+ min_abs_score=min_abs_score,
114
+ )
115
+ return {
116
+ "sentiment": sentiment,
117
+ "score": max_score,
118
+ "overall_sentiment": overall_sentiment,
119
+ "highlight_chunk_": shap_documents,
120
+ }
121
+
122
+ def _calculate_overall_sentiment(self, score: float, sentiment: str):
123
+ if sentiment.lower().strip() == self.positive_sentiment_name:
124
+ return score
125
+ else:
126
+ return -score
127
+
128
+ # def explainer(self):
129
+ # if hasattr(self, "_explainer"):
130
+ # return self._explainer
131
+ # else:
132
+ # try:
133
+ # import shap
134
+ # except ModuleNotFoundError:
135
+ # raise MissingPackageError("shap")
136
+ # self._explainer = shap.Explainer(self.classifier)
137
+ # return self._explainer
138
+
139
+ def get_shap_values(
140
+ self,
141
+ text: str,
142
+ sentiment_ind: int = 2,
143
+ max_number_of_shap_documents: Optional[int] = None,
144
+ min_abs_score: float = 0.1,
145
+ ):
146
+ """Get SHAP values"""
147
+ shap_values = self.explainer([text])
148
+ cohorts = {"": shap_values}
149
+ cohort_labels = list(cohorts.keys())
150
+ cohort_exps = list(cohorts.values())
151
+ features = cohort_exps[0].data
152
+ feature_names = cohort_exps[0].feature_names
153
+ values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))])
154
+ shap_docs = [
155
+ {"text": v, "score": f}
156
+ for f, v in zip(
157
+ [x[sentiment_ind] for x in values[0][0].tolist()], feature_names[0]
158
+ )
159
+ ]
160
+ if max_number_of_shap_documents is not None:
161
+ sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)
162
+ else:
163
+ sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)[
164
+ :max_number_of_shap_documents
165
+ ]
166
+ return [d for d in sorted_scores if abs(d["score"]) > min_abs_score]
167
+
168
+ def transform(self, text):
169
+ # # For each document, update the field
170
+ # sentiment_docs = [{"_id": d["_id"]} for d in documents]
171
+ # for i, t in enumerate(self.text_fields):
172
+ # if self.output_fields is not None:
173
+ # output_field = self.output_fields[i]
174
+ # else:
175
+ # output_field = self._get_output_field(t)
176
+ sentiment = self.analyze_sentiment(
177
+ text,
178
+ highlight=self.highlight,
179
+ max_number_of_shap_documents=self.max_number_of_shap_documents,
180
+ min_abs_score=self.min_abs_score, )
181
+ return sentiment
182
+
183
+
184
+ def sentiment_classifier(text, model_type, sensitivity):
185
+ if model_type == 'Social Media Model':
186
+ model_name = "cardiffnlp/twitter-roberta-base-sentiment"
187
+ elif model_type == 'Survey Model':
188
+ model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
189
+ else:
190
+ model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
191
+ model = SentimentTransform(model_name=model_name, sensitivity=sensitivity)
192
+ res_dict = model.transform(text)
193
+ return res_dict['sentiment'], res_dict['overall_sentiment_score']
194
+
195
+
196
+ demo = gr.Interface(
197
+ fn=sentiment_classifier,
198
+ inputs=[gr.Textbox(placeholder="Put the text here and click 'submit' to predict its sentiment", label="Input Text"), gr.Dropdown(["Social Media Model", "Survey Model"], value="Survey Model", label="Select the Model that you want to use."), gr.Slider(0, 1, step = 0.01, label="Sensitivity (How confident it is about being `neutral`. If you are dealing with news sources, you probably want less sensitivity.)")],
199
+ outputs=[gr.Textbox(label='Sentiment'), gr.Textbox(label='Sentiment Score')],
200
+ )
201
+ demo.launch(debug=True)