transformer_devops / model.py
Author
first commit
9a179e2
raw
history blame
2.1 kB
import torch
import torch.nn as nn
import numpy as np
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizerFast
TARGET_IND2LABEL = {
0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics',
}
class DistilBERTClassifier(nn.Module):
def __init__(self, num_classes=8):
super().__init__()
self.encoder = DistilBertModel.from_pretrained("distilbert-base-cased")
self.pre_classifier = nn.Linear(768, 768)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(768, num_classes)
def forward(self, input_ids, attention_mask, labels):
output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_state = output[0]
pooler = hidden_state[:, 0]
pooler = self.dropout(self.gelu(self.pre_classifier(pooler)))
preds = self.classifier(pooler)
return preds
@st.cache_resource
def load_tokenizer():
return DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
@st.cache_resource
def load_model(device):
model = torch.load('model.pt', map_location=torch.device('cpu')).to(device)
model.eval()
return model
def get_verdict(preds):
inds = np.argsort(preds)[::-1]
sum_prob = 0.0
verdict = []
for ind in inds:
prob = preds[ind]
sum_prob += prob
verdict.append(f"{TARGET_IND2LABEL[ind]}: {prob}")
if (sum_prob >= 0.95):
break
return "\n\n".join(verdict)
def get_preds(text, model, tokenizer, device):
tokens = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
tokens['input_ids'] = tokens['input_ids'].to(device)
tokens['attention_mask'] = tokens['attention_mask'].to(device)
tokens['labels'] = None # made for training convinience
with torch.no_grad():
preds = torch.softmax(model(**tokens)[0], 0).cpu().numpy()
return preds