Spaces:
Sleeping
Sleeping
from transformers_interpret import SequenceClassificationExplainer | |
from typing import List, Tuple, Union | |
import torch | |
class xBitterT5_explainer(SequenceClassificationExplainer): | |
def _make_input_reference_pair( | |
self, text: Union[List, str] | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
if isinstance(text, list): | |
raise NotImplementedError("Lists of text are not currently supported.") | |
text_ids = self.encode(text) | |
input_ids = self.tokenizer.encode(text, add_special_tokens=True) | |
# if no special tokens were added | |
if len(text_ids) == len(input_ids): | |
ref_input_ids = [self.ref_token_id] * len(text_ids) | |
else: | |
ref_input_ids = ( | |
[self.cls_token_id] | |
+ [self.ref_token_id] * len(text_ids) | |
+ [self.sep_token_id] | |
) | |
# Use this because pretrained BioT5 plus does not have cls_token_id | |
ref_input_ids = [self.ref_token_id] * len(text_ids) + [self.sep_token_id] | |
return ( | |
torch.tensor([input_ids], device=self.device), | |
torch.tensor([ref_input_ids], device=self.device), | |
len(text_ids), | |
) | |