anemll commited on
Commit
6a310be
·
verified ·
1 Parent(s): 716cf94

Fixed GIL issue

Browse files

race condition between CoreML and causal_mask update

Files changed (1) hide show
  1. chat_full.py +122 -95
chat_full.py CHANGED
@@ -193,6 +193,89 @@ def load_model(path, function_name=None):
193
  print("\nTry using the .mlpackage version instead, or recompile the model.")
194
  raise
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def load_metadata(model,args):
197
  # Extract metadata and config parameters
198
  metadata = {}
@@ -248,18 +331,28 @@ def load_metadata(model,args):
248
  else:
249
  ctx_len = args.context_length
250
 
251
- # Use defaults
252
  metadata['context_length'] = ctx_len
253
  metadata['state_length'] = ctx_len
254
- metadata['batch_size'] = 64
 
255
  metadata['lut_bits'] = 4
256
- metadata['num_chunks'] = 4
257
- print("\nUsing default parameters:")
258
  print(f" Context Length: {metadata['context_length']}")
259
  print(f" State Length: {metadata['state_length']}")
260
  print(f" Prefill Batch Size: {metadata['batch_size']}")
261
  print(f" LUT Bits: {metadata['lut_bits']}")
262
  print(f" Number of Chunks: {metadata['num_chunks']}")
 
 
 
 
 
 
 
 
 
263
  return metadata
264
 
265
  def load_models(args,metadata):
@@ -381,7 +474,7 @@ def make_causal_mask(length, start):
381
  mask[:, :, col_indices <= (row_indices + start)] = 0
382
  return mask
383
 
384
- def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
385
  """Run prefill on the input sequence."""
386
  #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
387
 
@@ -406,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
406
  # Generate position IDs for this batch
407
  position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
408
 
409
- # Create causal mask for this batch
410
- causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
411
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
412
  batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
413
 
414
  # Run embeddings
@@ -432,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
432
 
433
  return torch.tensor([current_pos], dtype=torch.int32)
434
 
435
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
436
  """Generate the next token."""
437
  # Get current token
438
  current_token = input_ids[:, pos-1:pos]
@@ -447,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
447
  update_mask[0, 0, pos-1, 0] = 1.0
448
  position_ids = torch.tensor([pos-1], dtype=torch.int32)
449
 
450
- # Create causal mask for current position
451
- causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for generation
452
- single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
453
 
454
  # Run through FFN chunks
455
  for ffn_model in ffn_models:
@@ -498,6 +588,13 @@ def create_unified_state(ffn_models, context_length):
498
  print("\nCreated unified transformer state")
499
  return state
500
 
 
 
 
 
 
 
 
501
  def get_user_input():
502
  """Get input from user, handling special key combinations."""
503
  global THINKING_MODE
@@ -558,7 +655,7 @@ def get_user_input():
558
  # Fallback for systems without termios
559
  return input("> ")
560
 
561
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
562
  """Interactive chat loop."""
563
  global THINKING_MODE
564
  context_length = metadata.get('context_length')
@@ -650,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
650
  generation_start_time = time.time()
651
 
652
  try:
653
- # Create initial causal mask
654
- causal_mask = make_causal_mask(context_length, 0)
655
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
656
-
657
  # Run prefill on entire context
658
  current_pos = run_prefill(
659
  embed_model,
@@ -662,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
662
  context_pos,
663
  context_length,
664
  batch_size,
665
- state
 
666
  )
667
  #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
668
 
@@ -696,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
696
  new_size, # Prefill the entire shifted content
697
  context_length,
698
  batch_size,
699
- state
 
700
  )
701
 
702
  # Start generating from the next position
@@ -715,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
715
  input_ids,
716
  pos,
717
  context_length,
718
- state
 
719
  )
720
 
721
  # Add token
@@ -768,81 +864,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
768
  traceback.print_exc()
769
 
770
  def main():
