File size: 5,794 Bytes
cc5b602
6952a60
6f619d7
 
 
85c3eca
6f619d7
 
aa2a2cc
 
 
6f619d7
 
 
 
 
7eeaa8f
6952a60
7eeaa8f
 
 
aa2a2cc
7eeaa8f
 
aa2a2cc
 
 
 
5e79225
417f21a
a9fe0e7
 
7eeaa8f
6f619d7
7eeaa8f
85585d6
398b913
51a7d9e
85c845b
 
51a7d9e
 
e6367a7
c9a8043
51a7d9e
bd34f0b
 
86bea01
bd34f0b
5c3a975
 
 
bd34f0b
 
 
51a7d9e
7eeaa8f
 
 
aa2a2cc
 
7eeaa8f
 
 
9cfb768
7eeaa8f
 
 
29a2985
 
 
 
8d17362
7eeaa8f
2272289
 
 
 
7eeaa8f
 
2272289
3518617
7eeaa8f
3518617
3a15f63
7eeaa8f
1cf4e84
 
7eeaa8f
48b2688
3a15f63
ab33f5f
0e2883b
fc09eb0
 
 
 
 
 
27af03d
3a15f63
42681ce
3a15f63
7eeaa8f
3518617
 
7eeaa8f
b8c0c7f
8d17362
85c845b
3a15f63
 
 
320fa57
 
 
3a15f63
 
 
 
 
 
 
 
 
 
 
 
 
 
33028cc
 
7eeaa8f
 
51a7d9e
c9a8043
51a7d9e
 
14a069f
51a7d9e
 
 
 
7eeaa8f
 
 
 
9c72529
51a7d9e
 
 
 
 
 
 
 
 
 
82b38de
51a7d9e
 
3569c20
51a7d9e
 
bd34f0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a7d9e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import signal
import threading
import time
import subprocess
import asyncio

OLLAMA = os.path.expanduser("~/ollama")
process = None
OLLAMA_SERVICE_THREAD = None

if not os.path.exists(OLLAMA):
    subprocess.run("curl -L https://ollama.com/download/ollama-linux-amd64 -o ~/ollama", shell=True)
    os.chmod(OLLAMA, 0o755)

def ollama_service_thread():
    global process
    process = subprocess.Popen("~/ollama serve", shell=True, preexec_fn=os.setsid)
    process.wait()
    
def terminate():
    global process, OLLAMA_SERVICE_THREAD
    if process:
        os.killpg(os.getpgid(process.pid), signal.SIGTERM)
    if OLLAMA_SERVICE_THREAD:
        OLLAMA_SERVICE_THREAD.join()
    process = None
    OLLAMA_SERVICE_THREAD = None
    print("Ollama service stopped.")

# Uncomment and modify the model to what you want locally
# model = "moondream" 
# model = os.environ.get("MODEL")

# subprocess.run(f"~/ollama pull {model}", shell=True)

import ollama
import gradio as gr
from ollama import AsyncClient
client = AsyncClient(host='http://localhost:11434', timeout=120)

HF_TOKEN = os.environ.get("HF_TOKEN", None)

TITLE = "<h1><center>Ollama Chat</center></h1>"

DESCRIPTION = f"""
<center>
<p>Feel free to test models with ollama.
<br>
First run please type <em>/init</em> to launch process.
<br>
Type <em>/pull model_name</em> to pull model.
</p>
</center>
"""

INIT_SIGN = ""

def init():
    global OLLAMA_SERVICE_THREAD
    OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
    OLLAMA_SERVICE_THREAD.start()
    print("Giving ollama serve a moment")
    time.sleep(10)
    global INIT_SIGN
    INIT_SIGN = "FINISHED"

def ollama_func(command):
    if " " in command:
        c1, c2 = command.split(" ")
    else:
        c1 = command
        c2 = ""
    function_map = {
        "/init": init,
        "/pull": lambda: ollama.pull(c2),
        "/list": ollama.list,
        "/bye": terminate,
    }
    if c1 in function_map:
        function_map.get(c1)()
        return "Running..."
    else:
        return "No supported command."

def launch():
    global OLLAMA_SERVICE_THREAD
    OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
    OLLAMA_SERVICE_THREAD.start()


async def stream_chat(message: str, history: list, model: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
    print(f"message: {message}")
    conversation = []
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])
    conversation.append({"role": "user", "content": message})
                
    print(f"Conversation is -\n{conversation}")
    
    if message.startswith("/"):
        resp = ollama_func(message)
        yield resp
    else:
        if not INIT_SIGN:
            yield "Please initialize Ollama"
        else:
            if not process:
                launch()
                print("Giving ollama serve a moment")
                time.sleep(10)    

            buffer = ""
            async for part in await client.chat(
                model=model,
                stream=True,
                messages=conversation,
                keep_alive="60s",
                options={
                    'num_predict': max_new_tokens,
                    'temperature': temperature,
                    'top_p': top_p,
                    'top_k': top_k,
                    'repeat_penalty': penalty,
                    'low_vram': True,
                    },
                ):
                buffer += part['message']['content']
                yield buffer

chatbot = gr.Chatbot(height=600, placeholder=DESCRIPTION)

with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML(TITLE)
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Textbox(
                value="qwen2:0.5b",
                label="Model",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=2048,
                step=1,
                value=1024,
                label="Max New Tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=0.8,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.0,
                label="Repetition penalty",
                render=False,
            ),
        ],
        examples=[
            ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
            ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
            ["Tell me a random fun fact about the Roman Empire."],
            ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()