Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +33 -29
tasks/text.py
CHANGED
@@ -20,7 +20,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
20 |
|
21 |
router = APIRouter()
|
22 |
|
23 |
-
DESCRIPTION = "
|
24 |
ROUTE = "/text"
|
25 |
|
26 |
|
@@ -28,19 +28,18 @@ if torch.cuda.is_available():
|
|
28 |
device = torch.device("cuda")
|
29 |
else:
|
30 |
device = torch.device("cpu")
|
31 |
-
print(device)
|
32 |
|
33 |
|
34 |
-
MODEL = "
|
35 |
|
36 |
-
class
|
37 |
nn.Module,
|
38 |
PyTorchModelHubMixin,
|
39 |
# optionally, you can add metadata which gets pushed to the model card
|
40 |
):
|
41 |
-
def __init__(self, num_classes):
|
42 |
super().__init__()
|
43 |
-
self.h1 = nn.Linear(
|
44 |
self.h2 = nn.Linear(100, 100)
|
45 |
self.h3 = nn.Linear(100, 100)
|
46 |
self.h4 = nn.Linear(100, 50)
|
@@ -71,7 +70,7 @@ class CTBERT(
|
|
71 |
PyTorchModelHubMixin,
|
72 |
# optionally, you can add metadata which gets pushed to the model card
|
73 |
):
|
74 |
-
def __init__(self, num_classes):
|
75 |
super().__init__()
|
76 |
self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
|
77 |
self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
|
@@ -87,7 +86,7 @@ class conspiracyModelBase(
|
|
87 |
PyTorchModelHubMixin,
|
88 |
# optionally, you can add metadata which gets pushed to the model card
|
89 |
):
|
90 |
-
def __init__(self, num_classes):
|
91 |
super().__init__()
|
92 |
self.n_classes = num_classes
|
93 |
self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base', num_labels=num_classes)
|
@@ -102,7 +101,7 @@ class conspiracyModelLarge(
|
|
102 |
PyTorchModelHubMixin,
|
103 |
# optionally, you can add metadata which gets pushed to the model card
|
104 |
):
|
105 |
-
def __init__(self, num_classes):
|
106 |
super().__init__()
|
107 |
self.n_classes = num_classes
|
108 |
self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
|
@@ -117,12 +116,10 @@ class gteModelLarge(
|
|
117 |
PyTorchModelHubMixin,
|
118 |
# optionally, you can add metadata which gets pushed to the model card
|
119 |
):
|
120 |
-
def __init__(self, num_classes):
|
121 |
super().__init__()
|
122 |
self.n_classes = num_classes
|
123 |
-
#self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
|
124 |
self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
125 |
-
#self.cls = nn.Linear(768, num_classes)
|
126 |
self.cls = nn.Linear(1024, num_classes)
|
127 |
|
128 |
def forward(self, input_ids, input_mask, input_type_ids):
|
@@ -136,20 +133,17 @@ class gteModel(
|
|
136 |
PyTorchModelHubMixin,
|
137 |
# optionally, you can add metadata which gets pushed to the model card
|
138 |
):
|
139 |
-
def __init__(self, num_classes):
|
140 |
super().__init__()
|
141 |
self.n_classes = num_classes
|
142 |
-
#self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
|
143 |
self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
|
144 |
self.cls = nn.Linear(768, num_classes)
|
145 |
-
#self.cls = nn.Linear(1024, num_classes)
|
146 |
|
147 |
def forward(self, input_ids, input_mask, input_type_ids):
|
148 |
outputs = self.gte(input_ids = input_ids, attention_mask = input_mask, token_type_ids = input_type_ids)
|
149 |
embeddings = outputs.last_hidden_state[:, 0]
|
150 |
logits = self.cls(embeddings)
|
151 |
-
return logits
|
152 |
-
|
153 |
|
154 |
@router.post(ROUTE, tags=["Text Task"],
|
155 |
description=DESCRIPTION)
|
@@ -187,20 +181,31 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
187 |
test_dataset = dataset["test"]
|
188 |
|
189 |
if MODEL =="mlp":
|
190 |
-
model =
|
191 |
-
model = model.to(device)
|
192 |
-
emb_model = SentenceTransformer("
|
193 |
batch_size = 6
|
194 |
|
195 |
test_texts = torch.Tensor(emb_model.encode([t['quote'] for t in test_dataset]))
|
196 |
test_data = TensorDataset(test_texts)
|
197 |
test_sampler = SequentialSampler(test_data)
|
198 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
elif MODEL == "ct":
|
201 |
-
model = CTBERT.from_pretrained("ypesk/frugal-ai-ct-bert-baseline")
|
202 |
model = model.to(device)
|
203 |
-
tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')
|
204 |
|
205 |
test_texts = [t['quote'] for t in test_dataset]
|
206 |
|
@@ -220,7 +225,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
220 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
221 |
|
222 |
elif MODEL == "modern-base":
|
223 |
-
model = conspiracyModelBase.from_pretrained("ypesk/frugal-ai-modern-base-
|
224 |
model = model.to(device)
|
225 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
226 |
|
@@ -241,7 +246,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
241 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
242 |
|
243 |
elif MODEL == "modern-large":
|
244 |
-
model = conspiracyModelLarge.from_pretrained(
|
245 |
model = model.to(device)
|
246 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
|
247 |
|
@@ -262,7 +267,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
262 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
263 |
|
264 |
elif MODEL == "gte-base":
|
265 |
-
model = gteModel.from_pretrained("ypesk/frugal-ai-gte-base-
|
266 |
model = model.to(device)
|
267 |
tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-base-en-v1.5')
|
268 |
|
@@ -284,7 +289,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
284 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
285 |
|
286 |
elif MODEL == "gte-large":
|
287 |
-
model =
|
288 |
model = model.to(device)
|
289 |
tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-large-en-v1.5')
|
290 |
|
@@ -333,8 +338,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
333 |
logits = model(b_input_ids, b_input_mask, b_token_type_ids)
|
334 |
|
335 |
logits = logits.detach().cpu().numpy()
|
336 |
-
predictions.extend(logits.argmax(1))
|
337 |
-
|
338 |
|
339 |
true_labels = test_dataset["label"]
|
340 |
# Make random predictions (placeholder for actual model inference)
|
|
|
20 |
|
21 |
router = APIRouter()
|
22 |
|
23 |
+
DESCRIPTION = "Submission 2: SBERT+MLP"
|
24 |
ROUTE = "/text"
|
25 |
|
26 |
|
|
|
28 |
device = torch.device("cuda")
|
29 |
else:
|
30 |
device = torch.device("cpu")
|
|
|
31 |
|
32 |
|
33 |
+
MODEL = "mlp" #sk, mlp, ct, modern-base, modern-large, gte-base, gte-large
|
34 |
|
35 |
+
class ConspiracyClassification768(
|
36 |
nn.Module,
|
37 |
PyTorchModelHubMixin,
|
38 |
# optionally, you can add metadata which gets pushed to the model card
|
39 |
):
|
40 |
+
def __init__(self, num_classes=8):
|
41 |
super().__init__()
|
42 |
+
self.h1 = nn.Linear(768, 100)
|
43 |
self.h2 = nn.Linear(100, 100)
|
44 |
self.h3 = nn.Linear(100, 100)
|
45 |
self.h4 = nn.Linear(100, 50)
|
|
|
70 |
PyTorchModelHubMixin,
|
71 |
# optionally, you can add metadata which gets pushed to the model card
|
72 |
):
|
73 |
+
def __init__(self, num_classes=8):
|
74 |
super().__init__()
|
75 |
self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
|
76 |
self.bert.cls.seq_relationship = nn.Linear(1024, num_classes)
|
|
|
86 |
PyTorchModelHubMixin,
|
87 |
# optionally, you can add metadata which gets pushed to the model card
|
88 |
):
|
89 |
+
def __init__(self, num_classes=8):
|
90 |
super().__init__()
|
91 |
self.n_classes = num_classes
|
92 |
self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base', num_labels=num_classes)
|
|
|
101 |
PyTorchModelHubMixin,
|
102 |
# optionally, you can add metadata which gets pushed to the model card
|
103 |
):
|
104 |
+
def __init__(self, num_classes=8):
|
105 |
super().__init__()
|
106 |
self.n_classes = num_classes
|
107 |
self.bert = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-large', num_labels=num_classes)
|
|
|
116 |
PyTorchModelHubMixin,
|
117 |
# optionally, you can add metadata which gets pushed to the model card
|
118 |
):
|
119 |
+
def __init__(self, num_classes=8):
|
120 |
super().__init__()
|
121 |
self.n_classes = num_classes
|
|
|
122 |
self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)
|
|
|
123 |
self.cls = nn.Linear(1024, num_classes)
|
124 |
|
125 |
def forward(self, input_ids, input_mask, input_type_ids):
|
|
|
133 |
PyTorchModelHubMixin,
|
134 |
# optionally, you can add metadata which gets pushed to the model card
|
135 |
):
|
136 |
+
def __init__(self, num_classes=8):
|
137 |
super().__init__()
|
138 |
self.n_classes = num_classes
|
|
|
139 |
self.gte = AutoModel.from_pretrained('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)
|
140 |
self.cls = nn.Linear(768, num_classes)
|
|
|
141 |
|
142 |
def forward(self, input_ids, input_mask, input_type_ids):
|
143 |
outputs = self.gte(input_ids = input_ids, attention_mask = input_mask, token_type_ids = input_type_ids)
|
144 |
embeddings = outputs.last_hidden_state[:, 0]
|
145 |
logits = self.cls(embeddings)
|
146 |
+
return logits
|
|
|
147 |
|
148 |
@router.post(ROUTE, tags=["Text Task"],
|
149 |
description=DESCRIPTION)
|
|
|
181 |
test_dataset = dataset["test"]
|
182 |
|
183 |
if MODEL =="mlp":
|
184 |
+
model = ConspiracyClassification768.from_pretrained("ypesk/frugal-ai-EURECOM-mlp-768-fullset")
|
185 |
+
model = model.to(device)
|
186 |
+
emb_model = SentenceTransformer("sentence-transformers/sentence-t5-large")
|
187 |
batch_size = 6
|
188 |
|
189 |
test_texts = torch.Tensor(emb_model.encode([t['quote'] for t in test_dataset]))
|
190 |
test_data = TensorDataset(test_texts)
|
191 |
test_sampler = SequentialSampler(test_data)
|
192 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
193 |
+
elif MODEL == "sk":
|
194 |
+
emb_model = SentenceTransformer("sentence-transformers/sentence-t5-large")
|
195 |
+
batch_size = 512
|
196 |
+
|
197 |
+
test_texts = torch.Tensor(emb_model.encode([t['quote'] for t in test_dataset]))
|
198 |
+
test_data = TensorDataset(test_texts)
|
199 |
+
test_sampler = SequentialSampler(test_data)
|
200 |
+
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
201 |
+
|
202 |
+
model = pickle.load(open('../svm.pkl', "rb"))
|
203 |
+
|
204 |
+
|
205 |
elif MODEL == "ct":
|
206 |
+
model = CTBERT.from_pretrained("ypesk/frugal-ai-EURECOM-ct-bert-baseline")
|
207 |
model = model.to(device)
|
208 |
+
tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-fullset')
|
209 |
|
210 |
test_texts = [t['quote'] for t in test_dataset]
|
211 |
|
|
|
225 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
226 |
|
227 |
elif MODEL == "modern-base":
|
228 |
+
model = conspiracyModelBase.from_pretrained("ypesk/frugal-ai-EURECOM-modern-base-fullset")
|
229 |
model = model.to(device)
|
230 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
231 |
|
|
|
246 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
247 |
|
248 |
elif MODEL == "modern-large":
|
249 |
+
model = conspiracyModelLarge.from_pretrained('ypesk/frugal-ai-EURECOM-modern-large-fullset')
|
250 |
model = model.to(device)
|
251 |
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-large")
|
252 |
|
|
|
267 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
268 |
|
269 |
elif MODEL == "gte-base":
|
270 |
+
model = gteModel.from_pretrained("ypesk/frugal-ai-EURECOM-gte-base-fullset")
|
271 |
model = model.to(device)
|
272 |
tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-base-en-v1.5')
|
273 |
|
|
|
289 |
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
|
290 |
|
291 |
elif MODEL == "gte-large":
|
292 |
+
model = gteModelLarge.from_pretrained("ypesk/frugal-ai-EURECOM-gte-large-fullset")
|
293 |
model = model.to(device)
|
294 |
tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-large-en-v1.5')
|
295 |
|
|
|
338 |
logits = model(b_input_ids, b_input_mask, b_token_type_ids)
|
339 |
|
340 |
logits = logits.detach().cpu().numpy()
|
341 |
+
predictions.extend(logits.argmax(1))
|
|
|
342 |
|
343 |
true_labels = test_dataset["label"]
|
344 |
# Make random predictions (placeholder for actual model inference)
|