import os import torch from llama_cpp import Llama # Library for GGUF model handling from typing import Any, List, Dict class FixedVocabLogitsProcessor: """ A custom logits processor for GGUF-compatible models. """ def __init__(self, allowed_ids: set[int], fill_value=float('-inf')): self.allowed_ids = allowed_ids self.fill_value = fill_value def apply(self, logits: torch.FloatTensor): """ Modify logits to restrict to allowed token IDs. """ for token_id in range(len(logits)): if token_id not in self.allowed_ids: logits[token_id] = self.fill_value return logits class EndpointHandler: def __init__(self, path=""): """ Initialize the GGUF model handler. Args: path (str): Path to the GGUF file. """ self.model = Llama(model_path='/repository/model.gguf') self.tokenizer = self.model.tokenizer # GGUF-specific tokenizer, if available def __call__(self, data: Any) -> List[Dict[str, str]]: """ Handle the request, performing inference with a restricted vocabulary. Args: data (Any): Input data. Returns: List[Dict[str, str]]: Generated output. """ # Extract inputs and parameters inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) vocab_list = data.pop("vocab_list", None) if not vocab_list: raise ValueError("You must provide a 'vocab_list' to define allowed tokens.") # Define allowed tokens dynamically allowed_ids = set() for word in vocab_list: for tid in self.model.tokenize(word): allowed_ids.add(tid) # Tokenize input input_ids = self.model.tokenize(inputs) # Perform inference output_ids = self.model.generate( input_ids, max_tokens=parameters.get("max_length", 30), logits_processor=lambda logits: FixedVocabLogitsProcessor(allowed_ids).apply(logits) ) # Decode the output generated_text = self.model.detokenize(output_ids) return [{"generated_text": generated_text}]