Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
import torch | |
from pydantic import BaseModel, Field | |
class RequestGenerate(BaseModel): | |
prompt: str | |
do_sample: bool = Field(default=True, example=True) | |
top_k: int = Field(default=1, example=1), | |
temperature: float = Field(default=0.9, example=0.9), | |
max_new_tokens: int = Field(default=500, example=500), | |
repetition_penalty: float = Field(default=1.5, example=1.5), | |
app = FastAPI() | |
def greet_json(): | |
return {"Hello": "World!"} | |
def generate(req: RequestGenerate): | |
model_name_or_id = "AI4Chem/ChemLLM-7B-Chat" | |
# model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5" | |
model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True) | |
inputs = tokenizer(req.prompt, return_tensors="pt") | |
generation_config = GenerationConfig( | |
do_sample=req.do_sample, | |
top_k=req.top_k, | |
temperature=req.temperature, | |
max_new_tokens=req.max_new_tokens, | |
repetition_penalty=req.repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
outputs = model.generate(**inputs, generation_config=generation_config) | |
# print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
return {"text": tokenizer.decode(outputs[0], skip_special_tokens=True)} | |