from fastapi import FastAPI from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import torch from pydantic import BaseModel, Field class RequestGenerate(BaseModel): prompt: str do_sample: bool = Field(default=True, example=True) top_k: int = Field(default=1, example=1), temperature: float = Field(default=0.9, example=0.9), max_new_tokens: int = Field(default=500, example=500), repetition_penalty: float = Field(default=1.5, example=1.5), app = FastAPI() @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/generate") def generate(req: RequestGenerate): model_name_or_id = "AI4Chem/ChemLLM-7B-Chat" # model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5" model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True) inputs = tokenizer(req.prompt, return_tensors="pt") generation_config = GenerationConfig( do_sample=req.do_sample, top_k=req.top_k, temperature=req.temperature, max_new_tokens=req.max_new_tokens, repetition_penalty=req.repetition_penalty, pad_token_id=tokenizer.eos_token_id ) outputs = model.generate(**inputs, generation_config=generation_config) # print(tokenizer.decode(outputs[0], skip_special_tokens=True)) return {"text": tokenizer.decode(outputs[0], skip_special_tokens=True)}