Spaces:
Runtime error
Runtime error
import torch | |
from transformers import RobertaForMaskedLM, RobertaTokenizer | |
from fastapi import FastAPI, HTTPException | |
app = FastAPI() | |
# Load the pre-trained model and tokenizer | |
model = RobertaForMaskedLM.from_pretrained('roberta-base') | |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
# Load your dataset, in this case "cyberpunk_lore.txt" | |
with open("cyberpunk_lore.txt", "r") as f: | |
dataset = f.read() | |
# Train the model on your dataset | |
input_ids = torch.tensor([tokenizer.encode(dataset, add_special_tokens=True)]) | |
model.train() | |
model.zero_grad() | |
outputs = model(input_ids, labels=input_ids) | |
loss, logits = outputs[:2] | |
loss.backward() | |
# Serve the model via FastAPI | |
def predict(prompt: str): | |
input_ids = torch.tensor([tokenizer.encode(prompt, add_special_tokens=True)]) | |
outputs = model(input_ids) | |
generated_text = tokenizer.decode(outputs[0].argmax(dim=1).tolist()[0]) | |
return {"generated_text": generated_text} | |