from threading import Thread import torch from beam import Image, Volume, GpuType, asgi from fastapi import FastAPI from fastapi.responses import StreamingResponse from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, PreTrainedTokenizerFast, PreTrainedModel, StoppingCriteriaList ) from utils import MaxPostsStoppingCriteria, Body, fallback SETTINGS = { "model_name": "Error410/JVCGPT-Medium", "beam_volume_path": "./cached_models", } # @see https://huggingface.co./docs/transformers/generation_strategies#customize-text-generation DEFAULTS = { "max_length": 2048, # 512 "temperature": 0.9, # 1 "top_p": 1, # 0.95 "top_k": 0, # 40 "repetition_penalty": 1.0, # 1.0 "no_repeat_ngram_size": 0, # 0 "do_sample": True, # True } def load_models(): tokenizer = AutoTokenizer.from_pretrained( SETTINGS["model_name"], cache_dir=SETTINGS["beam_volume_path"] ) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( SETTINGS["model_name"], device_map="auto", torch_dtype=torch.float16, cache_dir=SETTINGS["beam_volume_path"], ) return model, tokenizer def stream(model: PreTrainedModel, tokenizer: PreTrainedTokenizerFast, body: Body): generate_args = { "max_length": fallback(body.max_length, DEFAULTS["max_length"]), "temperature": fallback(body.temperature, DEFAULTS["temperature"]), "top_p": fallback(body.top_p, DEFAULTS["top_p"]), "top_k": fallback(body.top_k, DEFAULTS["top_k"]), "repetition_penalty": fallback(body.repetition_penalty, DEFAULTS["repetition_penalty"]), "no_repeat_ngram_size": fallback(body.no_repeat_ngram_size, DEFAULTS["no_repeat_ngram_size"]), "do_sample": fallback(body.do_sample, DEFAULTS["do_sample"]), "use_cache": True, "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id, } inputs = tokenizer(body.prompt, return_tensors="pt", padding=True) input_ids = inputs["input_ids"].to("cuda") attention_mask = inputs["attention_mask"].to("cuda") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False, timeout=240) # with torch.no_grad(): # seems to be useless thread = Thread( target=model.generate, kwargs={ "input_ids": input_ids, "attention_mask": attention_mask, "streamer": streamer, "stopping_criteria": StoppingCriteriaList([MaxPostsStoppingCriteria(tokenizer, body.posts_count)]), **generate_args, } ) thread.start() for token in streamer: yield token # if len(token) > 0: # yield f"DATA {token}" # # yield "EOS" @asgi( name="jvcgpt", on_start=load_models, cpu=2.0, memory="16Gi", gpu=GpuType.A100_40, gpu_count=1, timeout=5*60, # Time for loading the model and run the server keep_warm_seconds=5*60, image=Image( python_version="python3.12", python_packages=[ "fastapi", "torch", "transformers", "accelerate", "huggingface_hub[hf-transfer]", ], env_vars=["HF_HUB_ENABLE_HF_TRANSFER=1"], ), volumes=[ Volume( name="cached_models", mount_path=SETTINGS["beam_volume_path"], ) ], ) def server(context): model, tokenizer = context.on_start_value app = FastAPI() @app.post("/stream") async def stream_endpoint(body: Body) -> StreamingResponse: return StreamingResponse( stream(model, tokenizer, body), media_type='text/event-stream', headers={"Cache-Control": "no-cache"}, ) return app