anemll commited on
Commit
3f3158f
·
verified ·
1 Parent(s): 3315517

Fixed GIL issue

Browse files

race condition between CoreML and causal_mask update

Files changed (1) hide show
  1. chat_full.py +219 -116
chat_full.py CHANGED
@@ -28,6 +28,8 @@ RESET_COLOR = "\033[0m"
28
 
29
  # Add at the top with other constants
30
  WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
 
 
31
 
32
  class TokenPrinter:
33
  """Handles background printing of generated tokens."""
@@ -191,6 +193,89 @@ def load_model(path, function_name=None):
191
  print("\nTry using the .mlpackage version instead, or recompile the model.")
192
  raise
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def load_metadata(model,args):
195
  # Extract metadata and config parameters
196
  metadata = {}
@@ -246,18 +331,28 @@ def load_metadata(model,args):
246
  else:
247
  ctx_len = args.context_length
248
 
249
- # Use defaults
250
  metadata['context_length'] = ctx_len
251
  metadata['state_length'] = ctx_len
252
- metadata['batch_size'] = 64
 
253
  metadata['lut_bits'] = 4
254
- metadata['num_chunks'] = 4
255
- print("\nUsing default parameters:")
256
  print(f" Context Length: {metadata['context_length']}")
257
  print(f" State Length: {metadata['state_length']}")
258
  print(f" Prefill Batch Size: {metadata['batch_size']}")
259
  print(f" LUT Bits: {metadata['lut_bits']}")
260
  print(f" Number of Chunks: {metadata['num_chunks']}")
 
 
 
 
 
 
 
 
 
261
  return metadata
262
 
263
  def load_models(args,metadata):
@@ -379,7 +474,7 @@ def make_causal_mask(length, start):
379
  mask[:, :, col_indices <= (row_indices + start)] = 0
380
  return mask
381
 
382
- def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
383
  """Run prefill on the input sequence."""
384
  #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
385
 
@@ -404,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
404
  # Generate position IDs for this batch
405
  position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
406
 
407
- # Create causal mask for this batch
408
- causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
409
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
410
  batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
411
 
412
  # Run embeddings
@@ -430,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
430
 
431
  return torch.tensor([current_pos], dtype=torch.int32)
432
 
433
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
434
  """Generate the next token."""
435
  # Get current token
436
  current_token = input_ids[:, pos-1:pos]
@@ -445,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
445
  update_mask[0, 0, pos-1, 0] = 1.0
446
  position_ids = torch.tensor([pos-1], dtype=torch.int32)
447
 
448
- # Create causal mask for current position
449
- causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for generation
450
- single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
451
 
452
  # Run through FFN chunks
453
  for ffn_model in ffn_models:
@@ -496,23 +588,84 @@ def create_unified_state(ffn_models, context_length):
496
  print("\nCreated unified transformer state")
497
  return state
498
 
 
 
 
 
 
 
 
499
  def get_user_input():
500
- sys.stdout.write(f"\n{LIGHT_GREEN}You:{RESET_COLOR} ")
501
- sys.stdout.flush()
502
- line = sys.stdin.readline()
503
- if not line:
504
- raise EOFError
505
- return line.rstrip('\n')
506
-
507
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  """Interactive chat loop."""
 
509
  context_length = metadata.get('context_length')
510
  batch_size = metadata.get('batch_size', 64)
511
 
512
  if not warmup:
513
  print(f"\nUsing context length: {context_length}")
514
  print("\nStarting chat session. Press Ctrl+D to exit.")
515
- print("Type your message and press Enter to chat.")
 
516
 
517
  # Keep track of conversation history
518
  conversation = []
@@ -521,7 +674,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
521
  while True:
522
  try:
523
  if not warmup:
524
- print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
525
  if auto_prompt is not None:
526
  user_input = auto_prompt
527
  if not warmup:
@@ -535,16 +688,31 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
535
 
536
  if not user_input:
537
  continue
 
 
 
 
 
 
538
 
539
  # Add user message to conversation
540
  conversation.append({"role": "user", "content": user_input})
541
 
542
  # Format using chat template with full history
