AnnaWegmann's picture
Update README.md
bd054e7 verified
|
raw
history blame
6.38 kB
metadata
license: cc-by-4.0
datasets:
  - AnnaWegmann/Dialog-Paraphrase
language:
  - en
base_model: microsoft/deberta-v3-large

Model was created as described in https://arxiv.org/abs/2404.06670 , this is the best DeBERTa AGGREGATED model. See also the GitHub repository.

from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

class ParaphraseHighlighter:
    def __init__(self, model_name="AnnaWegmann/Highlight-Paraphrases-in-Dialog"):
        # Load the tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModelForTokenClassification.from_pretrained(model_name)
        
        # Get the label id for 'LABEL_1'
        self.label2id = self.model.config.label2id
        self.label_id = self.label2id['LABEL_1']
    
    def highlight_paraphrase(self, text1, text2):
        # Tokenize the inputs with the tokenizer
        encoding = self.tokenizer(text1, text2, return_tensors="pt", padding=True, truncation=True)
        
        outputs = self.model(**encoding)
        logits = outputs.logits  # Shape: (batch_size, sequence_length, num_labels)
        # Apply softmax to get probabilities, automatically places [SEP] token
        probs = torch.nn.functional.softmax(logits, dim=-1)  # Shape: (batch_size, sequence_length, num_labels)
        
        # Convert token IDs back to tokens
        tokens = self.tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])        
        # Get word IDs to map tokens to words
        word_ids = encoding.word_ids(batch_index=0)
        # Get sequence IDs to know which text the token belongs to
        sequence_ids = encoding.sequence_ids(batch_index=0)
        
        # Collect words and probabilities for each text
        words_text1 = []
        words_text2 = []
        probs_text1 = []
        probs_text2 = []
        
        previous_word_idx = None
        
        # For determining if there are high-probability words in both texts
        has_high_prob_text1 = False
        has_high_prob_text2 = False
        
        for idx, (word_idx, seq_id) in enumerate(zip(word_ids, sequence_ids)):
            if word_idx is None:
                # Skip special tokens like [CLS], [SEP], [PAD]
                continue

            if word_idx != previous_word_idx:
                # Start of a new word
                word_tokens = [tokens[idx]]

                # Get the probability for LABEL_1 for the first token of the word
                prob_LABEL_1 = probs[0][idx][self.label_id].item()

                # Collect subsequent tokens belonging to the same word
                j = idx + 1
                while j < len(word_ids) and word_ids[j] == word_idx:
                    word_tokens.append(tokens[j])
                    j += 1

                # Reconstruct the word
                word = self.tokenizer.convert_tokens_to_string(word_tokens).strip()

                # Check if probability >= 0.5 to uppercase
                if prob_LABEL_1 >= 0.5:
                    word_display = word.upper()
                    if seq_id == 0:
                        has_high_prob_text1 = True
                    elif seq_id == 1:
                        has_high_prob_text2 = True
                else:
                    word_display = word

                # Append the word and probability to the appropriate list
                if seq_id == 0:
                    words_text1.append(word_display)
                    probs_text1.append(prob_LABEL_1)
                elif seq_id == 1:
                    words_text2.append(word_display)
                    probs_text2.append(prob_LABEL_1)
                else:
                    # Should not happen
                    pass

            previous_word_idx = word_idx
        
        # Determine if there are words in both texts with prob >= 0.5
        if has_high_prob_text1 and has_high_prob_text2:
            print("is a paraphrase")
        else:
            print("is not a paraphrase")
        
        # Function to format and align words and probabilities
        def print_aligned(words, probs):
            # Determine the maximum length of words for formatting
            max_word_length = max(len(word) for word in words)
            # Create format string for alignment
            format_str = f'{{:<{max_word_length}}}'
            # Print words
            for word in words:
                print(format_str.format(word), end=' ')
            print()
            # Print probabilities aligned below words
            for prob in probs:
                prob_str = f"{prob:.2f}"
                print(format_str.format(prob_str), end=' ')
            print('\n')
        
        # Print text1's words and probabilities aligned
        print("\nSpeaker 1:")
        print_aligned(words_text1, probs_text1)
        
        # Print text2's words and probabilities aligned
        print("Speaker 2:")
        print_aligned(words_text2, probs_text2)
        
# Example usage
highlighter = ParaphraseHighlighter()
text1 = "And it will be my 20th time in doing it as a television commentator from Rome so."
text2 = "Yes, you've been doing this for a while now."
highlighter.highlight_paraphrase(text1, text2)

should return

is a paraphrase

Speaker 1:
And         IT          will        BE          MY          20TH        TIME        IN          DOING       IT          as          a           TELEVISION  COMMENTATOR from        Rome        so.         
0.15        0.54        0.49        0.56        0.74        0.83        0.77        0.75        0.78        0.76        0.44        0.45        0.52        0.52        0.30        0.37        0.21        

Speaker 2:
Yes,   YOU'VE BEEN   DOING  THIS   FOR    A      WHILE  NOW.   
0.12   0.79   0.78   0.82   0.82   0.69   0.70   0.72   0.66   

For comments or questions reach out to Anna (a.m.wegmann @ uu.nl) or raise an issue on GitHub.

If you find this model helpful, consider citing our paper:

@article{wegmann2024,
  title={What's Mine becomes Yours: Defining, Annotating and Detecting Context-Dependent Paraphrases in News Interview Dialogs},
  author={Wegmann, Anna and Broek, Tijs van den and Nguyen, Dong},
  journal={arXiv preprint arXiv:2404.06670},
  year={2024}
}