""" Use FastChat with Hugging Face generation APIs. Usage: python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 """ import argparse import torch from src.model import load_model, get_conversation_template, add_model_args @torch.inference_mode() def main(args): # Load model model, tokenizer = load_model( args.model_path, device=args.device, num_gpus=args.num_gpus, max_gpu_memory=args.max_gpu_memory, load_8bit=args.load_8bit, cpu_offloading=args.cpu_offloading, revision=args.revision, debug=args.debug, ) # Build the prompt with a conversation template msg = args.message conv = get_conversation_template(args.model_path) conv.append_message(conv.roles[0], msg) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() # Run inference inputs = tokenizer([prompt], return_tensors="pt").to(args.device) output_ids = model.generate( **inputs, do_sample=True if args.temperature > 1e-5 else False, temperature=args.temperature, repetition_penalty=args.repetition_penalty, max_new_tokens=args.max_new_tokens, ) if model.config.is_encoder_decoder: output_ids = output_ids[0] else: output_ids = output_ids[0][len(inputs["input_ids"][0]) :] outputs = tokenizer.decode( output_ids, skip_special_tokens=True, spaces_between_special_tokens=False ) # Print results print(f"{conv.roles[0]}: {msg}") print(f"{conv.roles[1]}: {outputs}") if __name__ == "__main__": parser = argparse.ArgumentParser() add_model_args(parser) parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--repetition_penalty", type=float, default=1.0) parser.add_argument("--max-new-tokens", type=int, default=1024) parser.add_argument("--debug", action="store_true") parser.add_argument("--message", type=str, default="Hello! Who are you?") args = parser.parse_args() # Reset default repetition penalty for T5 models. if "t5" in args.model_path and args.repetition_penalty == 1.0: args.repetition_penalty = 1.2 main(args)