|
import logging |
|
import json |
|
from typing import Dict, List |
|
from pydantic.dataclasses import dataclass |
|
|
|
from transformers import PreTrainedTokenizerFast |
|
from tokenizers.decoders import Decoder |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
PROMPT_TEMPLATE = ( |
|
"{{ '<|bos|>' }}" |
|
|
|
"{{ '<rating>' }}" |
|
"{% if 'rating' not in messages or messages['rating'] is none %}" |
|
"{{ 'rating:sfw, rating:general' }}" |
|
"{% else %}" |
|
"{{ messages['rating'] }}" |
|
"{% endif %}" |
|
"{{ '</rating>' }}" |
|
|
|
"{{ '<copyright>' }}" |
|
"{% if 'copyright' not in messages or messages['copyright'] is none %}" |
|
"{{ '' }}" |
|
"{% else %}" |
|
"{{ messages['copyright'] }}" |
|
"{% endif %}" |
|
"{{ '</copyright>' }}" |
|
|
|
"{{ '<character>' }}" |
|
"{% if 'character' not in messages or messages['character'] is none %}" |
|
"{{ '' }}" |
|
"{% else %}" |
|
"{{ messages['character'] }}" |
|
"{% endif %}" |
|
"{{ '</character>' }}" |
|
|
|
"{{ '<general>' }}" |
|
|
|
"{% if 'length' not in messages or messages['length'] is none %}" |
|
"{{ '<|long|>' }}" |
|
"{% else %}" |
|
"{{ messages['length'] }}" |
|
"{% endif %}" |
|
|
|
|
|
"{% if 'general' not in messages or messages['general'] is none %}" |
|
"{{ '' }}" |
|
"{% else %}" |
|
"{{ messages['general'] }}" |
|
"{% endif %}" |
|
"{{ '<|input_end|>' }}" |
|
).strip() |
|
|
|
|
|
|
|
@dataclass |
|
class Category: |
|
name: str |
|
bos_token_id: int |
|
eos_token_id: int |
|
|
|
|
|
@dataclass |
|
class TagCategoryConfig: |
|
categories: Dict[str, Category] |
|
category_to_token_ids: Dict[str, List[int]] |
|
|
|
|
|
def load_tag_category_config(config_json: str): |
|
with open(config_json, "rb") as file: |
|
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read())) |
|
|
|
return config |
|
|
|
|
|
class DartDecoder: |
|
def __init__(self, special_tokens: List[str]): |
|
self.special_tokens = list(special_tokens) |
|
|
|
def decode_chain(self, tokens: List[str]) -> List[str]: |
|
new_tokens = [] |
|
is_specials = [] |
|
|
|
for i, token in enumerate(tokens): |
|
is_specials.append(token in self.special_tokens) |
|
|
|
if i == 0: |
|
new_tokens.append(token) |
|
continue |
|
|
|
|
|
if is_specials[i] or is_specials[i - 1]: |
|
new_tokens.append(token) |
|
continue |
|
|
|
new_tokens.append(f", {token}") |
|
|
|
return new_tokens |
|
|
|
|
|
class DartTokenizer(PreTrainedTokenizerFast): |
|
"""Dart tokenizer""" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self._tokenizer.decoder = Decoder.custom( |
|
DartDecoder(list(self.get_added_vocab().keys())) |
|
) |
|
|
|
@property |
|
def default_chat_template(self): |
|
""" |
|
Danbooru Tags Transformer uses special format prompt to generate danbooru tags. |
|
""" |
|
|
|
return PROMPT_TEMPLATE |
|
|