Llama-3.2-1B / handler.py
taylorj94's picture
Update handler.py
0b80774 verified
raw
history blame
2.26 kB
import os
import torch
from llama_cpp import Llama # Library for GGUF model handling
from typing import Any, List, Dict
class FixedVocabLogitsProcessor:
"""
A custom logits processor for GGUF-compatible models.
"""
def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
self.allowed_ids = allowed_ids
self.fill_value = fill_value
def apply(self, logits: torch.FloatTensor):
"""
Modify logits to restrict to allowed token IDs.
"""
for token_id in range(len(logits)):
if token_id not in self.allowed_ids:
logits[token_id] = self.fill_value
return logits
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the GGUF model handler.
Args:
path (str): Path to the GGUF file.
"""
self.model = Llama(model_path='/repository/model.gguf')
self.tokenizer = self.model.tokenizer # GGUF-specific tokenizer, if available
def __call__(self, data: Any) -> List[Dict[str, str]]:
"""
Handle the request, performing inference with a restricted vocabulary.
Args:
data (Any): Input data.
Returns:
List[Dict[str, str]]: Generated output.
"""
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
vocab_list = data.pop("vocab_list", None)
if not vocab_list:
raise ValueError("You must provide a 'vocab_list' to define allowed tokens.")
# Define allowed tokens dynamically
allowed_ids = set()
for word in vocab_list:
for tid in self.model.tokenize(word):
allowed_ids.add(tid)
# Tokenize input
input_ids = self.model.tokenize(inputs)
# Perform inference
output_ids = self.model.generate(
input_ids,
max_tokens=parameters.get("max_length", 30),
logits_processor=lambda logits: FixedVocabLogitsProcessor(allowed_ids).apply(logits)
)
# Decode the output
generated_text = self.model.detokenize(output_ids)
return [{"generated_text": generated_text}]