Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
chat.py
CHANGED
@@ -243,18 +243,28 @@ def load_metadata(model,args):
|
|
243 |
else:
|
244 |
ctx_len = args.context_length
|
245 |
|
246 |
-
# Use defaults
|
247 |
metadata['context_length'] = ctx_len
|
248 |
metadata['state_length'] = ctx_len
|
249 |
-
|
|
|
250 |
metadata['lut_bits'] = 4
|
251 |
-
metadata['num_chunks'] = 4
|
252 |
-
print("\nUsing
|
253 |
print(f" Context Length: {metadata['context_length']}")
|
254 |
print(f" State Length: {metadata['state_length']}")
|
255 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
256 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
257 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
return metadata
|
259 |
|
260 |
def load_models(args,metadata):
|
@@ -376,11 +386,19 @@ def make_causal_mask(length, start):
|
|
376 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
377 |
return mask
|
378 |
|
379 |
-
def
|
380 |
-
"""
|
381 |
-
# Create causal mask
|
382 |
causal_mask = make_causal_mask(context_length, 0)
|
383 |
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
# Process in batches
|
386 |
batch_pos = 0
|
@@ -423,7 +441,7 @@ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length,
|
|
423 |
|
424 |
return torch.tensor([context_pos], dtype=torch.int32)
|
425 |
|
426 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
|
427 |
"""Generate the next token."""
|
428 |
# Get current token
|
429 |
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
@@ -437,8 +455,13 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
437 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
438 |
update_mask[0, 0, pos-1, 0] = 1.0
|
439 |
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
440 |
-
|
441 |
-
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
# Run through FFN chunks with state
|
444 |
for ffn_model in ffn_models:
|
@@ -447,7 +470,7 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
447 |
'hidden_states': hidden_states.numpy(),
|
448 |
'update_mask': update_mask.numpy(),
|
449 |
'position_ids': position_ids.numpy(),
|
450 |
-
'causal_mask':
|
451 |
'current_pos': position_ids.numpy()
|
452 |
}
|
453 |
output = ffn_model['infer'].predict(inputs, state)
|
@@ -493,7 +516,7 @@ def create_unified_state(ffn_models, context_length):
|
|
493 |
print("\nCreated unified transformer state")
|
494 |
return state
|
495 |
|
496 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
|
497 |
"""Interactive chat loop."""
|
498 |
context_length = metadata.get('context_length')
|
499 |
batch_size = metadata.get('batch_size', 64)
|
@@ -567,7 +590,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
567 |
# Start prefill timing
|
568 |
prefill_start = time.time()
|
569 |
|
570 |
-
# Run prefill with state
|
571 |
current_pos = run_prefill(
|
572 |
embed_model,
|
573 |
ffn_models,
|
@@ -575,7 +598,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
575 |
context_pos,
|
576 |
context_length,
|
577 |
batch_size,
|
578 |
-
state
|
|
|
579 |
)
|
580 |
|
581 |
# Calculate prefill timing
|
@@ -590,7 +614,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
590 |
inference_tokens = 0
|
591 |
|
592 |
while pos < context_length - 1:
|
593 |
-
# Generate next token
|
594 |
next_token = generate_next_token(
|
595 |
embed_model,
|
596 |
ffn_models,
|
@@ -598,7 +622,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
598 |
input_ids,
|
599 |
pos,
|
600 |
context_length,
|
601 |
-
state
|
|
|
602 |
)
|
603 |
|
604 |
# Add token to sequence
|
@@ -657,7 +682,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
657 |
traceback.print_exc()
|
658 |
|
659 |
def parse_args():
|
660 |
-
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
|
661 |
|
662 |
# Add meta.yaml option
|
663 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
@@ -678,9 +703,15 @@ def parse_args():
|
|
678 |
parser.add_argument('--prompt', type=str,
|
679 |
help='If specified, run once with this prompt and exit')
|
680 |
|
|
|
|
|
|
|
|
|
681 |
# Model configuration
|
682 |
parser.add_argument('--context-length', type=int,
|
683 |
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
|
|
|
|
684 |
|
685 |
args = parser.parse_args()
|
686 |
|
@@ -711,9 +742,11 @@ def parse_args():
|
|
711 |
if not args.tokenizer:
|
712 |
args.tokenizer = args.d
|
713 |
|
714 |
-
# Set other parameters
|
715 |
-
args.context_length
|
716 |
-
|
|
|
|
|
717 |
args.num_chunks = num_chunks
|
718 |
|
719 |
print(f"\nLoaded parameters from {args.meta}:")
|
@@ -782,18 +815,23 @@ def main():
|
|
782 |
# Create unified state once
|
783 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
784 |
|
|
|
|
|
|
|
785 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
|
|
|
|
797 |
|
798 |
# Main run
|
799 |
chat_loop(
|
@@ -803,6 +841,7 @@ def main():
|
|
803 |
tokenizer=tokenizer,
|
804 |
metadata=metadata,
|
805 |
state=state,
|
|
|
806 |
warmup=False,
|
807 |
auto_prompt=args.prompt
|
808 |
)
|
|
|
243 |
else:
|
244 |
ctx_len = args.context_length
|
245 |
|
246 |
+
# Use defaults or values from args
|
247 |
metadata['context_length'] = ctx_len
|
248 |
metadata['state_length'] = ctx_len
|
249 |
+
# Get batch size from args or use default
|
250 |
+
metadata['batch_size'] = getattr(args, 'batch_size', 64)
|
251 |
metadata['lut_bits'] = 4
|
252 |
+
metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
|
253 |
+
print("\nUsing parameters:")
|
254 |
print(f" Context Length: {metadata['context_length']}")
|
255 |
print(f" State Length: {metadata['state_length']}")
|
256 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
257 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
258 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
259 |
+
|
260 |
+
# Override with values from args if they exist
|
261 |
+
if hasattr(args, 'batch_size') and args.batch_size is not None:
|
262 |
+
metadata['batch_size'] = args.batch_size
|
263 |
+
print(f"\nOverriding batch size from args: {args.batch_size}")
|
264 |
+
if hasattr(args, 'num_chunks') and args.num_chunks is not None:
|
265 |
+
metadata['num_chunks'] = args.num_chunks
|
266 |
+
print(f"\nOverriding num chunks from args: {args.num_chunks}")
|
267 |
+
|
268 |
return metadata
|
269 |
|
270 |
def load_models(args,metadata):
|
|
|
386 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
387 |
return mask
|
388 |
|
389 |
+
def initialize_causal_mask(context_length):
|
390 |
+
"""Initialize causal mask for transformer attention."""
|
|
|
391 |
causal_mask = make_causal_mask(context_length, 0)
|
392 |
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
393 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
394 |
+
return causal_mask
|
395 |
+
|
396 |
+
def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
|
397 |
+
"""Run prefill on the input sequence."""
|
398 |
+
# Use provided causal mask or create one if not provided
|
399 |
+
if causal_mask is None:
|
400 |
+
causal_mask = make_causal_mask(context_length, 0)
|
401 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
402 |
|
403 |
# Process in batches
|
404 |
batch_pos = 0
|
|
|
441 |
|
442 |
return torch.tensor([context_pos], dtype=torch.int32)
|
443 |
|
444 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
|
445 |
"""Generate the next token."""
|
446 |
# Get current token
|
447 |
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
|
|
455 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
456 |
update_mask[0, 0, pos-1, 0] = 1.0
|
457 |
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
458 |
+
|
459 |
+
# Use provided causal mask or create one if not provided
|
460 |
+
if causal_mask is None:
|
461 |
+
causal_mask_data = make_causal_mask(context_length, 0)
|
462 |
+
single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
|
463 |
+
else:
|
464 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
465 |
|
466 |
# Run through FFN chunks with state
|
467 |
for ffn_model in ffn_models:
|
|
|
470 |
'hidden_states': hidden_states.numpy(),
|
471 |
'update_mask': update_mask.numpy(),
|
472 |
'position_ids': position_ids.numpy(),
|
473 |
+
'causal_mask': single_causal_mask.numpy(),
|
474 |
'current_pos': position_ids.numpy()
|
475 |
}
|
476 |
output = ffn_model['infer'].predict(inputs, state)
|
|
|
516 |
print("\nCreated unified transformer state")
|
517 |
return state
|
518 |
|
519 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
|
520 |
"""Interactive chat loop."""
|
521 |
context_length = metadata.get('context_length')
|
522 |
batch_size = metadata.get('batch_size', 64)
|
|
|
590 |
# Start prefill timing
|
591 |
prefill_start = time.time()
|
592 |
|
593 |
+
# Run prefill with state and causal mask
|
594 |
current_pos = run_prefill(
|
595 |
embed_model,
|
596 |
ffn_models,
|
|
|
598 |
context_pos,
|
599 |
context_length,
|
600 |
batch_size,
|
601 |
+
state,
|
602 |
+
causal_mask
|
603 |
)
|
604 |
|
605 |
# Calculate prefill timing
|
|
|
614 |
inference_tokens = 0
|
615 |
|
616 |
while pos < context_length - 1:
|
617 |
+
# Generate next token with causal mask
|
618 |
next_token = generate_next_token(
|
619 |
embed_model,
|
620 |
ffn_models,
|
|
|
622 |
input_ids,
|
623 |
pos,
|
624 |
context_length,
|
625 |
+
state,
|
626 |
+
causal_mask
|
627 |
)
|
628 |
|
629 |
# Add token to sequence
|
|
|
682 |
traceback.print_exc()
|
683 |
|
684 |
def parse_args():
|
685 |
+
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
|
686 |
|
687 |
# Add meta.yaml option
|
688 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
|
703 |
parser.add_argument('--prompt', type=str,
|
704 |
help='If specified, run once with this prompt and exit')
|
705 |
|
706 |
+
# Add no-warmup flag
|
707 |
+
parser.add_argument('--nw', action='store_true',
|
708 |
+
help='Skip warmup phase')
|
709 |
+
|
710 |
# Model configuration
|
711 |
parser.add_argument('--context-length', type=int,
|
712 |
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
713 |
+
parser.add_argument('--batch-size', type=int,
|
714 |
+
help='Batch size for prefill (default: 64)')
|
715 |
|
716 |
args = parser.parse_args()
|
717 |
|
|
|
742 |
if not args.tokenizer:
|
743 |
args.tokenizer = args.d
|
744 |
|
745 |
+
# Set other parameters if not overridden by command line
|
746 |
+
if args.context_length is None:
|
747 |
+
args.context_length = int(params['context_length'])
|
748 |
+
if args.batch_size is None:
|
749 |
+
args.batch_size = int(params['batch_size'])
|
750 |
args.num_chunks = num_chunks
|
751 |
|
752 |
print(f"\nLoaded parameters from {args.meta}:")
|
|
|
815 |
# Create unified state once
|
816 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
817 |
|
818 |
+
# Initialize causal mask once
|
819 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
820 |
+
|
821 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
822 |
+
if not args.nw:
|
823 |
+
for i in range(2):
|
824 |
+
chat_loop(
|
825 |
+
embed_model=embed_model,
|
826 |
+
ffn_models=ffn_models,
|
827 |
+
lmhead_model=lmhead_model,
|
828 |
+
tokenizer=tokenizer,
|
829 |
+
metadata=metadata,
|
830 |
+
state=state,
|
831 |
+
causal_mask=causal_mask, # Pass the causal mask
|
832 |
+
warmup=True,
|
833 |
+
auto_prompt="who are you?"
|
834 |
+
)
|
835 |
|
836 |
# Main run
|
837 |
chat_loop(
|
|
|
841 |
tokenizer=tokenizer,
|
842 |
metadata=metadata,
|
843 |
state=state,
|
844 |
+
causal_mask=causal_mask, # Pass the causal mask
|
845 |
warmup=False,
|
846 |
auto_prompt=args.prompt
|
847 |
)
|