mistral-base / handler.py
rwitz's picture
Update handler.py
f4e017b
raw
history blame
701 Bytes
import runpod
import os
import time
import torch
sleep_time = int(os.environ.get('SLEEP_TIME', 3))
# Use a pipeline as a high-level helper
from transformers import pipeline
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
pipe = pipeline("text-generation", model="rwitz/go-bruins-v2",device=0,torch_dtype=torch.bfloat16)
## load your model(s) into vram here
def handler(event):
inp = event["input"]
prompt=inp['prompt']
sampling_params=inp['sampling_params']
max=sampling_params['max_new_tokens']
temp=sampling_params['temperature']
return pipe(prompt,max_new_tokens=max,temperature=temp)
runpod.serverless.start({
"handler": handler
})