test-pair-classification / pair_classification_tool.py
sgugger's picture
Upload tool
e27ee71
raw
history blame
920 Bytes
import torch
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from transformers.tools import PipelineTool
class TextPairClassificationTool(PipelineTool):
default_checkpoint = "sgugger/bert-finetuned-mrpc"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
description = (
"This is a tool that classifies if two texts in English are similar or not using the labels 'equivalent' and "
"'not_equivalent'. It takes two inputs named `text` and `second_text` which should be in English and returns "
"the predicted label."
)
def encode(self, text, second_text):
return self.pre_processor(text, second_text, return_tensors="pt")
def decode(self, outputs):
logits = outputs.logits
label_id = torch.argmax(logits[0]).item()
return self.model.config.id2label[label_id]