ypesk commited on
Commit
5414c47
·
verified ·
1 Parent(s): 0ae53cb

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +62 -3
tasks/text.py CHANGED
@@ -9,7 +9,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -55,10 +55,69 @@ async def evaluate_text(request: TextEvaluationRequest):
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
9
 
10
  router = APIRouter()
11
 
12
+ DESCRIPTION = "First Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
 
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
+ class CovidTwitterBertClassifier(nn.Module):
59
 
60
+ def __init__(self, n_classes):
61
+ super().__init__()
62
+ self.n_classes = n_classes
63
+ self.bert = BertForPreTraining.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
64
+ self.bert.cls.seq_relationship = nn.Linear(1024, n_classes)
65
+
66
+ self.sigmoid = nn.Sigmoid()
67
+
68
+ def forward(self, input_ids, token_type_ids, input_mask):
69
+ outputs = self.bert(input_ids = input_ids, token_type_ids = token_type_ids, attention_mask = input_mask)
70
+
71
+ logits = outputs[1]
72
+
73
+ return logits
74
+
75
+ model = CovidTwitterBertClassifier(8)
76
+
77
+ model.to(device)
78
+ model.load_state_dict(torch.load('model.pth'))
79
+ model.eval()
80
+
81
+
82
+ tokenizer = AutoTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert')
83
+
84
+ test_texts = [t['quote'] for t in data_test]
85
+
86
+ MAX_LEN = 128 #1024 # < m some tweets will be truncated
87
+
88
+ tokenized_test = tokenizer(test_texts, max_length=MAX_LEN, padding='max_length', truncation=True)
89
+ test_input_ids, test_token_type_ids, test_attention_mask = tokenized_test['input_ids'], tokenized_test['token_type_ids'], tokenized_test['attention_mask']
90
+ test_token_type_ids = torch.tensor(test_token_type_ids)
91
+
92
+ test_input_ids = torch.tensor(test_input_ids)
93
+ test_attention_mask = torch.tensor(test_attention_mask)
94
+
95
+ batch_size = 8 #
96
+ test_data = TensorDataset(test_input_ids, test_attention_mask, test_token_type_ids)
97
+
98
+ test_sampler = SequentialSampler(test_data)
99
+ test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)
100
+
101
+ predictions = []
102
+ for step, batch in enumerate(test_dataloader):
103
+
104
+ # Add batch to GPU
105
+ batch = tuple(t.to(device) for t in batch)
106
+
107
+ b_input_ids, b_input_mask, b_token_type_ids = batch
108
+ with torch.no_grad():
109
+ logits = model(b_input_ids, b_token_type_ids, b_input_mask)
110
+
111
+ logits = logits.detach().cpu().numpy()
112
+ predictions.extend(logits.argmax(1))
113
+ for l in ground_truth:
114
+ labels_sep.append(l)
115
+
116
+
117
+ true_labels = test_dataset["label"]
118
  # Make random predictions (placeholder for actual model inference)
119
+ #true_labels = test_dataset["label"]
120
+ #predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
121
 
122
  #--------------------------------------------------------------------------------------------
123
  # YOUR MODEL INFERENCE STOPS HERE