Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from tiktoken import get_encoding | |
from model import GPT, GPTConfig # Replace with your actual model file/module | |
# Load the GPT-2 tokenizer | |
tokenizer = get_encoding("gpt2") | |
# Load your custom model (adjust as necessary for your model's implementation) | |
model_path = "model.pth" # Replace with the path to your model weights | |
model = GPT(GPTConfig()) # Initialize your custom model | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
model.eval() # Set the model to evaluation mode | |
# Function to tokenize input and generate text | |
def generate_text(prompt, max_length=50): | |
# Tokenize the input | |
input_ids = tokenizer.encode(prompt) | |
input_tensor = torch.tensor([input_ids]) # Add batch dimension | |
# Generate text using the model | |
with torch.no_grad(): | |
output_ids = model.generate(input_tensor, max_length=max_length) # Adjust if your model uses another method | |
# Decode the output back to text | |
generated_text = tokenizer.decode(output_ids[0].tolist()) | |
return generated_text | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Custom Transformer Text Generation") | |
gr.Markdown("Provide an input text prompt, and the model will generate text based on it.") | |
with gr.Row(): | |
input_text = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...", lines=2) | |
max_len = gr.Slider(label="Max Output Length", minimum=10, maximum=100, value=50, step=5) | |
output_text = gr.Textbox(label="Generated Text", lines=5) | |
generate_button = gr.Button("Generate") | |
generate_button.click(generate_text, inputs=[input_text, max_len], outputs=output_text) | |
# Run the app | |
if __name__ == "__main__": | |
demo.launch() | |