Model Description
This model takes in text from a news article and outputs an embedding representing that article. These output embeddings have been trained such that the cosine similarity between articles aligns with overall article similarity. The model was trained using data from the 2022 SemEval Task-8 News Article Similarity challenge, and achieves the second-highest score when evaluated using the test set from the challenge. Designed for speed and scalability, this model is ideal for embedding many news articles (or similar text) and using fast cosine similarity calculations for pairwise similarity over very large corpora.
- Developed by: Ben Litterer, David Jurgens, Dallas Card
- Finetuned from model: all-mpnet-base-v2
Uses
This model is ideal for embedding large corpora of text and calculating pairwise similarity scores. Note that when training, article headlines were first concatenated to the full article text. The first 288 tokens and the last 96 tokens were then concatenated to fit in the all-mpnet-base-v2 context window.
How to Get Started with the Model
Use the code below to get started with the model. All you need are the weights in state_dict.tar
import torch
import torch.nn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import numpy as np
MODEL_PATH = "/my/path/to/state_dict.tar"
#declare model class, inheriting from nn.Module
class BiModel(torch.nn.Module):
def __init__(self):
super(BiModel,self).__init__()
self.model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device).train()
self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4)
#pool token level embeddings
def mean_pooling(self, token_embeddings, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
#Note that here we expect only one batch of input ids and attention masks
def encode(self, input_ids, attention_mask):
encoding = self.model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))[0]
meanPooled = self.mean_pooling(encoding, attention_mask.squeeze(1))
return meanPooled
#NOTE: here we expect a list of two that we then unpack
def forward(self, input_ids, attention_mask):
input_ids_a = input_ids[0].to(device)
input_ids_b = input_ids[1].to(device)
attention_a = attention_mask[0].to(device)
attention_b = attention_mask[1].to(device)
#encode sentence and get mean pooled sentence representation
encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings
encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0]
meanPooled1 = self.mean_pooling(encoding1, attention_a)
meanPooled2 = self.mean_pooling(encoding2, attention_b)
pred = self.cos(meanPooled1, meanPooled2)
return pred
#set device as needed, initialize model, load weights
device = torch.device("cpu")
trainedModel = BiModel()
sDict = torch.load(MODEL_PATH)
#may need to run depending on pytorch version
del sDict["model.embeddings.position_ids"]
#initialize tokenizer for all-mpnet-base-v2
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
#initialize model
trainedModel.load_state_dict(sDict)
#trainedModel is now ready to use