771
- parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting (c) 2025 Anemll')
772
-
773
- # Add meta.yaml option
774
- parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
775
-
776
- # Add existing arguments
777
- parser.add_argument('--d', '--dir', type=str, default='.',
778
- help='Directory containing model files (default: current directory)')
779
- parser.add_argument('--embed', type=str, required=False,
780
- help='Path to embeddings model (relative to --dir)')
781
- parser.add_argument('--ffn', type=str, required=False,
782
- help='Path to FFN model (can be chunked, relative to --dir)')
783
- parser.add_argument('--lmhead', type=str, required=False,
784
- help='Path to LM head model (relative to --dir)')
785
- parser.add_argument('--tokenizer', type=str, required=False,
786
- help='Path to tokenizer')
787
-
788
- # Add new argument for auto-generation
789
- parser.add_argument('--prompt', type=str,
790
- help='If specified, run once with this prompt and exit')
791
-
792
- # Add no-warmup flag
793
- parser.add_argument('--nw', action='store_true',
794
- help='Skip warmup phase')
795
-
796
- # Model configuration
797
- parser.add_argument('--context-length', type=int,
798
- help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
799
-
800
- args = parser.parse_args()
801
-
802
- # If meta.yaml is provided, load parameters from it
803
- if args.meta:
804
- try:
805
- with open(args.meta, 'r') as f:
806
- meta = yaml.safe_load(f)
807
- params = meta['model_info']['parameters']
808
-
809
- # Set model directory to meta.yaml directory if not specified
810
- if not args.d or args.d == '.':
811
- args.d = str(Path(args.meta).parent)
812
-
813
- # Build model paths based on parameters
814
- prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
815
- lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
816
- lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
817
- num_chunks = int(params['num_chunks'])
818
-
819
- # Set model paths if not specified
820
- if not args.embed:
821
- args.embed = f'{prefix}_embeddings'
822
- if not args.lmhead:
823
- args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
824
- if not args.ffn:
825
- args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
826
- if not args.tokenizer:
827
- args.tokenizer = args.d
828
-
829
- # Set other parameters
830
- args.context_length = int(params['context_length'])
831
- args.batch_size = int(params['batch_size'])
832
- args.num_chunks = num_chunks
833
-
834
- print(f"\nLoaded parameters from {args.meta}:")
835
- print(f" Context Length: {args.context_length}")
836
- print(f" Batch Size: {args.batch_size}")
837
- print(f" Num Chunks: {args.num_chunks}")
838
- print(f" Models Directory: {args.d}")
839
- print(f" Embeddings: {args.embed}")
840
- print(f" LM Head: {args.lmhead}")
841
- print(f" FFN: {args.ffn}")
842
-
843
- except Exception as e:
844
- print(f"\nError loading meta.yaml: {str(e)}")
845
- sys.exit(1)
846
 
847
  # Convert directory to absolute path
848
  model_dir = Path(args.d).resolve()
@@ -892,6 +914,9 @@ def main():
892
  # Create unified state once
893
  state = create_unified_state(ffn_models, metadata['context_length'])
894
 
 
 
 
895
  # Warmup runs to prevent Python GIL issues with CoreML !
896
  if not args.nw:
897
  for i in range(2):
@@ -902,6 +927,7 @@ def main():
902
  tokenizer=tokenizer,
903
  metadata=metadata,
904
  state=state, # Pass the state
 
905
  warmup=True,
906
  auto_prompt="who are you?"
907
  )
@@ -914,6 +940,7 @@ def main():
914
  tokenizer=tokenizer,
915
  metadata=metadata,
916
  state=state, # Pass the state
 
917
  warmup=False,
918
  auto_prompt=args.prompt
919
  )
 
193
  print("\nTry using the .mlpackage version instead, or recompile the model.")
194
  raise
195
 
