import logging import json from typing import Dict, List from pydantic.dataclasses import dataclass from transformers import PreTrainedTokenizerFast from tokenizers.decoders import Decoder logger = logging.getLogger(__name__) # fmt: off # https://huggingface.co./docs/transformers/main/en/chat_templating PROMPT_TEMPLATE = ( "{{ '<|bos|>' }}" "{{ '' }}" "{% if 'rating' not in messages or messages['rating'] is none %}" "{{ 'rating:sfw, rating:general' }}" "{% else %}" "{{ messages['rating'] }}" "{% endif %}" "{{ '' }}" "{{ '' }}" "{% if 'copyright' not in messages or messages['copyright'] is none %}" "{{ '' }}" "{% else %}" "{{ messages['copyright'] }}" "{% endif %}" "{{ '' }}" "{{ '' }}" "{% if 'character' not in messages or messages['character'] is none %}" "{{ '' }}" "{% else %}" "{{ messages['character'] }}" "{% endif %}" "{{ '' }}" "{{ '' }}" # length token "{% if 'length' not in messages or messages['length'] is none %}" "{{ '<|long|>' }}" "{% else %}" "{{ messages['length'] }}" "{% endif %}" # general token "{% if 'general' not in messages or messages['general'] is none %}" "{{ '' }}" "{% else %}" "{{ messages['general'] }}" "{% endif %}" "{{ '<|input_end|>' }}" ).strip() # fmt: on @dataclass class Category: name: str bos_token_id: int eos_token_id: int @dataclass class TagCategoryConfig: categories: Dict[str, Category] category_to_token_ids: Dict[str, List[int]] def load_tag_category_config(config_json: str): with open(config_json, "rb") as file: config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read())) return config class DartDecoder: def __init__(self, special_tokens: List[str]): self.special_tokens = list(special_tokens) def decode_chain(self, tokens: List[str]) -> List[str]: new_tokens = [] is_specials = [] for i, token in enumerate(tokens): is_specials.append(token in self.special_tokens) if i == 0: new_tokens.append(token) continue # this token or previous token is special if is_specials[i] or is_specials[i - 1]: new_tokens.append(token) continue new_tokens.append(f", {token}") return new_tokens class DartTokenizer(PreTrainedTokenizerFast): """Dart tokenizer""" def __init__(self, **kwargs): super().__init__(**kwargs) self._tokenizer.decoder = Decoder.custom( # type: ignore DartDecoder(list(self.get_added_vocab().keys())) ) @property def default_chat_template(self): """ Danbooru Tags Transformer uses special format prompt to generate danbooru tags. """ return PROMPT_TEMPLATE