StarCycle's picture
initial commit
5d6d601
raw
history blame
20.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import re
import sys
import torch
from huggingface_hub import snapshot_download
from peft import PeftModel
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, AutoImageProcessor,
Dinov2Model, GenerationConfig)
from xtuner.dataset.utils import expand2square, load_image
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
from xtuner.tools.utils import get_stop_criteria, get_streamer
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
TORCH_DTYPE_MAP = dict(
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
def remove_prefix(state_dict, prefix):
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict
def parse_args():
parser = argparse.ArgumentParser(description='Chat with a HF model')
parser.add_argument(
'model_name_or_path', help='Hugging Face model name or path')
adapter_group = parser.add_mutually_exclusive_group()
adapter_group.add_argument(
'--adapter', default=None, help='adapter name or path')
adapter_group.add_argument(
'--llava', default=None, help='llava name or path')
parser.add_argument(
'--visual-encoder', default=None, help='visual encoder name or path')
parser.add_argument(
'--visual-select-layer', default=-2, help='visual select layer')
parser.add_argument('--image', default=None, help='image')
parser.add_argument(
'--torch-dtype',
default='fp16',
choices=TORCH_DTYPE_MAP.keys(),
help='Override the default `torch.dtype` and load the model under '
'a specific `dtype`.')
parser.add_argument(
'--prompt-template',
choices=PROMPT_TEMPLATE.keys(),
default=None,
help='Specify a prompt template')
system_group = parser.add_mutually_exclusive_group()
system_group.add_argument(
'--system', default=None, help='Specify the system text')
system_group.add_argument(
'--system-template',
choices=SYSTEM_TEMPLATE.keys(),
default=None,
help='Specify a system template')
parser.add_argument(
'--bits',
type=int,
choices=[4, 8, None],
default=None,
help='LLM bits')
parser.add_argument(
'--bot-name', type=str, default='BOT', help='Name for Bot')
parser.add_argument(
'--with-plugins',
nargs='+',
choices=['calculate', 'solve', 'search'],
help='Specify plugins to use')
parser.add_argument(
'--no-streamer', action='store_true', help='Whether to with streamer')
parser.add_argument(
'--lagent', action='store_true', help='Whether to use lagent')
parser.add_argument(
'--stop-words', nargs='+', type=str, default=[], help='Stop words')
parser.add_argument(
'--offload-folder',
default=None,
help='The folder in which to offload the model weights (or where the '
'model weights are already offloaded).')
parser.add_argument(
'--max-new-tokens',
type=int,
default=2048,
help='Maximum number of new tokens allowed in generated text')
parser.add_argument(
'--temperature',
type=float,
default=0.1,
help='The value used to modulate the next token probabilities.')
parser.add_argument(
'--top-k',
type=int,
default=40,
help='The number of highest probability vocabulary tokens to '
'keep for top-k-filtering.')
parser.add_argument(
'--top-p',
type=float,
default=0.75,
help='If set to float < 1, only the smallest set of most probable '
'tokens with probabilities that add up to top_p or higher are '
'kept for generation.')
parser.add_argument(
'--repetition-penalty',
type=float,
default=1.0,
help='The parameter for repetition penalty. 1.0 means no penalty.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Random seed for reproducible text generation')
args = parser.parse_args()
return args
def get_input():
"""Helper function for getting input from users."""
sentinel = '' # ends when this string is seen
result = None
while result is None:
print(('\ndouble enter to end input (EXIT: exit chat, '
'RESET: reset history) >>> '),
end='')
try:
result = '\n'.join(iter(input, sentinel))
except UnicodeDecodeError:
print('Invalid characters detected. Please enter again.')
return result
def main():
args = parse_args()
torch.manual_seed(args.seed)
# build llm
quantization_config = None
load_in_8bit = False
if args.bits == 4:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')
elif args.bits == 8:
load_in_8bit = True
model_kwargs = {
'quantization_config': quantization_config,
'load_in_8bit': load_in_8bit,
'device_map': 'auto',
'offload_folder': args.offload_folder,
'trust_remote_code': True,
'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
}
if args.lagent:
from lagent.actions import ActionExecutor, GoogleSearch
from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN,
ReAct, ReActProtocol)
from lagent.llms import HFTransformerCasualLM
try:
SERPER_API_KEY = os.environ['SERPER_API_KEY']
except Exception:
print('Please obtain the `SERPER_API_KEY` from https://serper.dev '
'and set it using `export SERPER_API_KEY=xxx`.')
sys.exit(1)
model_kwargs.pop('trust_remote_code')
llm = HFTransformerCasualLM(
args.model_name_or_path, model_kwargs=model_kwargs)
if args.adapter is not None:
print(f'Loading adapter from {args.adapter}...')
llm.model = PeftModel.from_pretrained(
llm.model,
args.adapter,
offload_folder=args.offload_folder,
trust_remote_code=True)
search_tool = GoogleSearch(api_key=SERPER_API_KEY)
chatbot = ReAct(
llm=llm,
action_executor=ActionExecutor(actions=[search_tool]),
protocol=ReActProtocol(
call_protocol=CALL_PROTOCOL_CN,
force_stop=FORCE_STOP_PROMPT_CN))
while True:
text = get_input()
while text.strip() == 'RESET':
print('Log: History responses have been removed!')
chatbot._session_history = []
inputs = ''
text = get_input()
if text.strip() == 'EXIT':
print('Log: Exit!')
exit(0)
response = chatbot.chat(text)
print(response.response)
else:
if args.with_plugins is None:
inner_thoughts_open = False
calculate_open = False
solve_open = False
search_open = False
else:
assert args.prompt_template == args.system_template == 'moss_sft'
from plugins import plugins_api
inner_thoughts_open = True
calculate_open = 'calculate' in args.with_plugins
solve_open = 'solve' in args.with_plugins
search_open = 'search' in args.with_plugins
# pre-import for api and model preparation
if calculate_open:
from plugins import calculate # noqa: F401
if solve_open:
from plugins import solve # noqa: F401
if search_open:
from plugins import search # noqa: F401
# build llm
llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
**model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True)
print(f'Load LLM from {args.model_name_or_path}')
if args.adapter is not None:
llm = PeftModel.from_pretrained(
llm,
args.adapter,
offload_folder=args.offload_folder,
trust_remote_code=True)
print(f'Load adapter from {args.adapter}')
if args.llava is not None:
llava_path = snapshot_download(
repo_id=args.llava) if not osp.isdir(
args.llava) else args.llava
# build visual_encoder
if 'visual_encoder' in os.listdir(llava_path):
assert args.visual_encoder is None, (
"Please don't specify the `--visual-encoder` since passed "
'`--llava` contains a visual encoder!')
visual_encoder_path = osp.join(llava_path, 'visual_encoder')
else:
assert args.visual_encoder is not None, (
'Please specify the `--visual-encoder`!')
visual_encoder_path = args.visual_encoder
visual_encoder = Dinov2Model.from_pretrained(
visual_encoder_path,
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = AutoImageProcessor.from_pretrained(
visual_encoder_path)
print(f'Load visual_encoder from {visual_encoder_path}')
# load adapter
if 'llm_adapter' in os.listdir(llava_path):
adapter_path = osp.join(llava_path, 'llm_adapter')
llm = PeftModel.from_pretrained(
llm,
adapter_path,
offload_folder=args.offload_folder,
trust_remote_code=True)
print(f'Load LLM adapter from {args.llava}')
if 'visual_encoder_adapter' in os.listdir(llava_path):
adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
visual_encoder = PeftModel.from_pretrained(
visual_encoder,
adapter_path,
offload_folder=args.offload_folder)
print(f'Load visual_encoder adapter from {args.llava}')
# build projector
projector_path = osp.join(llava_path, 'projector')
projector = AutoModel.from_pretrained(
projector_path,
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype],
trust_remote_code=True)
print(f'Load projector from {args.llava}')
projector.cuda()
projector.eval()
visual_encoder.cuda()
visual_encoder.eval()
llm.eval()
if args.image is not None:
image = load_image(args.image)
image = expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
image = image.cuda().unsqueeze(0)
visual_outputs = visual_encoder(image, output_hidden_states=True)
pixel_values = projector(
visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
stop_words = args.stop_words
sep = ''
if args.prompt_template:
template = PROMPT_TEMPLATE[args.prompt_template]
stop_words += template.get('STOP_WORDS', [])
sep = template.get('SEP', '')
stop_criteria = get_stop_criteria(
tokenizer=tokenizer, stop_words=stop_words)
if args.no_streamer:
Streamer = None
else:
Streamer = get_streamer(llm)
gen_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=args.temperature > 0,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
n_turn = 0
inputs = ''
while True:
text = get_input()
while text.strip() == 'RESET':
print('Log: History responses have been removed!')
n_turn = 0
inputs = ''
text = get_input()
if text.strip() == 'EXIT':
print('Log: Exit!')
exit(0)
if args.image is not None and n_turn == 0:
text = DEFAULT_IMAGE_TOKEN + '\n' + text
if args.prompt_template:
prompt_text = ''
template = PROMPT_TEMPLATE[args.prompt_template]
if 'SYSTEM' in template and n_turn == 0:
system_text = None
if args.system_template is not None:
system_text = SYSTEM_TEMPLATE[
args.system_template].format(
round=n_turn + 1, bot_name=args.bot_name)
elif args.system is not None:
system_text = args.system
if system_text is not None:
prompt_text += template['SYSTEM'].format(
system=system_text,
round=n_turn + 1,
bot_name=args.bot_name)
prompt_text += template['INSTRUCTION'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
if args.prompt_template == args.system_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
'- Inner thoughts: disabled.')
if not calculate_open:
prompt_text.replace(('- Calculator: enabled. API: '
'Calculate(expression)'),
'- Calculator: disabled.')
if not solve_open:
prompt_text.replace(
'- Equation solver: enabled. API: Solve(equation)',
'- Equation solver: disabled.')
if not search_open:
prompt_text.replace(
'- Web search: enabled. API: Search(query)',
'- Web search: disabled.')
else:
prompt_text = text
inputs += prompt_text
if args.image is None:
if n_turn == 0:
ids = tokenizer.encode(inputs, return_tensors='pt')
else:
ids = tokenizer.encode(
inputs, return_tensors='pt', add_special_tokens=False)
streamer = Streamer(
tokenizer) if Streamer is not None else None
if args.with_plugins is not None:
generate_output = llm.generate(
inputs=ids.cuda(),
generation_config=gen_config,
streamer=streamer,
stopping_criteria=stop_criteria).cpu()
generate_output_text = tokenizer.decode(
generate_output[0][len(ids[0]):])
if streamer is None:
end = '' if generate_output_text[-1] == '\n' else '\n'
print(generate_output_text, end=end)
pattern = r'<\|Commands\|>:(.*?)<eoc>'
command_text = ', '.join(
re.findall(pattern, generate_output_text))
extent_text = plugins_api(
command_text,
calculate_open=calculate_open,
solve_open=solve_open,
search_open=search_open)
end = '' if extent_text[-1] == '\n' else '\n'
print(extent_text, end=end)
extent_text_ids = tokenizer.encode(
extent_text,
return_tensors='pt',
add_special_tokens=False)
new_ids = torch.cat((generate_output, extent_text_ids),
dim=1)
new_streamer = Streamer(
tokenizer) if Streamer is not None else None
generate_output = llm.generate(
inputs=new_ids.cuda(),
generation_config=gen_config,
streamer=new_streamer,
stopping_criteria=stop_criteria)
if streamer is None:
output_text = tokenizer.decode(
generate_output[0][len(new_ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
else:
generate_output = llm.generate(
inputs=ids.cuda(),
generation_config=gen_config,
streamer=streamer,
stopping_criteria=stop_criteria)
if streamer is None:
output_text = tokenizer.decode(
generate_output[0][len(ids[0]):])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
inputs = tokenizer.decode(generate_output[0])
else:
chunk_encode = []
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0 and n_turn == 0:
cur_encode = tokenizer.encode(chunk)
else:
cur_encode = tokenizer.encode(
chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
ids.append(IMAGE_TOKEN_INDEX)
ids = torch.tensor(ids).cuda().unsqueeze(0)
mm_inputs = prepare_inputs_labels_for_multimodal(
llm=llm, input_ids=ids, pixel_values=pixel_values)
streamer = Streamer(
tokenizer) if Streamer is not None else None
generate_output = llm.generate(
**mm_inputs,
generation_config=gen_config,
streamer=streamer,
bos_token_id=tokenizer.bos_token_id,
stopping_criteria=stop_criteria)
if streamer is None:
output_text = tokenizer.decode(generate_output[0])
end = '' if output_text[-1] == '\n' else '\n'
print(output_text, end=end)
inputs += tokenizer.decode(generate_output[0])
n_turn += 1
inputs += sep
if len(generate_output[0]) >= args.max_new_tokens:
print(
'Remove the memory of history responses, since '
f'it exceeds the length limitation {args.max_new_tokens}.')
n_turn = 0
inputs = ''
if __name__ == '__main__':
main()