543
- base_input_ids = tokenizer.apply_chat_template(
544
- conversation,
545
- return_tensors="pt",
546
- add_generation_prompt=True
547
- ).to(torch.int32)
 
 
 
 
 
 
 
 
 
548
 
549
  # Check if we need to trim history
550
  while base_input_ids.size(1) > context_length - 100: # Leave room for response
@@ -579,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
579
  generation_start_time = time.time()
580
 
581
  try:
582
- # Create initial causal mask
583
- causal_mask = make_causal_mask(context_length, 0)
584
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
585
-
586
  # Run prefill on entire context
587
  current_pos = run_prefill(
588
  embed_model,
@@ -591,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
591
  context_pos,
592
  context_length,
593
  batch_size,
594
- state
 
595
  )
596
  #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
597
 
@@ -625,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
625
  new_size, # Prefill the entire shifted content
626
  context_length,
627
  batch_size,
628
- state
 
629
  )
630
 
631
  # Start generating from the next position
@@ -644,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
644
  input_ids,
645
  pos,
646
  context_length,
647
- state
 
648
  )
649
 
650
  # Add token
@@ -697,77 +864,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
697
  traceback.print_exc()
698
 
699
  def main():
700
- parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting (c) 2025 Anemll')
701
-
702
- # Add meta.yaml option
703
- parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
704
-
705
- # Add existing arguments
706
- parser.add_argument('--d', '--dir', type=str, default='.',
707
- help='Directory containing model files (default: current directory)')
708
- parser.add_argument('--embed', type=str, required=False,
709
- help='Path to embeddings model (relative to --dir)')
710
- parser.add_argument('--ffn', type=str, required=False,
711
- help='Path to FFN model (can be chunked, relative to --dir)')
712
- parser.add_argument('--lmhead', type=str, required=False,
713
- help='Path to LM head model (relative to --dir)')
714
- parser.add_argument('--tokenizer', type=str, required=False,
715
- help='Path to tokenizer')
716
-
717
- # Add new argument for auto-generation
718
- parser.add_argument('--prompt', type=str,
719
- help='If specified, run once with this prompt and exit')
720
-
721
- # Model configuration
722
- parser.add_argument('--context-length', type=int,
723
- help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
724
-
725
- args = parser.parse_args()
726
-
727
- # If meta.yaml is provided, load parameters from it
728
- if args.meta:
729
- try:
730
- with open(args.meta, 'r') as f:
731
- meta = yaml.safe_load(f)
732
- params = meta['model_info']['parameters']
733
-
734
- # Set model directory to meta.yaml directory if not specified
735
- if not args.d or args.d == '.':
736
- args.d = str(Path(args.meta).parent)
737
-
738
- # Build model paths based on parameters
739
- prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
740
- lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
741
- lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
742
- num_chunks = int(params['num_chunks'])
743
-
744
- # Set model paths if not specified
745
- if not args.embed:
746
- args.embed = f'{prefix}_embeddings'
747
- if not args.lmhead:
748
- args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
749
- if not args.ffn:
750
- args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
751
- if not args.tokenizer:
752
- args.tokenizer = args.d
753
-
754
- # Set other parameters
755
- args.context_length = int(params['context_length'])
756
- args.batch_size = int(params['batch_size'])
757
- args.num_chunks = num_chunks
758
-
759
- print(f"\nLoaded parameters from {args.meta}:")
760
- print(f" Context Length: {args.context_length}")
761
- print(f" Batch Size: {args.batch_size}")
762
- print(f" Num Chunks: {args.num_chunks}")
763
- print(f" Models Directory: {args.d}")
764
- print(f" Embeddings: {args.embed}")
765
- print(f" LM Head: {args.lmhead}")
766
- print(f" FFN: {args.ffn}")
767
-
768
- except Exception as e:
769
- print(f"\nError loading meta.yaml: {str(e)}")
770
- sys.exit(1)
771
 
772
  # Convert directory to absolute path
773
  model_dir = Path(args.d).resolve()
@@ -817,18 +914,23 @@ def main():
817
  # Create unified state once
818
  state = create_unified_state(ffn_models, metadata['context_length'])
819
 
 
 
 
820
  # Warmup runs to prevent Python GIL issues with CoreML !
821
- for i in range(2):
822
- chat_loop(
823
- embed_model=embed_model,
824
- ffn_models=ffn_models,
825
- lmhead_model=lmhead_model,
826
- tokenizer=tokenizer,
827
- metadata=metadata,
828
- state=state, # Pass the state
829
- warmup=True,
830
- auto_prompt="who are you?"
831
- )
 
 
832
 
833
  # Main run
834
  chat_loop(
@@ -838,6 +940,7 @@ def main():
838
  tokenizer=tokenizer,
839
  metadata=metadata,
840
  state=state, # Pass the state
 
841
  warmup=False,
842
  auto_prompt=args.prompt
843
  )
 
28
 
29
  # Add at the top with other constants
30
  WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
31
+ THINKING_MODE = False
32
+ THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
33
 
34
  class TokenPrinter:
35
  """Handles background printing of generated tokens."""
 
193
  print("\nTry using the .mlpackage version instead, or recompile the model.")
194
  raise
195
 
196
+ def parse_args():
197
+ parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
198
+
199
+ # Add meta.yaml option
200
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
201
+
202
+ # Add existing arguments
203
+ parser.add_argument('--d', '--dir', type=str, default='.',
204
+ help='Directory containing model files (default: current directory)')
205
+ parser.add_argument('--embed', type=str, required=False,
206
+ help='Path to embeddings model (relative to --dir)')
207
+ parser.add_argument('--ffn', type=str, required=False,
208
+ help='Path to FFN model (can be chunked, relative to --dir)')
209
+ parser.add_argument('--lmhead', type=str, required=False,
210
+ help='Path to LM head model (relative to --dir)')
211
+ parser.add_argument('--tokenizer', type=str, required=False,
212
+ help='Path to tokenizer')
213
+
214
+ # Add new argument for auto-generation
215
+ parser.add_argument('--prompt', type=str,
216
+ help='If specified, run once with this prompt and exit')
217
+
218
+ # Add no-warmup flag
219
+ parser.add_argument('--nw', action='store_true',
220
+ help='Skip warmup phase')
221
+
222
+ # Model configuration
223
+ parser.add_argument('--context-length', type=int,
224
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
225
+ parser.add_argument('--batch-size', type=int,
226
+ help='Batch size for prefill (default: 64)')
227
+
228
+ args = parser.parse_args()
229
+
230
+ # If meta.yaml is provided, load parameters from it
231
+ if args.meta:
232
+ try:
233
+ with open(args.meta, 'r') as f:
234
+ meta = yaml.safe_load(f)
235
+ params = meta['model_info']['parameters']
236
+
237
+ # Set model directory to meta.yaml directory if not specified
238
+ if not args.d or args.d == '.':
239
+ args.d = str(Path(args.meta).parent)
240
+
241
+ # Build model paths based on parameters
242
+ prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
243
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
244
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
245
+ num_chunks = int(params['num_chunks'])
246
+
247
+ # Set model paths if not specified
248
+ if not args.embed:
249
+ args.embed = f'{prefix}_embeddings'
250
+ if not args.lmhead:
251
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
252
+ if not args.ffn:
253
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
254
+ if not args.tokenizer:
255
+ args.tokenizer = args.d
256
+
257
+ # Set other parameters if not overridden by command line
258
+ if args.context_length is None:
259
+ args.context_length = int(params['context_length'])
260
+ if args.batch_size is None:
261
+ args.batch_size = int(params['batch_size'])
262
+ args.num_chunks = num_chunks
263
+
264
+ print(f"\nLoaded parameters from {args.meta}:")
265
+ print(f" Context Length: {args.context_length}")
266
+ print(f" Batch Size: {args.batch_size}")
267
+ print(f" Num Chunks: {args.num_chunks}")
268
+ print(f" Models Directory: {args.d}")
269
+ print(f" Embeddings: {args.embed}")
270
+ print(f" LM Head: {args.lmhead}")
271
+ print(f" FFN: {args.ffn}")
272
+
273
+ except Exception as e:
274
+ print(f"\nError loading meta.yaml: {str(e)}")
275
+ sys.exit(1)
276
+
277
+ return args
278
+
279
  def load_metadata(model,args):
280
  # Extract metadata and config parameters
281
  metadata = {}
 
331
  else:
332
  ctx_len = args.context_length
333
 
334
+ # Use defaults or values from args
335
  metadata['context_length'] = ctx_len
336
  metadata['state_length'] = ctx_len
337
+ # Get batch size from args or use default
338
+ metadata['batch_size'] = getattr(args, 'batch_size', 64)
339
  metadata['lut_bits'] = 4
340
+ metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
341
+ print("\nUsing parameters:")
342
  print(f" Context Length: {metadata['context_length']}")
343
  print(f" State Length: {metadata['state_length']}")
344
  print(f" Prefill Batch Size: {metadata['batch_size']}")
345
  print(f" LUT Bits: {metadata['lut_bits']}")
346
  print(f" Number of Chunks: {metadata['num_chunks']}")
347
+
348
+ # Override with values from args if they exist
349
+ if hasattr(args, 'batch_size') and args.batch_size is not None:
350
+ metadata['batch_size'] = args.batch_size
351
+ print(f"\nOverriding batch size from args: {args.batch_size}")
352
+ if hasattr(args, 'num_chunks') and args.num_chunks is not None:
353
+ metadata['num_chunks'] = args.num_chunks
354
+ print(f"\nOverriding num chunks from args: {args.num_chunks}")
355
+
356
  return metadata
357
 
358
  def load_models(args,metadata):
 
474
  mask[:, :, col_indices <= (row_indices + start)] = 0
475
  return mask
476
 
477
+ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
478
  """Run prefill on the input sequence."""
479
  #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
480
 
 
499
  # Generate position IDs for this batch
500
  position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
501
 
502
+ # Use the pre-initialized causal mask and extract the batch portion
 
 
503
  batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
504
 
505
  # Run embeddings
 
523
 
524
  return torch.tensor([current_pos], dtype=torch.int32)
525
 
526
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
527
  """Generate the next token."""
528
  # Get current token
529
  current_token = input_ids[:, pos-1:pos]
 
538
  update_mask[0, 0, pos-1, 0] = 1.0
539
  position_ids = torch.tensor([pos-1], dtype=torch.int32)
540
 
541
+ # Use the pre-initialized causal mask and extract the single position portion
542
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
 
543
 
544
  # Run through FFN chunks
545
  for ffn_model in ffn_models:
 
588
  print("\nCreated unified transformer state")
589
  return state
590
 
591
+ def initialize_causal_mask(context_length):
592
+ """Initialize causal mask for transformer attention."""
593
+ causal_mask = make_causal_mask(context_length, 0)
594
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
595
+ print(f"\nInitialized causal mask for context length {context_length}")
596
+ return causal_mask
597
+
598
  def get_user_input():
599
+ """Get input from user, handling special key combinations."""
600
+ global THINKING_MODE
601
+ try:
602
+ import termios
603
+ import tty
604
+ import sys
605
+
606
+ def _getch():
607
+ fd = sys.stdin.fileno()
608
+ old_settings = termios.tcgetattr(fd)
609
+ try:
610
+ tty.setraw(sys.stdin.fileno())
611
+ ch = sys.stdin.read(1)
612
+ finally:
613
+ termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
614
+ return ch
615
+
616
+ buffer = []
617
+ while True:
618
+ char = _getch()
619
+
620
+ # Debug: print the character code
621
+ print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
622
+
623
+ # Check for Enter key
624
+ if char == '\r' or char == '\n':
625
+ print() # Move to next line
626
+ input_text = ''.join(buffer)
627
+ # Check if the command is /t
628
+ if input_text == '/t':
629
+ THINKING_MODE = not THINKING_MODE
630
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
631
+ buffer = [] # Clear buffer
632
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
633
+ continue
634
+ return input_text
635
+
636
+ # Handle backspace
637
+ if char == '\x7f': # backspace
638
+ if buffer:
639
+ buffer.pop()
640
+ sys.stdout.write('\b \b') # Erase character
641
+ sys.stdout.flush()
642
+ continue
643
+
644
+ # Handle Ctrl-C
645
+ if char == '\x03': # Ctrl-C
646
+ print("^C")
647
+ raise KeyboardInterrupt
648
+
649
+ # Print character and add to buffer
650
+ sys.stdout.write(char)
651
+ sys.stdout.flush()
652
+ buffer.append(char)
653
+
654
+ except ImportError:
655
+ # Fallback for systems without termios
656
+ return input("> ")
657
+
658
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
659
  """Interactive chat loop."""
660
+ global THINKING_MODE
661
  context_length = metadata.get('context_length')
662
  batch_size = metadata.get('batch_size', 64)
663
 
664
  if not warmup:
665
  print(f"\nUsing context length: {context_length}")
666
  print("\nStarting chat session. Press Ctrl+D to exit.")
667
+ print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
668
+ print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
669
 
670
  # Keep track of conversation history
671
  conversation = []
 
674
  while True:
675
  try:
676
  if not warmup:
677
+ print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
678
  if auto_prompt is not None:
679
  user_input = auto_prompt
680
  if not warmup:
 
688
 
689
  if not user_input:
690
  continue
691
+
692
+ # Handle /t command
693
+ if user_input == "/t":
694
+ THINKING_MODE = not THINKING_MODE
695
+ print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
696
+ continue
697
 
698
  # Add user message to conversation
699
  conversation.append({"role": "user", "content": user_input})
700
 
701
  # Format using chat template with full history
702
+ if THINKING_MODE:
703
+ # Add thinking prompt to system message
704
+ conversation_with_thinking = [{"role": "system", "content": THINKING_PROMPT}] + conversation
705
+ base_input_ids = tokenizer.apply_chat_template(
706
+ conversation_with_thinking,
707
+ return_tensors="pt",
708
+ add_generation_prompt=True
709
+ ).to(torch.int32)
710
+ else:
711
+ base_input_ids = tokenizer.apply_chat_template(
712
+ conversation,
713
+ return_tensors="pt",
714
+ add_generation_prompt=True
715
+ ).to(torch.int32)
716
 
717
  # Check if we need to trim history
718
  while base_input_ids.size(1) > context_length - 100: # Leave room for response
 
747
  generation_start_time = time.time()
748
 
749
  try:
 
 
 
 
750
  # Run prefill on entire context
751
  current_pos = run_prefill(
752
  embed_model,
 
755
  context_pos,
756
  context_length,
757
  batch_size,
758
+ state,
759
+ causal_mask
760
  )
761
  #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
762
 
 
790
  new_size, # Prefill the entire shifted content
791
  context_length,
792
  batch_size,
793
+ state,
794
+ causal_mask
795
  )
796
 
797
  # Start generating from the next position
 
810
  input_ids,
811
  pos,
812
  context_length,
813
+ state,
814
+ causal_mask
815
  )
