dart-v1-sft / tokenization_dart.py
p1atdev's picture
Upload tokenization_dart.py
0df470a verified
raw
history blame
2.99 kB
import logging
import json
from typing import Dict, List
from pydantic.dataclasses import dataclass
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
# fmt: off
# https://huggingface.co./docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
"{{ '<|bos|>' }}"
"{{ '<rating>' }}"
"{% if 'rating' not in messages or messages['rating'] is none %}"
"{{ 'rating:sfw, rating:general' }}"
"{% else %}"
"{{ messages['rating'] }}"
"{% endif %}"
"{{ '</rating>' }}"
"{{ '<copyright>' }}"
"{% if 'copyright' not in messages or messages['copyright'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['copyright'] }}"
"{% endif %}"
"{{ '</copyright>' }}"
"{{ '<character>' }}"
"{% if 'character' not in messages or messages['character'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['character'] }}"
"{% endif %}"
"{{ '</character>' }}"
"{{ '<general>' }}"
# length token
"{% if 'length' not in messages or messages['length'] is none %}"
"{{ '<|long|>' }}"
"{% else %}"
"{{ messages['length'] }}"
"{% endif %}"
# general token
"{% if 'general' not in messages or messages['general'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['general'] }}"
"{% endif %}"
"{{ '<|input_end|>' }}"
).strip()
# fmt: on
@dataclass
class Category:
name: str
bos_token_id: int
eos_token_id: int
@dataclass
class TagCategoryConfig:
categories: Dict[str, Category]
category_to_token_ids: Dict[str, List[int]]
def load_tag_category_config(config_json: str):
with open(config_json, "rb") as file:
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
return config
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
@property
def default_chat_template(self):
"""
Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
"""
return PROMPT_TEMPLATE