infidea commited on
Commit
3097d1f
·
1 Parent(s): ac9d7ff

add generation code

Browse files
Files changed (2) hide show
  1. app.py +34 -3
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,4 +1,16 @@
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
@@ -6,6 +18,25 @@ app = FastAPI()
6
  def greet_json():
7
  return {"Hello": "World!"}
8
 
9
- @app.post("/")
10
- def greet_json():
11
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
3
+ import torch
4
+ from pydantic import BaseModel, Field
5
+
6
+ class Request(BaseModel):
7
+ prompt: str
8
+ response: str = Field(description="")
9
+ do_sample: bool=True,
10
+ top_k: int =1,
11
+ temperature: float=0.9,
12
+ max_new_tokens: int=500,
13
+ repetition_penalty: float=1.5,
14
 
15
  app = FastAPI()
16
 
 
18
  def greet_json():
19
  return {"Hello": "World!"}
20
 
21
+ @app.post("/generate")
22
+ def generate(req: Request):
23
+ model_name_or_id = "AI4Chem/ChemLLM-7B-Chat"
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_id,trust_remote_code=True)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_id,trust_remote_code=True)
27
+
28
+ inputs = tokenizer(req.prompt, return_tensors="pt")
29
+
30
+ generation_config = GenerationConfig(
31
+ do_sample=req.do_sample,
32
+ top_k=req.top_k,
33
+ temperature=req.temperature,
34
+ max_new_tokens=req.max_new_tokens,
35
+ repetition_penalty=req.repetition_penalty,
36
+ pad_token_id=tokenizer.eos_token_id
37
+ )
38
+
39
+ outputs = model.generate(**inputs, generation_config=generation_config)
40
+ # print(tokenizer.decode(outputs[0], skip_special_tokens=True))
41
+
42
+ return {"text": tokenizer.decode(outputs[0], skip_special_tokens=True)}
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ transformers
4
+ torch