import torch from transformers import RobertaForMaskedLM, RobertaTokenizer from fastapi import FastAPI, HTTPException app = FastAPI() # Load the pre-trained model and tokenizer model = RobertaForMaskedLM.from_pretrained('roberta-base') tokenizer = RobertaTokenizer.from_pretrained('roberta-base') # Load your dataset, in this case "cyberpunk_lore.txt" with open("cyberpunk_lore.txt", "r") as f: dataset = f.read() # Train the model on your dataset input_ids = torch.tensor([tokenizer.encode(dataset, add_special_tokens=True)]) model.train() model.zero_grad() outputs = model(input_ids, labels=input_ids) loss, logits = outputs[:2] loss.backward() # Serve the model via FastAPI @app.post("/predict") def predict(prompt: str): input_ids = torch.tensor([tokenizer.encode(prompt, add_special_tokens=True)]) outputs = model(input_ids) generated_text = tokenizer.decode(outputs[0].argmax(dim=1).tolist()[0]) return {"generated_text": generated_text}