File size: 5,735 Bytes
7e978bb
 
1b2a38a
7e978bb
0189af7
7e978bb
 
8653d1f
0189af7
094584b
0189af7
8653d1f
7e978bb
46a37bc
4a6f262
7e978bb
df2a65f
7e978bb
fb78f43
7e978bb
4b0103d
 
7e978bb
1b2a38a
c227359
e4e38b2
df2a65f
e4e38b2
46a37bc
e4e38b2
 
40c7d6d
e4e38b2
7e978bb
 
1b2a38a
 
 
 
 
7e978bb
1b2a38a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e978bb
1b2a38a
 
 
 
7e978bb
 
1b2a38a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e978bb
1b2a38a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e978bb
 
1b2a38a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gc
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os

# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)

# Define models
MODELS = {
    "atlas-flash-1215": {
        "name": "🦁 Atlas-Flash 1215",
        "sizes": {
            "1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
        },
        "emoji": "🦁",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_FLASH_1215",
    },
    "atlas-pro-0403": {
        "name": "πŸ† Atlas-Pro 0403",
        "sizes": {
            "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
        },
        "emoji": "πŸ†",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_PRO_0403",
    },
}

# Clear memory
def clear_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

# Load model
def load_model(model_key, model_size):
    try:
        clear_memory()
        
        # Unload previous model if any
        global current_model
        if current_model is not None:
            del current_model["model"]
            del current_model["tokenizer"]
            clear_memory()

        model_path = MODELS[model_key]["sizes"][model_size]
        
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )

        current_model.update({
            "tokenizer": tokenizer,
            "model": model,
            "config": {
                "name": f"{MODELS[model_key]['name']} {model_size}",
                "path": model_path,
                "system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
            }
        })
        return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
    except Exception as e:
        return f"❌ Error: {str(e)}"

# Respond to input
def respond(prompt, max_tokens, temperature, top_p, top_k):
    if not current_model["model"] or not current_model["tokenizer"]:
        return "⚠️ Please select and load a model first"

    try:
        system_prompt = current_model["config"]["system_prompt"]
        if not system_prompt:
            return "⚠️ System prompt not found for the selected model."

        full_prompt = f"{system_prompt}\n\n### Instruction:\n{prompt}\n\n### Response:"

        inputs = current_model["tokenizer"](
            full_prompt,
            return_tensors="pt",
            max_length=512,
            truncation=True,
            padding=True
        )
        with torch.no_grad():
            output = current_model["model"].generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                do_sample=True,
                pad_token_id=current_model["tokenizer"].pad_token_id,
                eos_token_id=current_model["tokenizer"].eos_token_id,
            )
            response = current_model["tokenizer"].decode(output[0], skip_special_tokens=True)

        if full_prompt in response:
            response = response.replace(full_prompt, "").strip()

        return response
    except Exception as e:
        return f"⚠️ Generation Error: {str(e)}"
    finally:
        clear_memory()

# Initialize model storage
current_model = {"tokenizer": None, "model": None, "config": None}

# UI for Gradio
def gradio_ui():
    def load_and_set_model(model_key, model_size):
        return load_model(model_key, model_size)
    
    with gr.Blocks() as app:
        gr.Markdown("## 🦁 Atlas Inference Platform - Experimental πŸ§ͺ")
        
        with gr.Row():
            model_key_dropdown = gr.Dropdown(
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
                label="Select Model Variant",
                interactive=True
            )
            model_size_dropdown = gr.Dropdown(
                choices=list(MODELS[list(MODELS.keys())[0]]["sizes"].keys()),
                value="1.5B",
                label="Select Model Size",
                interactive=True
            )
            load_button = gr.Button("Load Model")

        load_status = gr.Textbox(label="Model Load Status", interactive=False)

        load_button.click(
            load_and_set_model,
            inputs=[model_key_dropdown, model_size_dropdown],
            outputs=load_status,
        )

        with gr.Row():
            prompt_input = gr.Textbox(label="Input Prompt", lines=4)
            max_tokens_slider = gr.Slider(10, 512, value=256, step=10, label="Max Tokens")
            temperature_slider = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
            top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-P")
            top_k_slider = gr.Slider(1, 100, value=50, step=1, label="Top-K")
        
        generate_button = gr.Button("Generate Response")
        response_output = gr.Textbox(label="Model Response", lines=6, interactive=False)

        generate_button.click(
            respond,
            inputs=[prompt_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider],
            outputs=response_output,
        )

    return app

if __name__ == "__main__":
    gradio_ui().launch()