BK-Lee commited on
Commit
c75a3b7
·
1 Parent(s): a56928d
Files changed (2) hide show
  1. app.py +50 -51
  2. trol/load_trol.py +4 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # A100 Zero GPU
2
- import spaces
3
 
4
  # TroL Package
5
  import torch
@@ -18,8 +18,8 @@ from transformers import TextIteratorStreamer
18
  from torchvision.transforms.functional import pil_to_tensor
19
 
20
  # flash attention
21
- import subprocess
22
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
 
24
  # accel
25
  accel = Accelerator()
@@ -55,7 +55,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
55
  generation_kwargs.update({'use_cache': True})
56
  return model.generate(**generation_kwargs)
57
 
58
- @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
  # model selection
@@ -70,53 +70,52 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
70
  tokenizer = tokenizer_7
71
 
72
  # cpu -> gpu
73
- for param in model.parameters():
74
- if not param.is_cuda:
75
- param.data = param.to(accel.device)
76
-
77
- try:
78
- # prompt type -> input prompt
79
- image_token_number = None
80
- if len(message['files']) == 1:
81
- # Image Load
82
- image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
83
- if "3.8B" not in link:
84
- image_token_number = 1225
85
- image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
86
- inputs = [{'image': image.to(accel.device), 'question': message['text']}]
87
- elif len(message['files']) > 1:
88
- raise Exception("No way!")
89
- else:
90
- inputs = [{'question': message['text']}]
91
-
92
- # Text Generation
93
- with torch.inference_mode():
94
- # kwargs
95
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
96
-
97
- # Threading generation
98
- thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
99
- image_token_number=image_token_number,
100
- streamer=streamer,
101
- model=model,
102
- tokenizer=tokenizer,
103
- device=accel.device,
104
- temperature=temperature,
105
- new_max_token=new_max_token,
106
- top_p=top_p))
107
- thread.start()
108
-
109
- # generated text
110
- generated_text = ""
111
- for new_text in streamer:
112
- generated_text += new_text
113
- generated_text
114
-
115
- # Text decoding
116
- response = output_filtering(generated_text, model)
117
-
118
- except:
119
- response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
120
 
121
  # private log print
122
  text = message['text']
 
1
  # A100 Zero GPU
2
+ # import spaces
3
 
4
  # TroL Package
5
  import torch
 
18
  from torchvision.transforms.functional import pil_to_tensor
19
 
20
  # flash attention
21
+ # import subprocess
22
+ # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
 
24
  # accel
25
  accel = Accelerator()
 
55
  generation_kwargs.update({'use_cache': True})
56
  return model.generate(**generation_kwargs)
57
 
58
+ # @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
  # model selection
 
70
  tokenizer = tokenizer_7
71
 
72
  # cpu -> gpu
73
+ # for param in model.parameters():
74
+ # if not param.is_cuda:
75
+ # param.data = param.to(accel.device)
76
+
77
+ # prompt type -> input prompt
78
+ image_token_number = None
79
+ if len(message['files']) == 1:
80
+ # Image Load
81
+ image = pil_to_tensor(Image.open(message['files'][0]).convert("RGB"))
82
+ if "3.8B" not in link:
83
+ image_token_number = 1225
84
+ image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
85
+ inputs = [{'image': image.to(accel.device), 'question': message['text']}]
86
+ elif len(message['files']) > 1:
87
+ raise Exception("No way!")
88
+ else:
89
+ inputs = [{'question': message['text']}]
90
+
91
+ # Text Generation
92
+ with torch.inference_mode():
93
+ # kwargs
94
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
95
+
96
+ # Threading generation
97
+ thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
98
+ image_token_number=image_token_number,
99
+ streamer=streamer,
100
+ model=model,
101
+ tokenizer=tokenizer,
102
+ device=accel.device,
103
+ temperature=temperature,
104
+ new_max_token=new_max_token,
105
+ top_p=top_p))
106
+ thread.start()
107
+
108
+ # generated text
109
+ generated_text = ""
110
+ for new_text in streamer:
111
+ generated_text += new_text
112
+ generated_text
113
+
114
+ # Text decoding
115
+ response = output_filtering(generated_text, model)
116
+
117
+ # except:
118
+ # response = "There may be unsupported format: ex) pdf, video, sound. Only supported is a single image in this version."
 
119
 
120
  # private log print
121
  text = message['text']
trol/load_trol.py CHANGED
@@ -14,14 +14,14 @@ def load_trol(link):
14
  if link == 'TroL-1.8B':
15
  from .arch_internlm2.modeling_trol import TroLForCausalLM
16
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
17
- bits = 4
18
  path = TROL_1_8B
19
  bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
20
 
21
  elif link == 'TroL-3.8B':
22
  from trol.arch_phi3.modeling_trol import TroLForCausalLM
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
- bits = 8
25
  path = TROL_3_8B
26
  bit_quant_skip = ["vision_model", "vision_proj", "lm_head"]
27
 
@@ -64,8 +64,8 @@ def load_trol(link):
64
  # Loading tokenizer & Loading backbone model (error -> then delete flash attention)
65
  tok_trol = TroLTokenizer.from_pretrained(path, padding_side='left')
66
  try:
67
- trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
68
  except:
69
  del huggingface_config["attn_implementation"]
70
- trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
71
  return trol, tok_trol
 
14
  if link == 'TroL-1.8B':
15
  from .arch_internlm2.modeling_trol import TroLForCausalLM
16
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
17
+ bits = 16
18
  path = TROL_1_8B
19
  bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
20
 
21
  elif link == 'TroL-3.8B':
22
  from trol.arch_phi3.modeling_trol import TroLForCausalLM
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
+ bits = 16
25
  path = TROL_3_8B
26
  bit_quant_skip = ["vision_model", "vision_proj", "lm_head"]
27
 
 
64
  # Loading tokenizer & Loading backbone model (error -> then delete flash attention)
65
  tok_trol = TroLTokenizer.from_pretrained(path, padding_side='left')
66
  try:
67
+ trol = TroLForCausalLM.from_pretrained(path, **huggingface_config).cuda()
68
  except:
69
  del huggingface_config["attn_implementation"]
70
+ trol = TroLForCausalLM.from_pretrained(path, **huggingface_config).cuda()
71
  return trol, tok_trol