lora_model / handler.py
abhijsrwala's picture
Create handler.py
73ece59 verified
raw
history blame
935 Bytes
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer globally
MODEL_NAME = "abhijsrwala/lora_model"
def load_model():
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
return model, tokenizer
# Load model once to avoid reloading on every request
model, tokenizer = load_model()
def handle_request(input_data):
"""
Handles inference requests.
Args:
input_data (str): The input text prompt.
Returns:
str: The model's response.
"""
# Tokenize the input text
inputs = tokenizer.encode(input_data, return_tensors="pt")
# Generate text
outputs = model.generate(inputs, max_length=200, num_return_sequences=1)
# Decode the output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response