v1
Browse files- app.py +9 -9
- trol/arch_internlm2/modeling_internlm2.py +1 -1
- trol/arch_internlm2/modeling_trol.py +1 -1
- trol/load_trol.py +21 -7
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# A100 Zero GPU
|
2 |
-
|
3 |
|
4 |
# TroL Package
|
5 |
import torch
|
@@ -33,10 +33,10 @@ question="What is the troll doing? Provide the detail in the image and imagine w
|
|
33 |
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
|
34 |
|
35 |
# loading model
|
36 |
-
|
37 |
|
38 |
# loading model
|
39 |
-
|
40 |
|
41 |
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
42 |
|
@@ -55,7 +55,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
|
|
55 |
generation_kwargs.update({'use_cache': True})
|
56 |
return model.generate(**generation_kwargs)
|
57 |
|
58 |
-
|
59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
60 |
|
61 |
# model selection
|
@@ -70,9 +70,9 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
70 |
tokenizer = tokenizer_7
|
71 |
|
72 |
# cpu -> gpu
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
|
77 |
# prompt type -> input prompt
|
78 |
image_token_number = None
|
@@ -131,11 +131,11 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
131 |
buffer = ""
|
132 |
for character in response:
|
133 |
buffer += character
|
134 |
-
time.sleep(0.
|
135 |
yield buffer
|
136 |
|
137 |
demo = gr.ChatInterface(fn=bot_streaming,
|
138 |
-
additional_inputs = [gr.Radio(["1.8B"], label="Size", info="Select one model size", value="
|
139 |
additional_inputs_accordion="Generation Hyperparameters",
|
140 |
theme=gr.themes.Soft(),
|
141 |
title="TroL",
|
|
|
1 |
# A100 Zero GPU
|
2 |
+
import spaces
|
3 |
|
4 |
# TroL Package
|
5 |
import torch
|
|
|
33 |
model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
|
34 |
|
35 |
# loading model
|
36 |
+
model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
|
37 |
|
38 |
# loading model
|
39 |
+
model_7, tokenizer_7 = load_trol(link='TroL-7B')
|
40 |
|
41 |
def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
|
42 |
|
|
|
55 |
generation_kwargs.update({'use_cache': True})
|
56 |
return model.generate(**generation_kwargs)
|
57 |
|
58 |
+
@spaces.GPU
|
59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
60 |
|
61 |
# model selection
|
|
|
70 |
tokenizer = tokenizer_7
|
71 |
|
72 |
# cpu -> gpu
|
73 |
+
for param in model.parameters():
|
74 |
+
if not param.is_cuda:
|
75 |
+
param.data = param.to(accel.device)
|
76 |
|
77 |
# prompt type -> input prompt
|
78 |
image_token_number = None
|
|
|
131 |
buffer = ""
|
132 |
for character in response:
|
133 |
buffer += character
|
134 |
+
time.sleep(0.012)
|
135 |
yield buffer
|
136 |
|
137 |
demo = gr.ChatInterface(fn=bot_streaming,
|
138 |
+
additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
|
139 |
additional_inputs_accordion="Generation Hyperparameters",
|
140 |
theme=gr.themes.Soft(),
|
141 |
title="TroL",
|
trol/arch_internlm2/modeling_internlm2.py
CHANGED
@@ -857,7 +857,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
857 |
self.vocab_size = config.vocab_size
|
858 |
self.config = config
|
859 |
|
860 |
-
self.tok_embeddings = nn.Embedding(config.vocab_size,
|
861 |
config.hidden_size,
|
862 |
self.padding_idx)
|
863 |
self.layers = nn.ModuleList([
|
|
|
857 |
self.vocab_size = config.vocab_size
|
858 |
self.config = config
|
859 |
|
860 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size+1,
|
861 |
config.hidden_size,
|
862 |
self.padding_idx)
|
863 |
self.layers = nn.ModuleList([
|
trol/arch_internlm2/modeling_trol.py
CHANGED
@@ -30,7 +30,7 @@ class TroLForCausalLM(InternLM2PreTrainedModel):
|
|
30 |
# Model
|
31 |
self.model = InternLM2Model(config)
|
32 |
self.vocab_size = config.vocab_size
|
33 |
-
self.output = nn.Linear(config.hidden_size, config.vocab_size
|
34 |
self.max_length = config.max_length
|
35 |
|
36 |
# Initialize weights and apply final processing
|
|
|
30 |
# Model
|
31 |
self.model = InternLM2Model(config)
|
32 |
self.vocab_size = config.vocab_size
|
33 |
+
self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
34 |
self.max_length = config.max_length
|
35 |
|
36 |
# Initialize weights and apply final processing
|
trol/load_trol.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
import torch
|
2 |
import warnings
|
3 |
from config import *
|
4 |
-
from peft import LoraConfig
|
5 |
from transformers import BitsAndBytesConfig
|
6 |
|
7 |
warnings.filterwarnings(action='ignore')
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def load_trol(link):
|
10 |
|
11 |
"""
|
@@ -16,21 +21,24 @@ def load_trol(link):
|
|
16 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
17 |
bits = 4
|
18 |
path = TROL_1_8B
|
19 |
-
|
|
|
20 |
|
21 |
elif link == 'TroL-3.8B':
|
22 |
from trol.arch_phi3.modeling_trol import TroLForCausalLM
|
23 |
from transformers import LlamaTokenizerFast as TroLTokenizer
|
24 |
bits = 8
|
25 |
path = TROL_3_8B
|
26 |
-
|
|
|
27 |
|
28 |
elif link == 'TroL-7B':
|
29 |
from .arch_internlm2.modeling_trol import TroLForCausalLM
|
30 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
31 |
bits = 4
|
32 |
path = TROL_7B
|
33 |
-
|
|
|
34 |
else:
|
35 |
raise Exception("Unsupported Link")
|
36 |
|
@@ -68,10 +76,16 @@ def load_trol(link):
|
|
68 |
except:
|
69 |
del huggingface_config["attn_implementation"]
|
70 |
trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
#
|
|
|
73 |
try:
|
74 |
-
trol =
|
75 |
except:
|
76 |
-
|
77 |
return trol, tok_trol
|
|
|
1 |
import torch
|
2 |
import warnings
|
3 |
from config import *
|
|
|
4 |
from transformers import BitsAndBytesConfig
|
5 |
|
6 |
warnings.filterwarnings(action='ignore')
|
7 |
|
8 |
+
def setting_trol_config(trol, tok_trol, image_special_token):
|
9 |
+
trol.config.image_token_index = tok_trol.convert_tokens_to_ids(image_special_token)
|
10 |
+
trol.config.ignore_index = -100
|
11 |
+
trol.config.pad_token_id = tok_trol.eos_token_id
|
12 |
+
trol.config.eos_token_id = tok_trol.eos_token_id
|
13 |
+
|
14 |
def load_trol(link):
|
15 |
|
16 |
"""
|
|
|
21 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
22 |
bits = 4
|
23 |
path = TROL_1_8B
|
24 |
+
image_special_token = "<image>"
|
25 |
+
bit_quant_skip = ["vit", "vision_proj", "ffn", "output", "trol_gating"]
|
26 |
|
27 |
elif link == 'TroL-3.8B':
|
28 |
from trol.arch_phi3.modeling_trol import TroLForCausalLM
|
29 |
from transformers import LlamaTokenizerFast as TroLTokenizer
|
30 |
bits = 8
|
31 |
path = TROL_3_8B
|
32 |
+
image_special_token = "<IMG_CONTEXT>"
|
33 |
+
bit_quant_skip = ["vision_model", "vision_proj", "lm_head", "trol_gating"]
|
34 |
|
35 |
elif link == 'TroL-7B':
|
36 |
from .arch_internlm2.modeling_trol import TroLForCausalLM
|
37 |
from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
|
38 |
bits = 4
|
39 |
path = TROL_7B
|
40 |
+
image_special_token = "<image>"
|
41 |
+
bit_quant_skip = ["vit", "vision_proj", "ffn", "output", "trol_gating"]
|
42 |
else:
|
43 |
raise Exception("Unsupported Link")
|
44 |
|
|
|
76 |
except:
|
77 |
del huggingface_config["attn_implementation"]
|
78 |
trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
|
79 |
+
trol.config.llm_config.use_cache = False
|
80 |
+
|
81 |
+
# setting config
|
82 |
+
setting_trol_config(trol, tok_trol, image_special_token)
|
83 |
+
|
84 |
|
85 |
+
# trol gating load
|
86 |
+
from huggingface_hub import hf_hub_download
|
87 |
try:
|
88 |
+
trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
89 |
except:
|
90 |
+
trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
91 |
return trol, tok_trol
|