816
 
817
  # Add token
 
864
  traceback.print_exc()
865
 
866
  def main():
867
+ args = parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  # Convert directory to absolute path
870
  model_dir = Path(args.d).resolve()
 
914
  # Create unified state once
915
  state = create_unified_state(ffn_models, metadata['context_length'])
916
 
917
+ # Initialize causal mask once
918
+ causal_mask = initialize_causal_mask(metadata['context_length'])
919
+
920
  # Warmup runs to prevent Python GIL issues with CoreML !
921
+ if not args.nw:
922
+ for i in range(2):
923
+ chat_loop(
924
+ embed_model=embed_model,
925
+ ffn_models=ffn_models,
926
+ lmhead_model=lmhead_model,
927
+ tokenizer=tokenizer,
928
+ metadata=metadata,
929
+ state=state, # Pass the state
930
+ causal_mask=causal_mask, # Pass the causal mask
931
+ warmup=True,
932
+ auto_prompt="who are you?"
933
+ )
934
 
935
  # Main run
936
  chat_loop(
 
940
  tokenizer=tokenizer,
941
  metadata=metadata,
942
  state=state, # Pass the state
943
+ causal_mask=causal_mask, # Pass the causal mask
944
  warmup=False,
945
  auto_prompt=args.prompt
946
  )