v1
Browse files
app.py
CHANGED
@@ -38,6 +38,8 @@ model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
|
|
38 |
# loading model
|
39 |
model_7, tokenizer_7 = load_trol(link='TroL-7B')
|
40 |
|
|
|
|
|
41 |
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
42 |
|
43 |
# propagation
|
@@ -58,6 +60,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
|
|
58 |
@spaces.GPU
|
59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
60 |
|
|
|
61 |
if "1.8B" in link:
|
62 |
model = model_1_8
|
63 |
tokenizer = tokenizer_1_8
|
@@ -67,14 +70,19 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
67 |
elif "7B" in link:
|
68 |
model = model_7
|
69 |
tokenizer = tokenizer_7
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
71 |
try:
|
72 |
# prompt type -> input prompt
|
73 |
image_token_number = None
|
74 |
if len(message['files']) != 0:
|
75 |
# Image Load
|
76 |
image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
|
77 |
-
if
|
78 |
image_token_number = 1225
|
79 |
image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
|
80 |
inputs = [{'image': image, 'question': message['text']}]
|
@@ -129,7 +137,8 @@ demo = gr.ChatInterface(fn=bot_streaming,
|
|
129 |
additional_inputs_accordion="Generation Hyperparameters",
|
130 |
theme=gr.themes.Soft(),
|
131 |
title="TroL",
|
132 |
-
description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy
|
133 |
-
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)"
|
|
|
134 |
stop_btn="Stop Generation", multimodal=True)
|
135 |
demo.launch()
|
|
|
38 |
# loading model
|
39 |
model_7, tokenizer_7 = load_trol(link='TroL-7B')
|
40 |
|
41 |
+
print()
|
42 |
+
|
43 |
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
44 |
|
45 |
# propagation
|
|
|
60 |
@spaces.GPU
|
61 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
62 |
|
63 |
+
# model selection
|
64 |
if "1.8B" in link:
|
65 |
model = model_1_8
|
66 |
tokenizer = tokenizer_1_8
|
|
|
70 |
elif "7B" in link:
|
71 |
model = model_7
|
72 |
tokenizer = tokenizer_7
|
73 |
+
|
74 |
+
# cpu -> gpu
|
75 |
+
for param in model.parameters():
|
76 |
+
if not param.is_cuda:
|
77 |
+
param.data = param.to(accel.device)
|
78 |
+
|
79 |
try:
|
80 |
# prompt type -> input prompt
|
81 |
image_token_number = None
|
82 |
if len(message['files']) != 0:
|
83 |
# Image Load
|
84 |
image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
|
85 |
+
if "3.8B" not in link:
|
86 |
image_token_number = 1225
|
87 |
image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
|
88 |
inputs = [{'image': image, 'question': message['text']}]
|
|
|
137 |
additional_inputs_accordion="Generation Hyperparameters",
|
138 |
theme=gr.themes.Soft(),
|
139 |
title="TroL",
|
140 |
+
description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. "
|
141 |
+
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) "
|
142 |
+
"Note that, we don't support history-based conversation referring to previous dialogue",
|
143 |
stop_btn="Stop Generation", multimodal=True)
|
144 |
demo.launch()
|