fastapi_t5 / main.py
streetyogi's picture
Update main.py
7d22c1d
raw
history blame contribute delete
969 Bytes
import torch
from transformers import RobertaForMaskedLM, RobertaTokenizer
from fastapi import FastAPI, HTTPException
app = FastAPI()
# Load the pre-trained model and tokenizer
model = RobertaForMaskedLM.from_pretrained('roberta-base')
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# Load your dataset, in this case "cyberpunk_lore.txt"
with open("cyberpunk_lore.txt", "r") as f:
dataset = f.read()
# Train the model on your dataset
input_ids = torch.tensor([tokenizer.encode(dataset, add_special_tokens=True)])
model.train()
model.zero_grad()
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
loss.backward()
# Serve the model via FastAPI
@app.post("/predict")
def predict(prompt: str):
input_ids = torch.tensor([tokenizer.encode(prompt, add_special_tokens=True)])
outputs = model(input_ids)
generated_text = tokenizer.decode(outputs[0].argmax(dim=1).tolist()[0])
return {"generated_text": generated_text}