196
+ def parse_args():
197
+ parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
198
+
199
+ # Add meta.yaml option
200
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
201
+
202
+ # Add existing arguments
203
+ parser.add_argument('--d', '--dir', type=str, default='.',
204
+ help='Directory containing model files (default: current directory)')
205
+ parser.add_argument('--embed', type=str, required=False,
206
+ help='Path to embeddings model (relative to --dir)')
207
+ parser.add_argument('--ffn', type=str, required=False,
208
+ help='Path to FFN model (can be chunked, relative to --dir)')
209
+ parser.add_argument('--lmhead', type=str, required=False,
210
+ help='Path to LM head model (relative to --dir)')
211
+ parser.add_argument('--tokenizer', type=str, required=False,
212
+ help='Path to tokenizer')
213
+
214
+ # Add new argument for auto-generation
215
+ parser.add_argument('--prompt', type=str,
216
+ help='If specified, run once with this prompt and exit')
217
+
218
+ # Add no-warmup flag
219
+ parser.add_argument('--nw', action='store_true',
220
+ help='Skip warmup phase')
221
+
222
+ # Model configuration
223
+ parser.add_argument('--context-length', type=int,
224
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
225
+ parser.add_argument('--batch-size', type=int,
226
+ help='Batch size for prefill (default: 64)')
227
+
228
+ args = parser.parse_args()
229
+
230
+ # If meta.yaml is provided, load parameters from it
231
+ if args.meta:
232
+ try:
233
+ with open(args.meta, 'r') as f:
234
+ meta = yaml.safe_load(f)
235
+ params = meta['model_info']['parameters']
236
+
237
+ # Set model directory to meta.yaml directory if not specified
238
+ if not args.d or args.d == '.':
239
+ args.d = str(Path(args.meta).parent)
240
+
241
+ # Build model paths based on parameters
242
+ prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
243
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
244
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
245
+ num_chunks = int(params['num_chunks'])
246
+
247
+ # Set model paths if not specified
248
+ if not args.embed:
249
+ args.embed = f'{prefix}_embeddings'
250
+ if not args.lmhead:
251
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
252
+ if not args.ffn:
253
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
254
+ if not args.tokenizer:
255
+ args.tokenizer = args.d
256
+
257
+ # Set other parameters if not overridden by command line
258
+ if args.context_length is None:
259
+ args.context_length = int(params['context_length'])
260
+ if args.batch_size is None:
261
+ args.batch_size = int(params['batch_size'])
262
+ args.num_chunks = num_chunks
263
+
264
+ print(f"\nLoaded parameters from {args.meta}:")
265
+ print(f" Context Length: {args.context_length}")
266
+ print(f" Batch Size: {args.batch_size}")
267
+ print(f" Num Chunks: {args.num_chunks}")
268
+ print(f" Models Directory: {args.d}")
269
+ print(f" Embeddings: {args.embed}")
270
+ print(f" LM Head: {args.lmhead}")
271
+ print(f" FFN: {args.ffn}")
272
+
273
+ except Exception as e:
274
+ print(f"\nError loading meta.yaml: {str(e)}")
275
+ sys.exit(1)
276
+
277
+ return args
278
+
279
  def load_metadata(model,args):
280
  # Extract metadata and config parameters
281
  metadata = {}
 
331
  else:
332
  ctx_len = args.context_length
333
 
334
+ # Use defaults or values from args
335
  metadata['context_length'] = ctx_len
336
  metadata['state_length'] = ctx_len
337
+ # Get batch size from args or use default
338
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
339
  metadata['lut_bits'] = 4
340
+ metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
341
+ print("\nUsing parameters:")
342
  print(f" Context Length: {metadata['context_length']}")
343
  print(f" State Length: {metadata['state_length']}")
344
  print(f" Prefill Batch Size: {metadata['batch_size']}")
345
  print(f" LUT Bits: {metadata['lut_bits']}")
346
  print(f" Number of Chunks: {metadata['num_chunks']}")
347
+
348
+ # Override with values from args if they exist
349
+ if hasattr(args, 'batch_size') and args.batch_size is not None:
350
+ metadata['batch_size'] = args.batch_size
351
+ print(f"\nOverriding batch size from args: {args.batch_size}")
352
+ if hasattr(args, 'num_chunks') and args.num_chunks is not None:
353
+ metadata['num_chunks'] = args.num_chunks
354
+ print(f"\nOverriding num chunks from args: {args.num_chunks}")
355
+
356
  return metadata
357
 
358
  def load_models(args,metadata):
 
474
  mask[:, :, col_indices <= (row_indices + start)] = 0
475
  return mask
476
 
477
+ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
478
  """Run prefill on the input sequence."""
479
  #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
480
 
 
499
  # Generate position IDs for this batch
500
  position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
501
 
502
+ # Use the pre-initialized causal mask and extract the batch portion
 
 
503
  batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
504
 
505
  # Run embeddings
 
523
 
524
  return torch.tensor([current_pos], dtype=torch.int32)
525
 
526
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
527
  """Generate the next token."""
528
  # Get current token
529
  current_token = input_ids[:, pos-1:pos]
 
538
  update_mask[0, 0, pos-1, 0] = 1.0
539
  position_ids = torch.tensor([pos-1], dtype=torch.int32)
540
 
541
+ # Use the pre-initialized causal mask and extract the single position portion
542
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
 
543
 
544
  # Run through FFN chunks
545
  for ffn_model in ffn_models:
 
588
  print("\nCreated unified transformer state")
589
  return state
590
 
591
+ def initialize_causal_mask(context_length):
592
+ """Initialize causal mask for transformer attention."""
593
+ causal_mask = make_causal_mask(context_length, 0)
594
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
595
+ print(f"\nInitialized causal mask for context length {context_length}")
596
+ return causal_mask
597
+
598
  def get_user_input():
599
  """Get input from user, handling special key combinations."""
600
  global THINKING_MODE
 
655
  # Fallback for systems without termios
656
  return input("> ")
657
 
658
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
659
  """Interactive chat loop."""
660
  global THINKING_MODE
661
  context_length = metadata.get('context_length')
 
747
  generation_start_time = time.time()
748
 
749
  try:
 
 
 
 
750
  # Run prefill on entire context
751
  current_pos = run_prefill(
752
  embed_model,
 
755
  context_pos,
756
  context_length,
757
  batch_size,
758
+ state,
759
+ causal_mask
760
  )
761
  #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
762
 
 
790
  new_size, # Prefill the entire shifted content
791
  context_length,
792
  batch_size,
793
+ state,
794
+ causal_mask
795
  )
796
 
797
  # Start generating from the next position
 
810
  input_ids,
811
  pos,
812
  context_length,
813
+ state,
814
+ causal_mask
815
  )
816
 
817
  # Add token
 
864
  traceback.print_exc()
865
 
866
  def main():
867
+ args = parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  # Convert directory to absolute path
870
  model_dir = Path(args.d).resolve()
 
914
  # Create unified state once
915
  state = create_unified_state(ffn_models, metadata['context_length'])
916
 
917
+ # Initialize causal mask once
918
+ causal_mask = initialize_causal_mask(metadata['context_length'])
919
+
920
  # Warmup runs to prevent Python GIL issues with CoreML !
921
  if not args.nw:
922
  for i in range(2):
 
927
  tokenizer=tokenizer,
928
  metadata=metadata,
929
  state=state, # Pass the state
930
+ causal_mask=causal_mask, # Pass the causal mask
931
  warmup=True,
932
  auto_prompt="who are you?"
933
  )
 
940
  tokenizer=tokenizer,
941
  metadata=metadata,
942
  state=state, # Pass the state
943
+ causal_mask=causal_mask, # Pass the causal mask
944
  warmup=False,
945
  auto_prompt=args.prompt
946
  )