File size: 2,803 Bytes
632029f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# coding=utf-8
import torch
from transformers.generation.logits_process import LogitsProcessor
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoTokenizer
import re

class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
    def __init__(self, penalty: float, penalty_dialog: torch.LongTensor, input_length: int):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty
        self.input_length = input_length
        self.penalty_dialog = penalty_dialog

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        new_scores = []
        if self.penalty == 0.0:
            return scores
        for input_, score in zip(input_ids, scores):
            generated_tokens = torch.cat((self.penalty_dialog, input_[self.input_length:]), dim=-1)
            token_frequency = torch.bincount(generated_tokens, minlength=scores.size(-1)).to(scores.device)
            new_scores.append(score - self.penalty * token_frequency)

        return torch.stack(new_scores).float()



class LlamaForConditionalGeneration(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)

    def generate(self, **kwargs):
        history_penalty = kwargs.pop("history_penalty", 0.0)
        penalty_turns = kwargs.pop("penalty_turns", 0)
        messages = kwargs.pop("messages", [])

        if history_penalty != 0.0 and penalty_turns >= 0:
            input_ids = kwargs.get("input_ids", torch.tensor([[]]))
            input_length = input_ids.size(-1)

            dialogs = []
            for i in range(len(messages)):
                message = messages[i]
                if message['role'] == 'assistant':
                    dialogs.append(message['content'])

            penalty_dialog = []
            for i in range(penalty_turns, 0, -1):
                if i <= len(dialogs):
                    dialog = dialogs[-i].replace("("," ").replace(")"," ").replace("("," ").replace(")"," ")
                    penalty_dialog.append(dialog)

            model_id = "Collective-Ai/collective-v0.1-chinese-roleplay-8b"
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            penalty_token = torch.LongTensor(tokenizer.encode(' '.join(penalty_dialog))).to(input_ids.device)

            logits_processor = []
            logits_processor.append(FrequencyPenaltyLogitsProcessor(penalty=history_penalty, penalty_dialog=penalty_token, input_length=input_length))
            result = super().generate(logits_processor = logits_processor, **kwargs)
        else:
            result = super().generate(**kwargs)

        return result