mistral-new / handler.py
rwitz's picture
Update handler.py
807f045
raw
history blame contribute delete
No virus
813 Bytes
import runpod
import os
import time
import torch
sleep_time = int(os.environ.get('SLEEP_TIME', 3))
# Use a pipeline as a high-level helper
from transformers import pipeline, AutoTokenizer
# Load model directly
tokenizer=AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1",skip_special_tokens=False)
pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1",tokenizer=tokenizer, device=0,torch_dtype=torch.bfloat16)
## load your model(s) into vram here
def handler(event):
inp = event["input"]
prompt=inp['prompt']
sampling_params=inp['sampling_params']
max=sampling_params['max_new_tokens']
temp=sampling_params['temperature']
return pipe(prompt,max_new_tokens=max,temperature=temp,pad_token_id = 50256,do_sample=True)
runpod.serverless.start({
"handler": handler
})