|
from fastapi import FastAPI |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "Qwen/Qwen2.5-0.5B" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
load_in_4bit=True, |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
@app.get("/") |
|
def home(): |
|
return {"message": "Qwen2.5-0.5B API is running with 4-bit quantization"} |
|
|
|
@app.post("/generate") |
|
def generate_text(prompt: str, max_length: int = 50): |
|
text = pipeline("text-generation", model=model, tokenizer=tokenizer)( |
|
prompt, max_length=max_length, do_sample=True, pad_token_id=tokenizer.pad_token_id |
|
) |
|
return {"generated_text": text[0]["generated_text"]} |
|
|