BK-Lee commited on
Commit
a29da54
·
1 Parent(s): 9edaf8c
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -24,23 +24,21 @@ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENT
24
  # accel
25
  accel = Accelerator()
26
 
27
- # model selection
28
- link = "TroL-7B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B'
29
-
30
  # User prompt
31
  prompt_type="with_image" # Select one option "text_only", "with_image"
32
  img_path='figures/demo.png'
33
  question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
34
 
35
  # loading model
36
- model, tokenizer = load_trol(link=link)
 
 
 
37
 
38
- # cpu -> gpu
39
- for param in model.parameters():
40
- if not param.is_cuda:
41
- param.data = param.to('cuda:0')
42
 
43
- def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):
44
 
45
  # propagation
46
  _inputs = model.eval_process(inputs=inputs,
@@ -60,6 +58,16 @@ def threading_function(inputs, image_token_number, streamer, device, temperature
60
  @spaces.GPU
61
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
62
 
 
 
 
 
 
 
 
 
 
 
63
  try:
64
  # prompt type -> input prompt
65
  image_token_number = None
@@ -83,6 +91,8 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
83
  thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
84
  image_token_number=image_token_number,
85
  streamer=streamer,
 
 
86
  device=accel.device,
87
  temperature=temperature,
88
  new_max_token=new_max_token,
@@ -115,7 +125,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
115
  yield buffer
116
 
117
  demo = gr.ChatInterface(fn=bot_streaming,
118
- additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
119
  additional_inputs_accordion="Generation Hyperparameters",
120
  theme=gr.themes.Soft(),
121
  title="TroL",
 
24
  # accel
25
  accel = Accelerator()
26
 
 
 
 
27
  # User prompt
28
  prompt_type="with_image" # Select one option "text_only", "with_image"
29
  img_path='figures/demo.png'
30
  question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
31
 
32
  # loading model
33
+ model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
34
+
35
+ # loading model
36
+ model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
37
 
38
+ # loading model
39
+ model_7, tokenizer_7 = load_trol(link='TroL-7B')
 
 
40
 
41
+ def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
43
  # propagation
44
  _inputs = model.eval_process(inputs=inputs,
 
58
  @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
+ if "1.8B" in link:
62
+ model = model_1_8
63
+ tokenizer = tokenizer_1_8
64
+ elif "3.8B" in link:
65
+ model = model_3_8
66
+ tokenizer = tokenizer_3_8
67
+ elif "7B" in link:
68
+ model = model_7
69
+ tokenizer = tokenizer_7
70
+
71
  try:
72
  # prompt type -> input prompt
73
  image_token_number = None
 
91
  thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
92
  image_token_number=image_token_number,
93
  streamer=streamer,
94
+ model=model,
95
+ tokenizer=tokenizer,
96
  device=accel.device,
97
  temperature=temperature,
98
  new_max_token=new_max_token,
 
125
  yield buffer
126
 
127
  demo = gr.ChatInterface(fn=bot_streaming,
128
+ additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
129
  additional_inputs_accordion="Generation Hyperparameters",
130
  theme=gr.themes.Soft(),
131
  title="TroL",