anemll commited on
Commit
3315517
·
verified ·
1 Parent(s): deacd1a

Fixed GIL issue

Browse files

race condition between CoreML and causal_mask update

Files changed (1) hide show
  1. chat.py +70 -31
chat.py CHANGED
@@ -243,18 +243,28 @@ def load_metadata(model,args):
243
  else:
244
  ctx_len = args.context_length
245
 
246
- # Use defaults
247
  metadata['context_length'] = ctx_len
248
  metadata['state_length'] = ctx_len
249
- metadata['batch_size'] = 64
 
250
  metadata['lut_bits'] = 4
251
- metadata['num_chunks'] = 4
252
- print("\nUsing default parameters:")
253
  print(f" Context Length: {metadata['context_length']}")
254
  print(f" State Length: {metadata['state_length']}")
255
  print(f" Prefill Batch Size: {metadata['batch_size']}")
256
  print(f" LUT Bits: {metadata['lut_bits']}")
257
  print(f" Number of Chunks: {metadata['num_chunks']}")
 
 
 
 
 
 
 
 
 
258
  return metadata
259
 
260
  def load_models(args,metadata):
@@ -376,11 +386,19 @@ def make_causal_mask(length, start):
376
  mask[:, :, col_indices <= (row_indices + start)] = 0
377
  return mask
378
 
379
- def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None):
380
- """Run prefill on the input sequence."""
381
- # Create causal mask
382
  causal_mask = make_causal_mask(context_length, 0)
383
  causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
 
 
 
 
 
 
 
 
 
384
 
385
  # Process in batches
386
  batch_pos = 0
@@ -423,7 +441,7 @@ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length,
423
 
424
  return torch.tensor([context_pos], dtype=torch.int32)
425
 
426
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
427
  """Generate the next token."""
428
  # Get current token
429
  current_token = input_ids[:, pos-1:pos] # [1, 1]
@@ -437,8 +455,13 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
437
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
438
  update_mask[0, 0, pos-1, 0] = 1.0
439
  position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
440
- causal_mask = make_causal_mask(context_length, 0)
441
- causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
 
 
 
 
 
442
 
443
  # Run through FFN chunks with state
444
  for ffn_model in ffn_models:
@@ -447,7 +470,7 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
447
  'hidden_states': hidden_states.numpy(),
448
  'update_mask': update_mask.numpy(),
449
  'position_ids': position_ids.numpy(),
450
- 'causal_mask': causal_mask.numpy(),
451
  'current_pos': position_ids.numpy()
452
  }
453
  output = ffn_model['infer'].predict(inputs, state)
@@ -493,7 +516,7 @@ def create_unified_state(ffn_models, context_length):
493
  print("\nCreated unified transformer state")
494
  return state
495
 
496
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
497
  """Interactive chat loop."""
498
  context_length = metadata.get('context_length')
499
  batch_size = metadata.get('batch_size', 64)
@@ -567,7 +590,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
567
  # Start prefill timing
568
  prefill_start = time.time()
569
 
570
- # Run prefill with state
571
  current_pos = run_prefill(
572
  embed_model,
573
  ffn_models,
@@ -575,7 +598,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
575
  context_pos,
576
  context_length,
577
  batch_size,
578
- state
 
579
  )
580
 
581
  # Calculate prefill timing
@@ -590,7 +614,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
590
  inference_tokens = 0
591
 
592
  while pos < context_length - 1:
593
- # Generate next token
594
  next_token = generate_next_token(
595
  embed_model,
596
  ffn_models,
@@ -598,7 +622,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
598
  input_ids,
599
  pos,
600
  context_length,
601
- state
 
602
  )
603
 
604
  # Add token to sequence
@@ -657,7 +682,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
657
  traceback.print_exc()
658
 
659
  def parse_args():
660
- parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
661
 
662
  # Add meta.yaml option
663
  parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
@@ -678,9 +703,15 @@ def parse_args():
678
  parser.add_argument('--prompt', type=str,
679
  help='If specified, run once with this prompt and exit')
680
 
 
 
 
 
681
  # Model configuration
