Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
- chat_full.py +122 -95
chat_full.py
CHANGED
@@ -193,6 +193,89 @@ def load_model(path, function_name=None):
|
|
193 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
194 |
raise
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
def load_metadata(model,args):
|
197 |
# Extract metadata and config parameters
|
198 |
metadata = {}
|
@@ -248,18 +331,28 @@ def load_metadata(model,args):
|
|
248 |
else:
|
249 |
ctx_len = args.context_length
|
250 |
|
251 |
-
# Use defaults
|
252 |
metadata['context_length'] = ctx_len
|
253 |
metadata['state_length'] = ctx_len
|
254 |
-
|
|
|
255 |
metadata['lut_bits'] = 4
|
256 |
-
metadata['num_chunks'] = 4
|
257 |
-
print("\nUsing
|
258 |
print(f" Context Length: {metadata['context_length']}")
|
259 |
print(f" State Length: {metadata['state_length']}")
|
260 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
261 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
262 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
return metadata
|
264 |
|
265 |
def load_models(args,metadata):
|
@@ -381,7 +474,7 @@ def make_causal_mask(length, start):
|
|
381 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
382 |
return mask
|
383 |
|
384 |
-
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
|
385 |
"""Run prefill on the input sequence."""
|
386 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
387 |
|
@@ -406,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
406 |
# Generate position IDs for this batch
|
407 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
408 |
|
409 |
-
#
|
410 |
-
causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
|
411 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
412 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
413 |
|
414 |
# Run embeddings
|
@@ -432,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
432 |
|
433 |
return torch.tensor([current_pos], dtype=torch.int32)
|
434 |
|
435 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state
|
436 |
"""Generate the next token."""
|
437 |
# Get current token
|
438 |
current_token = input_ids[:, pos-1:pos]
|
@@ -447,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
447 |
update_mask[0, 0, pos-1, 0] = 1.0
|
448 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
449 |
|
450 |
-
#
|
451 |
-
|
452 |
-
single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
|
453 |
|
454 |
# Run through FFN chunks
|
455 |
for ffn_model in ffn_models:
|
@@ -498,6 +588,13 @@ def create_unified_state(ffn_models, context_length):
|
|
498 |
print("\nCreated unified transformer state")
|
499 |
return state
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
def get_user_input():
|
502 |
"""Get input from user, handling special key combinations."""
|
503 |
global THINKING_MODE
|
@@ -558,7 +655,7 @@ def get_user_input():
|
|
558 |
# Fallback for systems without termios
|
559 |
return input("> ")
|
560 |
|
561 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
|
562 |
"""Interactive chat loop."""
|
563 |
global THINKING_MODE
|
564 |
context_length = metadata.get('context_length')
|
@@ -650,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
650 |
generation_start_time = time.time()
|
651 |
|
652 |
try:
|
653 |
-
# Create initial causal mask
|
654 |
-
causal_mask = make_causal_mask(context_length, 0)
|
655 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
656 |
-
|
657 |
# Run prefill on entire context
|
658 |
current_pos = run_prefill(
|
659 |
embed_model,
|
@@ -662,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
662 |
context_pos,
|
663 |
context_length,
|
664 |
batch_size,
|
665 |
-
state
|
|
|
666 |
)
|
667 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
668 |
|
@@ -696,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
696 |
new_size, # Prefill the entire shifted content
|
697 |
context_length,
|
698 |
batch_size,
|
699 |
-
state
|
|
|
700 |
)
|
701 |
|
702 |
# Start generating from the next position
|
@@ -715,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
715 |
input_ids,
|
716 |
pos,
|
717 |
context_length,
|
718 |
-
state
|
|
|
719 |
)
|
720 |
|
721 |
# Add token
|
@@ -768,81 +864,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
768 |
traceback.print_exc()
|
769 |
|
770 |
def main():
|
771 |
-
|
772 |
-
|
773 |
-
# Add meta.yaml option
|
774 |
-
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
775 |
-
|
776 |
-
# Add existing arguments
|
777 |
-
parser.add_argument('--d', '--dir', type=str, default='.',
|
778 |
-
help='Directory containing model files (default: current directory)')
|
779 |
-
parser.add_argument('--embed', type=str, required=False,
|
780 |
-
help='Path to embeddings model (relative to --dir)')
|
781 |
-
parser.add_argument('--ffn', type=str, required=False,
|
782 |
-
help='Path to FFN model (can be chunked, relative to --dir)')
|
783 |
-
parser.add_argument('--lmhead', type=str, required=False,
|
784 |
-
help='Path to LM head model (relative to --dir)')
|
785 |
-
parser.add_argument('--tokenizer', type=str, required=False,
|
786 |
-
help='Path to tokenizer')
|
787 |
-
|
788 |
-
# Add new argument for auto-generation
|
789 |
-
parser.add_argument('--prompt', type=str,
|
790 |
-
help='If specified, run once with this prompt and exit')
|
791 |
-
|
792 |
-
# Add no-warmup flag
|
793 |
-
parser.add_argument('--nw', action='store_true',
|
794 |
-
help='Skip warmup phase')
|
795 |
-
|
796 |
-
# Model configuration
|
797 |
-
parser.add_argument('--context-length', type=int,
|
798 |
-
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
799 |
-
|
800 |
-
args = parser.parse_args()
|
801 |
-
|
802 |
-
# If meta.yaml is provided, load parameters from it
|
803 |
-
if args.meta:
|
804 |
-
try:
|
805 |
-
with open(args.meta, 'r') as f:
|
806 |
-
meta = yaml.safe_load(f)
|
807 |
-
params = meta['model_info']['parameters']
|
808 |
-
|
809 |
-
# Set model directory to meta.yaml directory if not specified
|
810 |
-
if not args.d or args.d == '.':
|
811 |
-
args.d = str(Path(args.meta).parent)
|
812 |
-
|
813 |
-
# Build model paths based on parameters
|
814 |
-
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
815 |
-
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
816 |
-
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
817 |
-
num_chunks = int(params['num_chunks'])
|
818 |
-
|
819 |
-
# Set model paths if not specified
|
820 |
-
if not args.embed:
|
821 |
-
args.embed = f'{prefix}_embeddings'
|
822 |
-
if not args.lmhead:
|
823 |
-
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
824 |
-
if not args.ffn:
|
825 |
-
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
826 |
-
if not args.tokenizer:
|
827 |
-
args.tokenizer = args.d
|
828 |
-
|
829 |
-
# Set other parameters
|
830 |
-
args.context_length = int(params['context_length'])
|
831 |
-
args.batch_size = int(params['batch_size'])
|
832 |
-
args.num_chunks = num_chunks
|
833 |
-
|
834 |
-
print(f"\nLoaded parameters from {args.meta}:")
|
835 |
-
print(f" Context Length: {args.context_length}")
|
836 |
-
print(f" Batch Size: {args.batch_size}")
|
837 |
-
print(f" Num Chunks: {args.num_chunks}")
|
838 |
-
print(f" Models Directory: {args.d}")
|
839 |
-
print(f" Embeddings: {args.embed}")
|
840 |
-
print(f" LM Head: {args.lmhead}")
|
841 |
-
print(f" FFN: {args.ffn}")
|
842 |
-
|
843 |
-
except Exception as e:
|
844 |
-
print(f"\nError loading meta.yaml: {str(e)}")
|
845 |
-
sys.exit(1)
|
846 |
|
847 |
# Convert directory to absolute path
|
848 |
model_dir = Path(args.d).resolve()
|
@@ -892,6 +914,9 @@ def main():
|
|
892 |
# Create unified state once
|
893 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
894 |
|
|
|
|
|
|
|
895 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
896 |
if not args.nw:
|
897 |
for i in range(2):
|
@@ -902,6 +927,7 @@ def main():
|
|
902 |
tokenizer=tokenizer,
|
903 |
metadata=metadata,
|
904 |
state=state, # Pass the state
|
|
|
905 |
warmup=True,
|
906 |
auto_prompt="who are you?"
|
907 |
)
|
@@ -914,6 +940,7 @@ def main():
|
|
914 |
tokenizer=tokenizer,
|
915 |
metadata=metadata,
|
916 |
state=state, # Pass the state
|
|
|
917 |
warmup=False,
|
918 |
auto_prompt=args.prompt
|
919 |
)
|
|
|
193 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
194 |
raise
|
195 |
|
196 |
+
def parse_args():
|
197 |
+
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
198 |
+
|
199 |
+
# Add meta.yaml option
|
200 |
+
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
201 |
+
|
202 |
+
# Add existing arguments
|
203 |
+
parser.add_argument('--d', '--dir', type=str, default='.',
|
204 |
+
help='Directory containing model files (default: current directory)')
|
205 |
+
parser.add_argument('--embed', type=str, required=False,
|
206 |
+
help='Path to embeddings model (relative to --dir)')
|
207 |
+
parser.add_argument('--ffn', type=str, required=False,
|
208 |
+
help='Path to FFN model (can be chunked, relative to --dir)')
|
209 |
+
parser.add_argument('--lmhead', type=str, required=False,
|
210 |
+
help='Path to LM head model (relative to --dir)')
|
211 |
+
parser.add_argument('--tokenizer', type=str, required=False,
|
212 |
+
help='Path to tokenizer')
|
213 |
+
|
214 |
+
# Add new argument for auto-generation
|
215 |
+
parser.add_argument('--prompt', type=str,
|
216 |
+
help='If specified, run once with this prompt and exit')
|
217 |
+
|
218 |
+
# Add no-warmup flag
|
219 |
+
parser.add_argument('--nw', action='store_true',
|
220 |
+
help='Skip warmup phase')
|
221 |
+
|
222 |
+
# Model configuration
|
223 |
+
parser.add_argument('--context-length', type=int,
|
224 |
+
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
225 |
+
parser.add_argument('--batch-size', type=int,
|
226 |
+
help='Batch size for prefill (default: 64)')
|
227 |
+
|
228 |
+
args = parser.parse_args()
|
229 |
+
|
230 |
+
# If meta.yaml is provided, load parameters from it
|
231 |
+
if args.meta:
|
232 |
+
try:
|
233 |
+
with open(args.meta, 'r') as f:
|
234 |
+
meta = yaml.safe_load(f)
|
235 |
+
params = meta['model_info']['parameters']
|
236 |
+
|
237 |
+
# Set model directory to meta.yaml directory if not specified
|
238 |
+
if not args.d or args.d == '.':
|
239 |
+
args.d = str(Path(args.meta).parent)
|
240 |
+
|
241 |
+
# Build model paths based on parameters
|
242 |
+
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
243 |
+
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
244 |
+
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
245 |
+
num_chunks = int(params['num_chunks'])
|
246 |
+
|
247 |
+
# Set model paths if not specified
|
248 |
+
if not args.embed:
|
249 |
+
args.embed = f'{prefix}_embeddings'
|
250 |
+
if not args.lmhead:
|
251 |
+
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
252 |
+
if not args.ffn:
|
253 |
+
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
254 |
+
if not args.tokenizer:
|
255 |
+
args.tokenizer = args.d
|
256 |
+
|
257 |
+
# Set other parameters if not overridden by command line
|
258 |
+
if args.context_length is None:
|
259 |
+
args.context_length = int(params['context_length'])
|
260 |
+
if args.batch_size is None:
|
261 |
+
args.batch_size = int(params['batch_size'])
|
262 |
+
args.num_chunks = num_chunks
|
263 |
+
|
264 |
+
print(f"\nLoaded parameters from {args.meta}:")
|
265 |
+
print(f" Context Length: {args.context_length}")
|
266 |
+
print(f" Batch Size: {args.batch_size}")
|
267 |
+
print(f" Num Chunks: {args.num_chunks}")
|
268 |
+
print(f" Models Directory: {args.d}")
|
269 |
+
print(f" Embeddings: {args.embed}")
|
270 |
+
print(f" LM Head: {args.lmhead}")
|
271 |
+
print(f" FFN: {args.ffn}")
|
272 |
+
|
273 |
+
except Exception as e:
|
274 |
+
print(f"\nError loading meta.yaml: {str(e)}")
|
275 |
+
sys.exit(1)
|
276 |
+
|
277 |
+
return args
|
278 |
+
|
279 |
def load_metadata(model,args):
|
280 |
# Extract metadata and config parameters
|
281 |
metadata = {}
|
|
|
331 |
else:
|
332 |
ctx_len = args.context_length
|
333 |
|
334 |
+
# Use defaults or values from args
|
335 |
metadata['context_length'] = ctx_len
|
336 |
metadata['state_length'] = ctx_len
|
337 |
+
# Get batch size from args or use default
|
338 |
+
metadata['batch_size'] = getattr(args, 'batch_size', 64)
|
339 |
metadata['lut_bits'] = 4
|
340 |
+
metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
|
341 |
+
print("\nUsing parameters:")
|
342 |
print(f" Context Length: {metadata['context_length']}")
|
343 |
print(f" State Length: {metadata['state_length']}")
|
344 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
345 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
346 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
347 |
+
|
348 |
+
# Override with values from args if they exist
|
349 |
+
if hasattr(args, 'batch_size') and args.batch_size is not None:
|
350 |
+
metadata['batch_size'] = args.batch_size
|
351 |
+
print(f"\nOverriding batch size from args: {args.batch_size}")
|
352 |
+
if hasattr(args, 'num_chunks') and args.num_chunks is not None:
|
353 |
+
metadata['num_chunks'] = args.num_chunks
|
354 |
+
print(f"\nOverriding num chunks from args: {args.num_chunks}")
|
355 |
+
|
356 |
return metadata
|
357 |
|
358 |
def load_models(args,metadata):
|
|
|
474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
475 |
return mask
|
476 |
|
477 |
+
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
|
478 |
"""Run prefill on the input sequence."""
|
479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
480 |
|
|
|
499 |
# Generate position IDs for this batch
|
500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
501 |
|
502 |
+
# Use the pre-initialized causal mask and extract the batch portion
|
|
|
|
|
503 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
504 |
|
505 |
# Run embeddings
|
|
|
523 |
|
524 |
return torch.tensor([current_pos], dtype=torch.int32)
|
525 |
|
526 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
527 |
"""Generate the next token."""
|
528 |
# Get current token
|
529 |
current_token = input_ids[:, pos-1:pos]
|
|
|
538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
539 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
540 |
|
541 |
+
# Use the pre-initialized causal mask and extract the single position portion
|
542 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
|
|
543 |
|
544 |
# Run through FFN chunks
|
545 |
for ffn_model in ffn_models:
|
|
|
588 |
print("\nCreated unified transformer state")
|
589 |
return state
|
590 |
|
591 |
+
def initialize_causal_mask(context_length):
|
592 |
+
"""Initialize causal mask for transformer attention."""
|
593 |
+
causal_mask = make_causal_mask(context_length, 0)
|
594 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
595 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
596 |
+
return causal_mask
|
597 |
+
|
598 |
def get_user_input():
|
599 |
"""Get input from user, handling special key combinations."""
|
600 |
global THINKING_MODE
|
|
|
655 |
# Fallback for systems without termios
|
656 |
return input("> ")
|
657 |
|
658 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
659 |
"""Interactive chat loop."""
|
660 |
global THINKING_MODE
|
661 |
context_length = metadata.get('context_length')
|
|
|
747 |
generation_start_time = time.time()
|
748 |
|
749 |
try:
|
|
|
|
|
|
|
|
|
750 |
# Run prefill on entire context
|
751 |
current_pos = run_prefill(
|
752 |
embed_model,
|
|
|
755 |
context_pos,
|
756 |
context_length,
|
757 |
batch_size,
|
758 |
+
state,
|
759 |
+
causal_mask
|
760 |
)
|
761 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
762 |
|
|
|
790 |
new_size, # Prefill the entire shifted content
|
791 |
context_length,
|
792 |
batch_size,
|
793 |
+
state,
|
794 |
+
causal_mask
|
795 |
)
|
796 |
|
797 |
# Start generating from the next position
|
|
|
810 |
input_ids,
|
811 |
pos,
|
812 |
context_length,
|
813 |
+
state,
|
814 |
+
causal_mask
|
815 |
)
|
816 |
|
817 |
# Add token
|
|
|
864 |
traceback.print_exc()
|
865 |
|
866 |
def main():
|
867 |
+
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
868 |
|
869 |
# Convert directory to absolute path
|
870 |
model_dir = Path(args.d).resolve()
|
|
|
914 |
# Create unified state once
|
915 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
916 |
|
917 |
+
# Initialize causal mask once
|
918 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
919 |
+
|
920 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
921 |
if not args.nw:
|
922 |
for i in range(2):
|
|
|
927 |
tokenizer=tokenizer,
|
928 |
metadata=metadata,
|
929 |
state=state, # Pass the state
|
930 |
+
causal_mask=causal_mask, # Pass the causal mask
|
931 |
warmup=True,
|
932 |
auto_prompt="who are you?"
|
933 |
)
|
|
|
940 |
tokenizer=tokenizer,
|
941 |
metadata=metadata,
|
942 |
state=state, # Pass the state
|
943 |
+
causal_mask=causal_mask, # Pass the causal mask
|
944 |
warmup=False,
|
945 |
auto_prompt=args.prompt
|
946 |
)
|