File size: 3,211 Bytes
c7a96cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch

from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase

from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse

B = TypeVar("B", bound=Batch)


class Model(ABC):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        requires_padding: bool,
        dtype: torch.dtype,
        device: torch.device,
        decode_buffer: int = 3,
    ):
        if decode_buffer < 1:
            raise ValueError("decode_buffer must be >= 1")

        self.tokenizer = tokenizer
        self.all_special_ids = set(tokenizer.all_special_ids)
        self.requires_padding = requires_padding
        self.dtype = dtype
        self.device = device
        self.decode_buffer = decode_buffer

    @property
    def info(self) -> InfoResponse:
        return InfoResponse(
            requires_padding=self.requires_padding,
            dtype=str(self.dtype),
            device_type=self.device.type,
        )

    @property
    @abstractmethod
    def batch_type(self) -> Type[B]:
        raise NotImplementedError

    @abstractmethod
    def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
        raise NotImplementedError

    def decode_token(
        self,
        all_input_ids: List[int],
        offset: Optional[int] = None,
        token_offset: Optional[int] = None,
    ) -> Tuple[str, Optional[int], Optional[int]]:
        """Hack to hopefully support generate_stream for the maximum number of tokenizers"""
        if all_input_ids[-1] in self.all_special_ids:
            return (
                self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
                None,
                None,
            )

        if token_offset is None:
            token_offset = len(all_input_ids) - self.decode_buffer
            # left token buffer
            if self.decode_buffer > 1:
                # Decode token_offset token minus last one and token_offset tokens
                raw_texts = self.tokenizer.batch_decode(
                    [all_input_ids[token_offset:-1], all_input_ids[token_offset:]],
                    skip_special_tokens=False,
                )

                # default offset is only the last token
                offset = len(raw_texts[0])
                sequence_text = raw_texts[1]
            else:
                # Only decode the last token without using a token buffer
                sequence_text = self.tokenizer.decode(
                    all_input_ids[-1], skip_special_tokens=False
                )
                # no offset in this case
                offset = 0
        else:
            assert offset is not None
            sequence_text = self.tokenizer.decode(
                all_input_ids[token_offset:],
                skip_special_tokens=False,
            )

        # get text
        token_text = sequence_text[offset:]

        # if text is utf-8
        if token_text and token_text[-1] != "�":
            return token_text, None, None
        else:
            return "", offset, token_offset