BK-Lee commited on
Commit
c9961ab
·
1 Parent(s): 74321a7
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -38,6 +38,8 @@ model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
38
  # loading model
39
  model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
 
 
41
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
43
  # propagation
@@ -58,6 +60,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
58
  @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
 
61
  if "1.8B" in link:
62
  model = model_1_8
63
  tokenizer = tokenizer_1_8
@@ -67,14 +70,19 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
67
  elif "7B" in link:
68
  model = model_7
69
  tokenizer = tokenizer_7
70
-
 
 
 
 
 
71
  try:
72
  # prompt type -> input prompt
73
  image_token_number = None
74
  if len(message['files']) != 0:
75
  # Image Load
76
  image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
77
- if not "3.8B" in link:
78
  image_token_number = 1225
79
  image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
80
  inputs = [{'image': image, 'question': message['text']}]
@@ -129,7 +137,8 @@ demo = gr.ChatInterface(fn=bot_streaming,
129
  additional_inputs_accordion="Generation Hyperparameters",
130
  theme=gr.themes.Soft(),
131
  title="TroL",
132
- description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy\n"
133
- "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)",
 
134
  stop_btn="Stop Generation", multimodal=True)
135
  demo.launch()
 
38
  # loading model
39
  model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
41
+ print()
42
+
43
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
44
 
45
  # propagation
 
60
  @spaces.GPU
61
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
62
 
63
+ # model selection
64
  if "1.8B" in link:
65
  model = model_1_8
66
  tokenizer = tokenizer_1_8
 
70
  elif "7B" in link:
71
  model = model_7
72
  tokenizer = tokenizer_7
73
+
74
+ # cpu -> gpu
75
+ for param in model.parameters():
76
+ if not param.is_cuda:
77
+ param.data = param.to(accel.device)
78
+
79
  try:
80
  # prompt type -> input prompt
81
  image_token_number = None
82
  if len(message['files']) != 0:
83
  # Image Load
84
  image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
85
+ if "3.8B" not in link:
86
  image_token_number = 1225
87
  image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
88
  inputs = [{'image': image, 'question': message['text']}]
 
137
  additional_inputs_accordion="Generation Hyperparameters",
138
  theme=gr.themes.Soft(),
139
  title="TroL",
140
+ description="TroL is efficient 1.8B, 3.8B, and 7B size Large Language and Vision Models built on new propagation strategy. "
141
+ "Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity) "
142
+ "Note that, we don't support history-based conversation referring to previous dialogue",
143
  stop_btn="Stop Generation", multimodal=True)
144
  demo.launch()