infidea commited on
Commit
92758a0
·
1 Parent(s): 3097d1f

fix generation code

Browse files
Files changed (3) hide show
  1. __pycache__/app.cpython-310.pyc +0 -0
  2. app.py +8 -8
  3. requirements.txt +2 -0
__pycache__/app.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
app.py CHANGED
@@ -3,14 +3,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
3
  import torch
4
  from pydantic import BaseModel, Field
5
 
6
- class Request(BaseModel):
7
  prompt: str
8
- response: str = Field(description="")
9
- do_sample: bool=True,
10
- top_k: int =1,
11
- temperature: float=0.9,
12
- max_new_tokens: int=500,
13
- repetition_penalty: float=1.5,
14
 
15
  app = FastAPI()
16
 
@@ -19,8 +18,9 @@ def greet_json():
19
  return {"Hello": "World!"}
20
 
21
  @app.post("/generate")
22
- def generate(req: Request):
23
  model_name_or_id = "AI4Chem/ChemLLM-7B-Chat"
 
24
 
25
  model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True)
26
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True)
 
3
  import torch
4
  from pydantic import BaseModel, Field
5
 
6
+ class RequestGenerate(BaseModel):
7
  prompt: str
8
+ do_sample: bool = Field(default=True, example=True)
9
+ top_k: int = Field(default=1, example=1),
10
+ temperature: float = Field(default=0.9, example=0.9),
11
+ max_new_tokens: int = Field(default=500, example=500),
12
+ repetition_penalty: float = Field(default=1.5, example=1.5),
 
13
 
14
  app = FastAPI()
15
 
 
18
  return {"Hello": "World!"}
19
 
20
  @app.post("/generate")
21
+ def generate(req: RequestGenerate):
22
  model_name_or_id = "AI4Chem/ChemLLM-7B-Chat"
23
+ # model_name_or_id = "AI4Chem/CHEMLLM-2b-1_5"
24
 
25
  model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True)
26
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True)
requirements.txt CHANGED
@@ -2,3 +2,5 @@ fastapi
2
  uvicorn[standard]
3
  transformers
4
  torch
 
 
 
2
  uvicorn[standard]
3
  transformers
4
  torch
5
+ einops
6
+ sentencepiece