from pydantic import BaseModel from transformers import (PreTrainedTokenizerFast, StoppingCriteria) def fallback(value, fallback_value): if value is None: return fallback_value return value class Body(BaseModel): prompt: str posts_count: int max_length: int | None = None temperature: float | None = None top_p: float | None = None top_k: float | None = None repetition_penalty: float | None = None no_repeat_ngram_size: float | None = None do_sample: bool | None = None class MaxPostsStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer: PreTrainedTokenizerFast, posts_count: int): self.end_of_post_token_id = tokenizer.encode("<|end_of_post|>", add_special_tokens=False) self.posts_count = posts_count self.counter = 0 def __call__(self, input_ids, scores, **kwargs): # Check if the last token matches the <|end_of_post|> token ID for sequence in input_ids: if sequence[-len(self.end_of_post_token_id):].tolist() == self.end_of_post_token_id: self.counter += 1 if self.counter >= self.posts_count: return True return False