gumi / app.py
TKgumi's picture
Create app.py
02955e7 verified
raw
history blame
854 Bytes
from fastapi import FastAPI
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# FastAPI アプリを作成
app = FastAPI()
# モデルのロード(4bit量子化)
model_name = "Qwen/Qwen2.5-0.5B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True, # 4bit量子化
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# API エンドポイント
@app.get("/")
def home():
return {"message": "Qwen2.5-0.5B API is running with 4-bit quantization"}
@app.post("/generate")
def generate_text(prompt: str, max_length: int = 50):
text = pipeline("text-generation", model=model, tokenizer=tokenizer)(
prompt, max_length=max_length, do_sample=True, pad_token_id=tokenizer.pad_token_id
)
return {"generated_text": text[0]["generated_text"]}