Spaces:
Running
Running
from pathlib import Path | |
import tensorrt_llm | |
import torch | |
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp | |
from modules import shared | |
from modules.logging_colors import logger | |
from modules.text_generation import ( | |
get_max_prompt_length, | |
get_reply_from_output_ids | |
) | |
class TensorRTLLMModel: | |
def __init__(self): | |
pass | |
def from_pretrained(self, path_to_model): | |
path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) | |
runtime_rank = tensorrt_llm.mpi_rank() | |
# Define model settings | |
runner_kwargs = dict( | |
engine_dir=str(path_to_model), | |
lora_dir=None, | |
rank=runtime_rank, | |
debug_mode=False, | |
lora_ckpt_source="hf", | |
) | |
if shared.args.cpp_runner: | |
logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"") | |
runner_kwargs.update( | |
max_batch_size=1, | |
max_input_len=shared.args.max_seq_len - 512, | |
max_output_len=512, | |
max_beam_width=1, | |
max_attention_window_size=None, | |
sink_token_length=None, | |
) | |
else: | |
logger.info("TensorRT-LLM: Using \"ModelRunner\"") | |
# Load the model | |
runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner | |
runner = runner_cls.from_dir(**runner_kwargs) | |
result = self() | |
result.model = runner | |
result.runtime_rank = runtime_rank | |
return result | |
def generate_with_streaming(self, prompt, state): | |
batch_input_ids = [] | |
input_ids = shared.tokenizer.encode( | |
prompt, | |
add_special_tokens=True, | |
truncation=False, | |
) | |
input_ids = torch.tensor(input_ids, dtype=torch.int32) | |
input_ids = input_ids[-get_max_prompt_length(state):] # Apply truncation_length | |
batch_input_ids.append(input_ids) | |
if shared.args.cpp_runner: | |
max_new_tokens = min(512, state['max_new_tokens']) | |
elif state['auto_max_new_tokens']: | |
max_new_tokens = state['truncation_length'] - input_ids.shape[-1] | |
else: | |
max_new_tokens = state['max_new_tokens'] | |
with torch.no_grad(): | |
generator = self.model.generate( | |
batch_input_ids, | |
max_new_tokens=max_new_tokens, | |
max_attention_window_size=None, | |
sink_token_length=None, | |
end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1, | |
pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id, | |
temperature=state['temperature'], | |
top_k=state['top_k'], | |
top_p=state['top_p'], | |
num_beams=1, | |
length_penalty=1.0, | |
repetition_penalty=state['repetition_penalty'], | |
presence_penalty=state['presence_penalty'], | |
frequency_penalty=state['frequency_penalty'], | |
stop_words_list=None, | |
bad_words_list=None, | |
lora_uids=None, | |
prompt_table_path=None, | |
prompt_tasks=None, | |
streaming=not shared.args.cpp_runner, | |
output_sequence_lengths=True, | |
return_dict=True, | |
medusa_choices=None | |
) | |
torch.cuda.synchronize() | |
cumulative_reply = '' | |
starting_from = batch_input_ids[0].shape[-1] | |
if shared.args.cpp_runner: | |
sequence_length = generator['sequence_lengths'][0].item() | |
output_ids = generator['output_ids'][0][0][:sequence_length].tolist() | |
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from) | |
starting_from = sequence_length | |
yield cumulative_reply | |
else: | |
for curr_outputs in generator: | |
if shared.stop_everything: | |
break | |
sequence_length = curr_outputs['sequence_lengths'][0].item() | |
output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist() | |
cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from) | |
starting_from = sequence_length | |
yield cumulative_reply | |
def generate(self, prompt, state): | |
output = '' | |
for output in self.generate_with_streaming(prompt, state): | |
pass | |
return output | |