v1
Browse files
app.py
CHANGED
@@ -24,23 +24,21 @@ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENT
|
|
24 |
# accel
|
25 |
accel = Accelerator()
|
26 |
|
27 |
-
# model selection
|
28 |
-
link = "TroL-7B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B'
|
29 |
-
|
30 |
# User prompt
|
31 |
prompt_type="with_image" # Select one option "text_only", "with_image"
|
32 |
img_path='figures/demo.png'
|
33 |
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
|
34 |
|
35 |
# loading model
|
36 |
-
|
|
|
|
|
|
|
37 |
|
38 |
-
#
|
39 |
-
|
40 |
-
if not param.is_cuda:
|
41 |
-
param.data = param.to('cuda:0')
|
42 |
|
43 |
-
def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):
|
44 |
|
45 |
# propagation
|
46 |
_inputs = model.eval_process(inputs=inputs,
|
@@ -60,6 +58,16 @@ def threading_function(inputs, image_token_number, streamer, device, temperature
|
|
60 |
@spaces.GPU
|
61 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
try:
|
64 |
# prompt type -> input prompt
|
65 |
image_token_number = None
|
@@ -83,6 +91,8 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
83 |
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
|
84 |
image_token_number=image_token_number,
|
85 |
streamer=streamer,
|
|
|
|
|
86 |
device=accel.device,
|
87 |
temperature=temperature,
|
88 |
new_max_token=new_max_token,
|
@@ -115,7 +125,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
115 |
yield buffer
|
116 |
|
117 |
demo = gr.ChatInterface(fn=bot_streaming,
|
118 |
-
additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
|
119 |
additional_inputs_accordion="Generation Hyperparameters",
|
120 |
theme=gr.themes.Soft(),
|
121 |
title="TroL",
|
|
|
24 |
# accel
|
25 |
accel = Accelerator()
|
26 |
|
|
|
|
|
|
|
27 |
# User prompt
|
28 |
prompt_type="with_image" # Select one option "text_only", "with_image"
|
29 |
img_path='figures/demo.png'
|
30 |
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
|
31 |
|
32 |
# loading model
|
33 |
+
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
|
34 |
+
|
35 |
+
# loading model
|
36 |
+
model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
|
37 |
|
38 |
+
# loading model
|
39 |
+
model_7, tokenizer_7 = load_trol(link='TroL-7B')
|
|
|
|
|
40 |
|
41 |
+
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
42 |
|
43 |
# propagation
|
44 |
_inputs = model.eval_process(inputs=inputs,
|
|
|
58 |
@spaces.GPU
|
59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
60 |
|
61 |
+
if "1.8B" in link:
|
62 |
+
model = model_1_8
|
63 |
+
tokenizer = tokenizer_1_8
|
64 |
+
elif "3.8B" in link:
|
65 |
+
model = model_3_8
|
66 |
+
tokenizer = tokenizer_3_8
|
67 |
+
elif "7B" in link:
|
68 |
+
model = model_7
|
69 |
+
tokenizer = tokenizer_7
|
70 |
+
|
71 |
try:
|
72 |
# prompt type -> input prompt
|
73 |
image_token_number = None
|
|
|
91 |
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
|
92 |
image_token_number=image_token_number,
|
93 |
streamer=streamer,
|
94 |
+
model=model,
|
95 |
+
tokenizer=tokenizer,
|
96 |
device=accel.device,
|
97 |
temperature=temperature,
|
98 |
new_max_token=new_max_token,
|
|
|
125 |
yield buffer
|
126 |
|
127 |
demo = gr.ChatInterface(fn=bot_streaming,
|
128 |
+
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
|
129 |
additional_inputs_accordion="Generation Hyperparameters",
|
130 |
theme=gr.themes.Soft(),
|
131 |
title="TroL",
|