|
import os |
|
import gradio as gr |
|
import torch |
|
from train_get2_8_init import GPT, GPTConfig, generate_text, TrainingConfig |
|
from huggingface_hub import hf_hub_download |
|
from torch.serialization import add_safe_globals |
|
|
|
|
|
add_safe_globals([GPTConfig]) |
|
|
|
def load_trained_model(): |
|
config = TrainingConfig() |
|
model_config = GPTConfig( |
|
block_size=config.block_size, |
|
n_layer=config.n_layer, |
|
n_head=config.n_head, |
|
n_embd=config.n_embd, |
|
dropout=config.dropout |
|
) |
|
|
|
model = GPT(model_config) |
|
model_path = hf_hub_download( |
|
repo_id="padmanabhbosamia/Short_Shakesphere", |
|
filename="best_model_compressed.pt", |
|
token=os.getenv('HF_TOKEN') |
|
) |
|
checkpoint = torch.load(model_path, map_location=config.device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(config.device) |
|
model.eval() |
|
return model |
|
|
|
def create_gradio_interface(): |
|
model = load_trained_model() |
|
|
|
def predict(prompt, max_length, temperature=0.7): |
|
return generate_text(model, prompt, max_length, temperature) |
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Textbox( |
|
lines=3, |
|
label="Enter your prompt", |
|
placeholder="Start typing here..." |
|
), |
|
gr.Slider( |
|
minimum=10, |
|
maximum=500, |
|
value=100, |
|
step=10, |
|
label="Maximum Length" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature (Higher = more creative)" |
|
) |
|
], |
|
outputs=gr.Textbox(lines=5, label="Generated Text"), |
|
title="Custom GPT Text Generator (124M) based on Shakespeare", |
|
description="A GPT-style language model trained on custom data by Shakespeare with 124M parameters" |
|
) |
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_gradio_interface() |
|
interface.launch() |