|
from fastapi import FastAPI |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
app = FastAPI() |
|
|
|
MODEL_NAME = "sbintuitions/sarashina2.2-0.5b" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
|
@app.get("/") |
|
def home(): |
|
return {"message": "Sarashina Model API is running on CPU!"} |
|
|
|
@app.get("/generate") |
|
def generate(text: str, max_length: int = 100): |
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_length=max_length) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return {"input": text, "output": generated_text} |
|
|