mistral-base / handler.py
rwitz's picture
Update handler.py
326fa19
raw
history blame
No virus
640 Bytes
import runpod
import os
import time
import torch
sleep_time = int(os.environ.get('SLEEP_TIME', 3))
# Use a pipeline as a high-level helper
from transformers import pipeline
# Load model directly
pipe = pipeline("text-generation", model="rwitz/go-bruins-v2",device=0,torch_dtype=torch.bfloat16)
## load your model(s) into vram here
def handler(event):
inp = event["input"]
prompt=inp['prompt']
sampling_params=inp['sampling_params']
max=sampling_params['max_new_tokens']
temp=sampling_params['temperature']
return pipe(prompt,max_new_tokens=max,temperature=temp)
runpod.serverless.start({
"handler": handler
})