Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
import streamlit as st | |
from transformers import DistilBertModel, DistilBertTokenizerFast | |
TARGET_IND2LABEL = { | |
0: 'Computer Science', | |
1: 'Economics', | |
2: 'Electrical Engineering and Systems Science', | |
3: 'Mathematics', | |
4: 'Physics', | |
5: 'Quantitative Biology', | |
6: 'Quantitative Finance', | |
7: 'Statistics', | |
} | |
class DistilBERTClassifier(nn.Module): | |
def __init__(self, num_classes=8): | |
super().__init__() | |
self.encoder = DistilBertModel.from_pretrained("distilbert-base-cased") | |
self.pre_classifier = nn.Linear(768, 768) | |
self.gelu = nn.GELU() | |
self.dropout = nn.Dropout(0.1) | |
self.classifier = nn.Linear(768, num_classes) | |
def forward(self, input_ids, attention_mask, labels): | |
output = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
hidden_state = output[0] | |
pooler = hidden_state[:, 0] | |
pooler = self.dropout(self.gelu(self.pre_classifier(pooler))) | |
preds = self.classifier(pooler) | |
return preds | |
def load_tokenizer(): | |
return DistilBertTokenizerFast.from_pretrained('distilbert-base-cased') | |
def load_model(device): | |
model = torch.load('model.pt', map_location=torch.device('cpu')).to(device) | |
model.eval() | |
return model | |
def get_verdict(preds): | |
inds = np.argsort(preds)[::-1] | |
sum_prob = 0.0 | |
verdict = [] | |
for ind in inds: | |
prob = preds[ind] | |
sum_prob += prob | |
verdict.append(f"{TARGET_IND2LABEL[ind]}: {prob}") | |
if (sum_prob >= 0.95): | |
break | |
return "\n\n".join(verdict) | |
def get_preds(text, model, tokenizer, device): | |
tokens = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
tokens['input_ids'] = tokens['input_ids'].to(device) | |
tokens['attention_mask'] = tokens['attention_mask'].to(device) | |
tokens['labels'] = None # made for training convinience | |
with torch.no_grad(): | |
preds = torch.softmax(model(**tokens)[0], 0).cpu().numpy() | |
return preds | |