682
  parser.add_argument('--context-length', type=int,
683
  help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
 
 
684
 
685
  args = parser.parse_args()
686
 
@@ -711,9 +742,11 @@ def parse_args():
711
  if not args.tokenizer:
712
  args.tokenizer = args.d
713
 
714
- # Set other parameters
715
- args.context_length = int(params['context_length'])
716
- args.batch_size = int(params['batch_size'])
 
 
717
  args.num_chunks = num_chunks
718
 
719
  print(f"\nLoaded parameters from {args.meta}:")
@@ -782,18 +815,23 @@ def main():
782
  # Create unified state once
783
  state = create_unified_state(ffn_models, metadata['context_length'])
784
 
 
 
 
785
  # Warmup runs to prevent Python GIL issues with CoreML !
786
- for i in range(2):
787
- chat_loop(
788
- embed_model=embed_model,
789
- ffn_models=ffn_models,
790
- lmhead_model=lmhead_model,
791
- tokenizer=tokenizer,
792
- metadata=metadata,
793
- state=state,
794
- warmup=True,
795
- auto_prompt="who are you?"
796
- )
 
 
797
 
798
  # Main run
799
  chat_loop(
@@ -803,6 +841,7 @@ def main():
803
  tokenizer=tokenizer,
804
  metadata=metadata,
805
  state=state,
 
806
  warmup=False,
807
  auto_prompt=args.prompt
808
  )
 
243
  else:
244
  ctx_len = args.context_length
245
 
246
+ # Use defaults or values from args
247
  metadata['context_length'] = ctx_len
248
  metadata['state_length'] = ctx_len
249
+ # Get batch size from args or use default
250
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
251
  metadata['lut_bits'] = 4
252
+ metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
253
+ print("\nUsing parameters:")
254
  print(f" Context Length: {metadata['context_length']}")
255
  print(f" State Length: {metadata['state_length']}")
256
  print(f" Prefill Batch Size: {metadata['batch_size']}")
257
  print(f" LUT Bits: {metadata['lut_bits']}")
258
  print(f" Number of Chunks: {metadata['num_chunks']}")
259
+
260
+ # Override with values from args if they exist
261
+ if hasattr(args, 'batch_size') and args.batch_size is not None:
262
+ metadata['batch_size'] = args.batch_size
263
+ print(f"\nOverriding batch size from args: {args.batch_size}")
264
+ if hasattr(args, 'num_chunks') and args.num_chunks is not None:
265
+ metadata['num_chunks'] = args.num_chunks
266
+ print(f"\nOverriding num chunks from args: {args.num_chunks}")
267
+
268
  return metadata
269
 
270
  def load_models(args,metadata):
 
386
  mask[:, :, col_indices <= (row_indices + start)] = 0
387
  return mask
388
 
389
+ def initialize_causal_mask(context_length):
390
+ """Initialize causal mask for transformer attention."""
 
391
  causal_mask = make_causal_mask(context_length, 0)
392
  causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
393
+ print(f"\nInitialized causal mask for context length {context_length}")
394
+ return causal_mask
395
+
396
+ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
397
+ """Run prefill on the input sequence."""
398
+ # Use provided causal mask or create one if not provided
399
+ if causal_mask is None:
400
+ causal_mask = make_causal_mask(context_length, 0)
401
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
402
 
403
  # Process in batches
404
  batch_pos = 0
 
441
 
442
  return torch.tensor([context_pos], dtype=torch.int32)
443
 
444
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
445
  """Generate the next token."""
446
  # Get current token
447
  current_token = input_ids[:, pos-1:pos] # [1, 1]
 
455
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
456
  update_mask[0, 0, pos-1, 0] = 1.0
457
  position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
458
+
459
+ # Use provided causal mask or create one if not provided
460
+ if causal_mask is None:
461
+ causal_mask_data = make_causal_mask(context_length, 0)
462
+ single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
463
+ else:
464
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
465
 
466
  # Run through FFN chunks with state
467
  for ffn_model in ffn_models:
 
470
  'hidden_states': hidden_states.numpy(),
471
  'update_mask': update_mask.numpy(),
472
  'position_ids': position_ids.numpy(),
473
+ 'causal_mask': single_causal_mask.numpy(),
474
  'current_pos': position_ids.numpy()
475
  }
476
  output = ffn_model['infer'].predict(inputs, state)
 
516
  print("\nCreated unified transformer state")
517
  return state
518
 
519
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
520
  """Interactive chat loop."""
521
  context_length = metadata.get('context_length')
522
  batch_size = metadata.get('batch_size', 64)
 
590
  # Start prefill timing
591
  prefill_start = time.time()
592
 
593
+ # Run prefill with state and causal mask
594
  current_pos = run_prefill(
595
  embed_model,
596
  ffn_models,
 
598
  context_pos,
599
  context_length,
600
  batch_size,
601
+ state,
602
+ causal_mask
603
  )
604
 
605
  # Calculate prefill timing
 
614
  inference_tokens = 0
615
 
616
  while pos < context_length - 1:
617
+ # Generate next token with causal mask
618
  next_token = generate_next_token(
619
  embed_model,
620
  ffn_models,
 
622
  input_ids,
623
  pos,
624
  context_length,
625
+ state,
626
+ causal_mask
627
  )
628
 
629
  # Add token to sequence
 
682
  traceback.print_exc()
683
 
684
  def parse_args():
685
+ parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
686
 
687
  # Add meta.yaml option
688
  parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
 
703
  parser.add_argument('--prompt', type=str,
704
  help='If specified, run once with this prompt and exit')
705
 
706
+ # Add no-warmup flag
707
+ parser.add_argument('--nw', action='store_true',
708
+ help='Skip warmup phase')
709
+
710
  # Model configuration
711
  parser.add_argument('--context-length', type=int,
712
  help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
713
+ parser.add_argument('--batch-size', type=int,
714
+ help='Batch size for prefill (default: 64)')
715
 
716
  args = parser.parse_args()
717
 
 
742
  if not args.tokenizer:
743
  args.tokenizer = args.d
744
 
745
+ # Set other parameters if not overridden by command line
746
+ if args.context_length is None:
747
+ args.context_length = int(params['context_length'])
748
+ if args.batch_size is None:
749
+ args.batch_size = int(params['batch_size'])
750
  args.num_chunks = num_chunks
751
 
752
  print(f"\nLoaded parameters from {args.meta}:")
 
815
  # Create unified state once
816
  state = create_unified_state(ffn_models, metadata['context_length'])
817
 
818
+ # Initialize causal mask once
819
+ causal_mask = initialize_causal_mask(metadata['context_length'])
820
+
821
  # Warmup runs to prevent Python GIL issues with CoreML !
822
+ if not args.nw:
823
+ for i in range(2):
824
+ chat_loop(
825
+ embed_model=embed_model,
826
+ ffn_models=ffn_models,
827
+ lmhead_model=lmhead_model,
828
+ tokenizer=tokenizer,
829
+ metadata=metadata,
830
+ state=state,
831
+ causal_mask=causal_mask, # Pass the causal mask
832
+ warmup=True,
833
+ auto_prompt="who are you?"
834
+ )
835
 
836
  # Main run
837
  chat_loop(
 
841
  tokenizer=tokenizer,
842
  metadata=metadata,
843
  state=state,
844
+ causal_mask=causal_mask, # Pass the causal mask
845
  warmup=False,
846
  auto_prompt=args.prompt
847
  )