beam-app / app.py
Greums's picture
add instructions in readme
cc53151
raw
history blame
3.83 kB
from threading import Thread
import torch
from beam import Image, Volume, GpuType, asgi
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import (
AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer,
PreTrainedTokenizerFast, PreTrainedModel, StoppingCriteriaList
)
from utils import MaxPostsStoppingCriteria, Body, fallback
SETTINGS = {
"model_name": "Error410/JVCGPT-Medium",
"beam_volume_path": "./cached_models",
}
# @see https://huggingface.co./docs/transformers/generation_strategies#customize-text-generation
DEFAULTS = {
"max_length": 2048, # 512
"temperature": 0.9, # 1
"top_p": 1, # 0.95
"top_k": 0, # 40
"repetition_penalty": 1.0, # 1.0
"no_repeat_ngram_size": 0, # 0
"do_sample": True, # True
}
def load_models():
tokenizer = AutoTokenizer.from_pretrained(
SETTINGS["model_name"],
cache_dir=SETTINGS["beam_volume_path"]
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
SETTINGS["model_name"],
device_map="auto",
torch_dtype=torch.float16,
cache_dir=SETTINGS["beam_volume_path"],
)
return model, tokenizer
def stream(model: PreTrainedModel, tokenizer: PreTrainedTokenizerFast, body: Body):
generate_args = {
"max_length": fallback(body.max_length, DEFAULTS["max_length"]),
"temperature": fallback(body.temperature, DEFAULTS["temperature"]),
"top_p": fallback(body.top_p, DEFAULTS["top_p"]),
"top_k": fallback(body.top_k, DEFAULTS["top_k"]),
"repetition_penalty": fallback(body.repetition_penalty, DEFAULTS["repetition_penalty"]),
"no_repeat_ngram_size": fallback(body.no_repeat_ngram_size, DEFAULTS["no_repeat_ngram_size"]),
"do_sample": fallback(body.do_sample, DEFAULTS["do_sample"]),
"use_cache": True,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
inputs = tokenizer(body.prompt, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to("cuda")
attention_mask = inputs["attention_mask"].to("cuda")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=240)
# with torch.no_grad(): # seems to be useless
thread = Thread(
target=model.generate,
kwargs={
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([MaxPostsStoppingCriteria(tokenizer, body.posts_count)]),
**generate_args,
}
)
thread.start()
for token in streamer:
yield token
# if len(token) > 0:
# yield f"DATA {token}"
#
# yield "EOS"
@asgi(
name="jvcgpt",
on_start=load_models,
cpu=2.0,
memory="16Gi",
gpu=GpuType.A100_40,
gpu_count=1,
timeout=900, # Time for loading the model and run the server
image=Image(
python_version="python3.12",
python_packages=[
"fastapi",
"torch",
"transformers",
"accelerate",
"huggingface_hub[hf-transfer]",
],
env_vars=["HF_HUB_ENABLE_HF_TRANSFER=1"],
),
volumes=[
Volume(
name="cached_models",
mount_path=SETTINGS["beam_volume_path"],
)
],
)
def server(context):
model, tokenizer = context.on_start_value
app = FastAPI()
@app.post("/stream")
async def stream_endpoint(body: Body) -> StreamingResponse:
return StreamingResponse(
stream(model, tokenizer, body),
media_type='text/event-stream',
headers={"Cache-Control": "no-cache"},
)
return app