import logging
from typing import List
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
# fmt: off
# https://huggingface.co./docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
"{{ '<|bos|>' }}"
"{{ '' }}"
"{% if 'rating' not in messages or messages['rating'] is none %}"
"{{ 'rating:sfw, rating:general' }}"
"{% else %}"
"{{ messages['rating'] }}"
"{% endif %}"
"{{ '' }}"
"{{ '' }}"
"{% if 'copyright' not in messages or messages['copyright'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['copyright'] }}"
"{% endif %}"
"{{ '' }}"
"{{ '' }}"
"{% if 'character' not in messages or messages['character'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['character'] }}"
"{% endif %}"
"{{ '' }}"
"{{ '' }}"
# length token
"{% if 'length' not in messages or messages['length'] is none %}"
"{{ '<|long|>' }}"
"{% else %}"
"{{ messages['length'] }}"
"{% endif %}"
# general token
"{% if 'general' not in messages or messages['general'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['general'] }}"
"{% endif %}"
"{{ '<|input_end|>' }}"
).strip()
# fmt: on
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
@property
def default_chat_template(self):
"""
Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
"""
return PROMPT_TEMPLATE