import torch import torch.nn as nn import numpy as np import streamlit as st from transformers import DistilBertModel, DistilBertTokenizerFast TARGET_IND2LABEL = { 0: 'Computer Science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Mathematics', 4: 'Physics', 5: 'Quantitative Biology', 6: 'Quantitative Finance', 7: 'Statistics', } class DistilBERTClassifier(nn.Module): def __init__(self, num_classes=8): super().__init__() self.encoder = DistilBertModel.from_pretrained("distilbert-base-cased") self.pre_classifier = nn.Linear(768, 768) self.gelu = nn.GELU() self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(768, num_classes) def forward(self, input_ids, attention_mask, labels): output = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output[0] pooler = hidden_state[:, 0] pooler = self.dropout(self.gelu(self.pre_classifier(pooler))) preds = self.classifier(pooler) return preds @st.cache_resource def load_tokenizer(): return DistilBertTokenizerFast.from_pretrained('distilbert-base-cased') @st.cache_resource def load_model(device): model = torch.load('model.pt', map_location=torch.device('cpu')).to(device) model.eval() return model def get_verdict(preds): inds = np.argsort(preds)[::-1] sum_prob = 0.0 verdict = [] for ind in inds: prob = preds[ind] sum_prob += prob verdict.append(f"{TARGET_IND2LABEL[ind]}: {prob}") if (sum_prob >= 0.95): break return "\n\n".join(verdict) def get_preds(text, model, tokenizer, device): tokens = tokenizer(text, padding=True, truncation=True, return_tensors='pt') tokens['input_ids'] = tokens['input_ids'].to(device) tokens['attention_mask'] = tokens['attention_mask'].to(device) tokens['labels'] = None # made for training convinience with torch.no_grad(): preds = torch.softmax(model(**tokens)[0], 0).cpu().numpy() return preds