File size: 813 Bytes
a5cd593
 
 
c81cf4c
a5cd593
 
f2c19e5
de1c5bf
f2c19e5
 
a5cd593
 
 
 
 
ccf9578
a5cd593
 
807f045
a5cd593
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import runpod
import os
import time
import torch
sleep_time = int(os.environ.get('SLEEP_TIME', 3))
# Use a pipeline as a high-level helper
from transformers import pipeline, AutoTokenizer
# Load model directly
tokenizer=AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1",skip_special_tokens=False)
pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1",tokenizer=tokenizer, device=0,torch_dtype=torch.bfloat16)
## load your model(s) into vram here

def handler(event):
    inp = event["input"]
    prompt=inp['prompt']
    sampling_params=inp['sampling_params']
    max=sampling_params['max_new_tokens']
    temp=sampling_params['temperature']
    return pipe(prompt,max_new_tokens=max,temperature=temp,pad_token_id = 50256,do_sample=True)

runpod.serverless.start({
    "handler": handler
})