File size: 640 Bytes
a5cd593
 
 
c81cf4c
a5cd593
 
 
de1c5bf
a5cd593
f4e017b
a5cd593
 
 
 
 
ccf9578
a5cd593
 
471dcc4
a5cd593
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import runpod
import os
import time
import torch
sleep_time = int(os.environ.get('SLEEP_TIME', 3))
# Use a pipeline as a high-level helper
from transformers import pipeline
# Load model directly

pipe = pipeline("text-generation", model="rwitz/go-bruins-v2",device=0,torch_dtype=torch.bfloat16)
## load your model(s) into vram here

def handler(event):
    inp = event["input"]
    prompt=inp['prompt']
    sampling_params=inp['sampling_params']
    max=sampling_params['max_new_tokens']
    temp=sampling_params['temperature']
    return pipe(prompt,max_new_tokens=max,temperature=temp)

runpod.serverless.start({
    "handler": handler
})