padmanabhbosamia's picture
Update app.py
b792a8c verified
import os
import gradio as gr
import torch
from train_get2_8_init import GPT, GPTConfig, generate_text, TrainingConfig
from huggingface_hub import hf_hub_download
from torch.serialization import add_safe_globals
# Add GPTConfig to safe globals
add_safe_globals([GPTConfig])
def load_trained_model():
config = TrainingConfig()
model_config = GPTConfig(
block_size=config.block_size,
n_layer=config.n_layer,
n_head=config.n_head,
n_embd=config.n_embd,
dropout=config.dropout
)
model = GPT(model_config)
model_path = hf_hub_download(
repo_id="padmanabhbosamia/Short_Shakesphere",
filename="best_model_compressed.pt",
token=os.getenv('HF_TOKEN')
)
checkpoint = torch.load(model_path, map_location=config.device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(config.device)
model.eval()
return model
def create_gradio_interface():
model = load_trained_model()
def predict(prompt, max_length, temperature=0.7):
return generate_text(model, prompt, max_length, temperature)
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(
lines=3,
label="Enter your prompt",
placeholder="Start typing here..."
),
gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Maximum Length"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature (Higher = more creative)"
)
],
outputs=gr.Textbox(lines=5, label="Generated Text"),
title="Custom GPT Text Generator (124M) based on Shakespeare",
description="A GPT-style language model trained on custom data by Shakespeare with 124M parameters"
)
return interface
# For Hugging Face Spaces
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch()