# Copyright (c) OpenMMLab. All rights reserved. import argparse import os import os.path as osp import re import sys import torch from huggingface_hub import snapshot_download from peft import PeftModel from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoImageProcessor, Dinov2Model, GenerationConfig) from xtuner.dataset.utils import expand2square, load_image from xtuner.model.utils import prepare_inputs_labels_for_multimodal from xtuner.tools.utils import get_stop_criteria, get_streamer from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, PROMPT_TEMPLATE, SYSTEM_TEMPLATE) TORCH_DTYPE_MAP = dict( fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') def remove_prefix(state_dict, prefix): new_state_dict = {} for key, value in state_dict.items(): if key.startswith(prefix): new_key = key[len(prefix):] new_state_dict[new_key] = value else: new_state_dict[key] = value return new_state_dict def parse_args(): parser = argparse.ArgumentParser(description='Chat with a HF model') parser.add_argument( 'model_name_or_path', help='Hugging Face model name or path') adapter_group = parser.add_mutually_exclusive_group() adapter_group.add_argument( '--adapter', default=None, help='adapter name or path') adapter_group.add_argument( '--llava', default=None, help='llava name or path') parser.add_argument( '--visual-encoder', default=None, help='visual encoder name or path') parser.add_argument( '--visual-select-layer', default=-2, help='visual select layer') parser.add_argument('--image', default=None, help='image') parser.add_argument( '--torch-dtype', default='fp16', choices=TORCH_DTYPE_MAP.keys(), help='Override the default `torch.dtype` and load the model under ' 'a specific `dtype`.') parser.add_argument( '--prompt-template', choices=PROMPT_TEMPLATE.keys(), default=None, help='Specify a prompt template') system_group = parser.add_mutually_exclusive_group() system_group.add_argument( '--system', default=None, help='Specify the system text') system_group.add_argument( '--system-template', choices=SYSTEM_TEMPLATE.keys(), default=None, help='Specify a system template') parser.add_argument( '--bits', type=int, choices=[4, 8, None], default=None, help='LLM bits') parser.add_argument( '--bot-name', type=str, default='BOT', help='Name for Bot') parser.add_argument( '--with-plugins', nargs='+', choices=['calculate', 'solve', 'search'], help='Specify plugins to use') parser.add_argument( '--no-streamer', action='store_true', help='Whether to with streamer') parser.add_argument( '--lagent', action='store_true', help='Whether to use lagent') parser.add_argument( '--stop-words', nargs='+', type=str, default=[], help='Stop words') parser.add_argument( '--offload-folder', default=None, help='The folder in which to offload the model weights (or where the ' 'model weights are already offloaded).') parser.add_argument( '--max-new-tokens', type=int, default=2048, help='Maximum number of new tokens allowed in generated text') parser.add_argument( '--temperature', type=float, default=0.1, help='The value used to modulate the next token probabilities.') parser.add_argument( '--top-k', type=int, default=40, help='The number of highest probability vocabulary tokens to ' 'keep for top-k-filtering.') parser.add_argument( '--top-p', type=float, default=0.75, help='If set to float < 1, only the smallest set of most probable ' 'tokens with probabilities that add up to top_p or higher are ' 'kept for generation.') parser.add_argument( '--repetition-penalty', type=float, default=1.0, help='The parameter for repetition penalty. 1.0 means no penalty.') parser.add_argument( '--seed', type=int, default=0, help='Random seed for reproducible text generation') args = parser.parse_args() return args def get_input(): """Helper function for getting input from users.""" sentinel = '' # ends when this string is seen result = None while result is None: print(('\ndouble enter to end input (EXIT: exit chat, ' 'RESET: reset history) >>> '), end='') try: result = '\n'.join(iter(input, sentinel)) except UnicodeDecodeError: print('Invalid characters detected. Please enter again.') return result def main(): args = parse_args() torch.manual_seed(args.seed) # build llm quantization_config = None load_in_8bit = False if args.bits == 4: quantization_config = BitsAndBytesConfig( load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4') elif args.bits == 8: load_in_8bit = True model_kwargs = { 'quantization_config': quantization_config, 'load_in_8bit': load_in_8bit, 'device_map': 'auto', 'offload_folder': args.offload_folder, 'trust_remote_code': True, 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] } if args.lagent: from lagent.actions import ActionExecutor, GoogleSearch from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN, ReAct, ReActProtocol) from lagent.llms import HFTransformerCasualLM try: SERPER_API_KEY = os.environ['SERPER_API_KEY'] except Exception: print('Please obtain the `SERPER_API_KEY` from https://serper.dev ' 'and set it using `export SERPER_API_KEY=xxx`.') sys.exit(1) model_kwargs.pop('trust_remote_code') llm = HFTransformerCasualLM( args.model_name_or_path, model_kwargs=model_kwargs) if args.adapter is not None: print(f'Loading adapter from {args.adapter}...') llm.model = PeftModel.from_pretrained( llm.model, args.adapter, offload_folder=args.offload_folder, trust_remote_code=True) search_tool = GoogleSearch(api_key=SERPER_API_KEY) chatbot = ReAct( llm=llm, action_executor=ActionExecutor(actions=[search_tool]), protocol=ReActProtocol( call_protocol=CALL_PROTOCOL_CN, force_stop=FORCE_STOP_PROMPT_CN)) while True: text = get_input() while text.strip() == 'RESET': print('Log: History responses have been removed!') chatbot._session_history = [] inputs = '' text = get_input() if text.strip() == 'EXIT': print('Log: Exit!') exit(0) response = chatbot.chat(text) print(response.response) else: if args.with_plugins is None: inner_thoughts_open = False calculate_open = False solve_open = False search_open = False else: assert args.prompt_template == args.system_template == 'moss_sft' from plugins import plugins_api inner_thoughts_open = True calculate_open = 'calculate' in args.with_plugins solve_open = 'solve' in args.with_plugins search_open = 'search' in args.with_plugins # pre-import for api and model preparation if calculate_open: from plugins import calculate # noqa: F401 if solve_open: from plugins import solve # noqa: F401 if search_open: from plugins import search # noqa: F401 # build llm llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, encode_special_tokens=True) print(f'Load LLM from {args.model_name_or_path}') if args.adapter is not None: llm = PeftModel.from_pretrained( llm, args.adapter, offload_folder=args.offload_folder, trust_remote_code=True) print(f'Load adapter from {args.adapter}') if args.llava is not None: llava_path = snapshot_download( repo_id=args.llava) if not osp.isdir( args.llava) else args.llava # build visual_encoder if 'visual_encoder' in os.listdir(llava_path): assert args.visual_encoder is None, ( "Please don't specify the `--visual-encoder` since passed " '`--llava` contains a visual encoder!') visual_encoder_path = osp.join(llava_path, 'visual_encoder') else: assert args.visual_encoder is not None, ( 'Please specify the `--visual-encoder`!') visual_encoder_path = args.visual_encoder visual_encoder = Dinov2Model.from_pretrained( visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) image_processor = AutoImageProcessor.from_pretrained( visual_encoder_path) print(f'Load visual_encoder from {visual_encoder_path}') # load adapter if 'llm_adapter' in os.listdir(llava_path): adapter_path = osp.join(llava_path, 'llm_adapter') llm = PeftModel.from_pretrained( llm, adapter_path, offload_folder=args.offload_folder, trust_remote_code=True) print(f'Load LLM adapter from {args.llava}') if 'visual_encoder_adapter' in os.listdir(llava_path): adapter_path = osp.join(llava_path, 'visual_encoder_adapter') visual_encoder = PeftModel.from_pretrained( visual_encoder, adapter_path, offload_folder=args.offload_folder) print(f'Load visual_encoder adapter from {args.llava}') # build projector projector_path = osp.join(llava_path, 'projector') projector = AutoModel.from_pretrained( projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype], trust_remote_code=True) print(f'Load projector from {args.llava}') projector.cuda() projector.eval() visual_encoder.cuda() visual_encoder.eval() llm.eval() if args.image is not None: image = load_image(args.image) image = expand2square( image, tuple(int(x * 255) for x in image_processor.image_mean)) image = image_processor.preprocess( image, return_tensors='pt')['pixel_values'][0] image = image.cuda().unsqueeze(0) visual_outputs = visual_encoder(image, output_hidden_states=True) pixel_values = projector( visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) stop_words = args.stop_words sep = '' if args.prompt_template: template = PROMPT_TEMPLATE[args.prompt_template] stop_words += template.get('STOP_WORDS', []) sep = template.get('SEP', '') stop_criteria = get_stop_criteria( tokenizer=tokenizer, stop_words=stop_words) if args.no_streamer: Streamer = None else: Streamer = get_streamer(llm) gen_config = GenerationConfig( max_new_tokens=args.max_new_tokens, do_sample=args.temperature > 0, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, ) n_turn = 0 inputs = '' while True: text = get_input() while text.strip() == 'RESET': print('Log: History responses have been removed!') n_turn = 0 inputs = '' text = get_input() if text.strip() == 'EXIT': print('Log: Exit!') exit(0) if args.image is not None and n_turn == 0: text = DEFAULT_IMAGE_TOKEN + '\n' + text if args.prompt_template: prompt_text = '' template = PROMPT_TEMPLATE[args.prompt_template] if 'SYSTEM' in template and n_turn == 0: system_text = None if args.system_template is not None: system_text = SYSTEM_TEMPLATE[ args.system_template].format( round=n_turn + 1, bot_name=args.bot_name) elif args.system is not None: system_text = args.system if system_text is not None: prompt_text += template['SYSTEM'].format( system=system_text, round=n_turn + 1, bot_name=args.bot_name) prompt_text += template['INSTRUCTION'].format( input=text, round=n_turn + 1, bot_name=args.bot_name) if args.prompt_template == args.system_template == 'moss_sft': if not inner_thoughts_open: prompt_text.replace('- Inner thoughts: enabled.', '- Inner thoughts: disabled.') if not calculate_open: prompt_text.replace(('- Calculator: enabled. API: ' 'Calculate(expression)'), '- Calculator: disabled.') if not solve_open: prompt_text.replace( '- Equation solver: enabled. API: Solve(equation)', '- Equation solver: disabled.') if not search_open: prompt_text.replace( '- Web search: enabled. API: Search(query)', '- Web search: disabled.') else: prompt_text = text inputs += prompt_text if args.image is None: if n_turn == 0: ids = tokenizer.encode(inputs, return_tensors='pt') else: ids = tokenizer.encode( inputs, return_tensors='pt', add_special_tokens=False) streamer = Streamer( tokenizer) if Streamer is not None else None if args.with_plugins is not None: generate_output = llm.generate( inputs=ids.cuda(), generation_config=gen_config, streamer=streamer, stopping_criteria=stop_criteria).cpu() generate_output_text = tokenizer.decode( generate_output[0][len(ids[0]):]) if streamer is None: end = '' if generate_output_text[-1] == '\n' else '\n' print(generate_output_text, end=end) pattern = r'<\|Commands\|>:(.*?)' command_text = ', '.join( re.findall(pattern, generate_output_text)) extent_text = plugins_api( command_text, calculate_open=calculate_open, solve_open=solve_open, search_open=search_open) end = '' if extent_text[-1] == '\n' else '\n' print(extent_text, end=end) extent_text_ids = tokenizer.encode( extent_text, return_tensors='pt', add_special_tokens=False) new_ids = torch.cat((generate_output, extent_text_ids), dim=1) new_streamer = Streamer( tokenizer) if Streamer is not None else None generate_output = llm.generate( inputs=new_ids.cuda(), generation_config=gen_config, streamer=new_streamer, stopping_criteria=stop_criteria) if streamer is None: output_text = tokenizer.decode( generate_output[0][len(new_ids[0]):]) end = '' if output_text[-1] == '\n' else '\n' print(output_text, end=end) else: generate_output = llm.generate( inputs=ids.cuda(), generation_config=gen_config, streamer=streamer, stopping_criteria=stop_criteria) if streamer is None: output_text = tokenizer.decode( generate_output[0][len(ids[0]):]) end = '' if output_text[-1] == '\n' else '\n' print(output_text, end=end) inputs = tokenizer.decode(generate_output[0]) else: chunk_encode = [] for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): if idx == 0 and n_turn == 0: cur_encode = tokenizer.encode(chunk) else: cur_encode = tokenizer.encode( chunk, add_special_tokens=False) chunk_encode.append(cur_encode) assert len(chunk_encode) == 2 ids = [] for idx, cur_chunk_encode in enumerate(chunk_encode): ids.extend(cur_chunk_encode) if idx != len(chunk_encode) - 1: ids.append(IMAGE_TOKEN_INDEX) ids = torch.tensor(ids).cuda().unsqueeze(0) mm_inputs = prepare_inputs_labels_for_multimodal( llm=llm, input_ids=ids, pixel_values=pixel_values) streamer = Streamer( tokenizer) if Streamer is not None else None generate_output = llm.generate( **mm_inputs, generation_config=gen_config, streamer=streamer, bos_token_id=tokenizer.bos_token_id, stopping_criteria=stop_criteria) if streamer is None: output_text = tokenizer.decode(generate_output[0]) end = '' if output_text[-1] == '\n' else '\n' print(output_text, end=end) inputs += tokenizer.decode(generate_output[0]) n_turn += 1 inputs += sep if len(generate_output[0]) >= args.max_new_tokens: print( 'Remove the memory of history responses, since ' f'it exceeds the length limitation {args.max_new_tokens}.') n_turn = 0 inputs = '' if __name__ == '__main__': main()