Spaces:
Sleeping
Sleeping
File size: 4,395 Bytes
140387c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import time
import argparse
from dotenv import load_dotenv
from distutils.util import strtobool
from memory_profiler import memory_usage
from tqdm import tqdm
from llama2_wrapper import LLAMA2_WRAPPER
def run_iteration(
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
):
def generation():
generator = llama2_wrapper.run(
prompt_example,
[],
DEFAULT_SYSTEM_PROMPT,
DEFAULT_MAX_NEW_TOKENS,
1,
0.95,
50,
)
model_response = None
try:
first_model_response = next(generator)
except StopIteration:
pass
for model_response in generator:
pass
return llama2_wrapper.get_token_length(model_response), model_response
tic = time.perf_counter()
mem_usage, (output_token_length, model_response) = memory_usage(
(generation,), max_usage=True, retval=True
)
toc = time.perf_counter()
generation_time = toc - tic
tokens_per_second = output_token_length / generation_time
return generation_time, tokens_per_second, mem_usage, model_response
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--iter", type=int, default=5, help="Number of iterations")
parser.add_argument("--model_path", type=str, default="", help="model path")
parser.add_argument(
"--backend_type",
type=str,
default="",
help="Backend options: llama.cpp, gptq, transformers",
)
parser.add_argument(
"--load_in_8bit",
type=bool,
default=False,
help="Whether to use bitsandbytes 8 bit.",
)
args = parser.parse_args()
load_dotenv()
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))
MODEL_PATH = os.getenv("MODEL_PATH")
assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
BACKEND_TYPE = os.getenv("BACKEND_TYPE")
assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"
LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))
if args.model_path != "":
MODEL_PATH = args.model_path
if args.backend_type != "":
BACKEND_TYPE = args.backend_type
if args.load_in_8bit:
LOAD_IN_8BIT = True
# Initialization
init_tic = time.perf_counter()
llama2_wrapper = LLAMA2_WRAPPER(
model_path=MODEL_PATH,
backend_type=BACKEND_TYPE,
max_tokens=MAX_INPUT_TOKEN_LENGTH,
load_in_8bit=LOAD_IN_8BIT,
# verbose=True,
)
init_toc = time.perf_counter()
initialization_time = init_toc - init_tic
total_time = 0
total_tokens_per_second = 0
total_memory_gen = 0
prompt_example = (
"Can you explain briefly to me what is the Python programming language?"
)
# Cold run
print("Performing cold run...")
run_iteration(
llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
)
# Timed runs
print(f"Performing {args.iter} timed runs...")
for i in tqdm(range(args.iter)):
try:
gen_time, tokens_per_sec, mem_gen, model_response = run_iteration(
llama2_wrapper,
prompt_example,
DEFAULT_SYSTEM_PROMPT,
DEFAULT_MAX_NEW_TOKENS,
)
total_time += gen_time
total_tokens_per_second += tokens_per_sec
total_memory_gen += mem_gen
except:
break
avg_time = total_time / (i + 1)
avg_tokens_per_second = total_tokens_per_second / (i + 1)
avg_memory_gen = total_memory_gen / (i + 1)
print(f"Last model response: {model_response}")
print(f"Initialization time: {initialization_time:0.4f} seconds.")
print(
f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds."
)
print(
f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec."
)
print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB")
if __name__ == "__main__":
main()
|