VerySmolTextGen / app.py
padmanabhbosamia's picture
Create app.py
ac8a885 verified
raw
history blame
3.4 kB
import os
import torch
import gradio as gr
from transformers import AutoTokenizer
from config import SmolLM2Config
from model import SmolLM2Lightning
def load_model(checkpoint_path):
"""Load the trained model from checkpoint"""
try:
config = SmolLM2Config("config.yaml")
model = SmolLM2Lightning.load_from_checkpoint(checkpoint_path, config=config)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
else:
print("Model loaded on CPU")
return model
except Exception as e:
print(f"Error loading model: {str(e)}")
return None
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
"""Generate text from prompt"""
try:
if model is None:
return "Model not loaded. Please check if checkpoint exists."
inputs = model.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=model.config.model.max_position_embeddings,
padding=True
)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=model.tokenizer.pad_token_id,
bos_token_id=model.tokenizer.bos_token_id,
eos_token_id=model.tokenizer.eos_token_id
)
return model.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error generating text: {str(e)}"
# Load the model
print("Loading model...")
checkpoint_path = "checkpoints/smol-lm2-final.ckpt"
if not os.path.exists(checkpoint_path):
print(f"Warning: Checkpoint not found at {checkpoint_path}")
print("Please train the model first or specify correct checkpoint path")
model = None
else:
model = load_model(checkpoint_path)
# Create Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
],
outputs=gr.Textbox(label="Generated Text"),
title="SmolLM2 Text Generation",
description="Enter a prompt and adjust generation parameters to create text with SmolLM2",
examples=[
["Explain what machine learning is:", 100, 0.7, 0.9, 50],
["Once upon a time", 150, 0.8, 0.9, 40],
["The best way to learn programming is", 120, 0.7, 0.9, 50]
]
)
if __name__ == "__main__":
print("Starting Gradio interface...")
# Simple launch configuration
demo.launch(
server_port=7860,
share=True
)