File size: 3,399 Bytes
ac8a885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
import gradio as gr
from transformers import AutoTokenizer
from config import SmolLM2Config
from model import SmolLM2Lightning

def load_model(checkpoint_path):
    """Load the trained model from checkpoint"""
    try:
        config = SmolLM2Config("config.yaml")
        model = SmolLM2Lightning.load_from_checkpoint(checkpoint_path, config=config)
        model.eval()
        
        if torch.cuda.is_available():
            model = model.cuda()
            print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
        else:
            print("Model loaded on CPU")
            
        return model
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
    """Generate text from prompt"""
    try:
        if model is None:
            return "Model not loaded. Please check if checkpoint exists."
            
        inputs = model.tokenizer(
            prompt, 
            return_tensors="pt",
            truncation=True,
            max_length=model.config.model.max_position_embeddings,
            padding=True
        )
        
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                do_sample=True,
                pad_token_id=model.tokenizer.pad_token_id,
                bos_token_id=model.tokenizer.bos_token_id,
                eos_token_id=model.tokenizer.eos_token_id
            )
        
        return model.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
    except Exception as e:
        return f"Error generating text: {str(e)}"

# Load the model
print("Loading model...")
checkpoint_path = "checkpoints/smol-lm2-final.ckpt"
if not os.path.exists(checkpoint_path):
    print(f"Warning: Checkpoint not found at {checkpoint_path}")
    print("Please train the model first or specify correct checkpoint path")
    model = None
else:
    model = load_model(checkpoint_path)

# Create Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
        gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"),
        gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="SmolLM2 Text Generation",
    description="Enter a prompt and adjust generation parameters to create text with SmolLM2",
    examples=[
        ["Explain what machine learning is:", 100, 0.7, 0.9, 50],
        ["Once upon a time", 150, 0.8, 0.9, 40],
        ["The best way to learn programming is", 120, 0.7, 0.9, 50]
    ]
)

if __name__ == "__main__":
    print("Starting Gradio interface...")
    # Simple launch configuration
    demo.launch(
        server_port=7860,
        share=True
    )