TKgumi commited on
Commit
848832b
·
verified ·
1 Parent(s): 93cb2ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -1,7 +1,23 @@
1
  from fastapi import FastAPI
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/generate")
6
- def generate(text: str):
7
- return {"output": text}
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  app = FastAPI()
6
 
7
+ MODEL_NAME = "sbintuitions/sarashina2.2-0.5b"
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ # CPUでモデルをロード(デフォルトは torch.float32)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
11
+
12
+ @app.get("/")
13
+ def home():
14
+ return {"message": "Sarashina Model API is running on CPU!"}
15
+
16
  @app.get("/generate")
17
+ def generate(text: str, max_length: int = 100):
18
+ inputs = tokenizer(text, return_tensors="pt")
19
+ # CPU環境ではモデルも自動的にCPUで動作します
20
+ with torch.no_grad():
21
+ output = model.generate(**inputs, max_length=max_length)
22
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
23
+ return {"input": text, "output": generated_text}