from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch class EndpointHandler: def __init__(self, path="google/flan-t5-large"): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForSeq2SeqLM.from_pretrained(path) def __call__(self, data): """ Args: data: (dict): A dictionary with a "inputs" key containing the text to process """ inputs = data.pop("inputs", data) # Parameters for text generation parameters = { "max_length": 512, "min_length": 32, "temperature": 0.9, "top_p": 0.95, "top_k": 50, "do_sample": True, "num_return_sequences": 1 } # Update parameters if provided in the request parameters.update(data) # Tokenize the input input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids # Generate the response outputs = self.model.generate(input_ids, **parameters) # Decode the response generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated_text}