Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
- chat_full.py +219 -116
chat_full.py
CHANGED
@@ -28,6 +28,8 @@ RESET_COLOR = "\033[0m"
|
|
28 |
|
29 |
# Add at the top with other constants
|
30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
|
|
|
|
31 |
|
32 |
class TokenPrinter:
|
33 |
"""Handles background printing of generated tokens."""
|
@@ -191,6 +193,89 @@ def load_model(path, function_name=None):
|
|
191 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
192 |
raise
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
def load_metadata(model,args):
|
195 |
# Extract metadata and config parameters
|
196 |
metadata = {}
|
@@ -246,18 +331,28 @@ def load_metadata(model,args):
|
|
246 |
else:
|
247 |
ctx_len = args.context_length
|
248 |
|
249 |
-
# Use defaults
|
250 |
metadata['context_length'] = ctx_len
|
251 |
metadata['state_length'] = ctx_len
|
252 |
-
|
|
|
253 |
metadata['lut_bits'] = 4
|
254 |
-
metadata['num_chunks'] = 4
|
255 |
-
print("\nUsing
|
256 |
print(f" Context Length: {metadata['context_length']}")
|
257 |
print(f" State Length: {metadata['state_length']}")
|
258 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
259 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
260 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
return metadata
|
262 |
|
263 |
def load_models(args,metadata):
|
@@ -379,7 +474,7 @@ def make_causal_mask(length, start):
|
|
379 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
380 |
return mask
|
381 |
|
382 |
-
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
|
383 |
"""Run prefill on the input sequence."""
|
384 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
385 |
|
@@ -404,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
404 |
# Generate position IDs for this batch
|
405 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
406 |
|
407 |
-
#
|
408 |
-
causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
|
409 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
410 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
411 |
|
412 |
# Run embeddings
|
@@ -430,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
430 |
|
431 |
return torch.tensor([current_pos], dtype=torch.int32)
|
432 |
|
433 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state
|
434 |
"""Generate the next token."""
|
435 |
# Get current token
|
436 |
current_token = input_ids[:, pos-1:pos]
|
@@ -445,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
445 |
update_mask[0, 0, pos-1, 0] = 1.0
|
446 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
447 |
|
448 |
-
#
|
449 |
-
|
450 |
-
single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
|
451 |
|
452 |
# Run through FFN chunks
|
453 |
for ffn_model in ffn_models:
|
@@ -496,23 +588,84 @@ def create_unified_state(ffn_models, context_length):
|
|
496 |
print("\nCreated unified transformer state")
|
497 |
return state
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
def get_user_input():
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
"""Interactive chat loop."""
|
|
|
509 |
context_length = metadata.get('context_length')
|
510 |
batch_size = metadata.get('batch_size', 64)
|
511 |
|
512 |
if not warmup:
|
513 |
print(f"\nUsing context length: {context_length}")
|
514 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
515 |
-
print("Type your message and press Enter to chat.")
|
|
|
516 |
|
517 |
# Keep track of conversation history
|
518 |
conversation = []
|
@@ -521,7 +674,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
521 |
while True:
|
522 |
try:
|
523 |
if not warmup:
|
524 |
-
print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
|
525 |
if auto_prompt is not None:
|
526 |
user_input = auto_prompt
|
527 |
if not warmup:
|
@@ -535,16 +688,31 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
535 |
|
536 |
if not user_input:
|
537 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
|
539 |
# Add user message to conversation
|
540 |
conversation.append({"role": "user", "content": user_input})
|
541 |
|
542 |
# Format using chat template with full history
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
|
549 |
# Check if we need to trim history
|
550 |
while base_input_ids.size(1) > context_length - 100: # Leave room for response
|
@@ -579,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
579 |
generation_start_time = time.time()
|
580 |
|
581 |
try:
|
582 |
-
# Create initial causal mask
|
583 |
-
causal_mask = make_causal_mask(context_length, 0)
|
584 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
585 |
-
|
586 |
# Run prefill on entire context
|
587 |
current_pos = run_prefill(
|
588 |
embed_model,
|
@@ -591,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
591 |
context_pos,
|
592 |
context_length,
|
593 |
batch_size,
|
594 |
-
state
|
|
|
595 |
)
|
596 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
597 |
|
@@ -625,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
625 |
new_size, # Prefill the entire shifted content
|
626 |
context_length,
|
627 |
batch_size,
|
628 |
-
state
|
|
|
629 |
)
|
630 |
|
631 |
# Start generating from the next position
|
@@ -644,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
644 |
input_ids,
|
645 |
pos,
|
646 |
context_length,
|
647 |
-
state
|
|
|
648 |
)
|
649 |
|
650 |
# Add token
|
@@ -697,77 +864,7 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
697 |
traceback.print_exc()
|
698 |
|
699 |
def main():
|
700 |
-
|
701 |
-
|
702 |
-
# Add meta.yaml option
|
703 |
-
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
704 |
-
|
705 |
-
# Add existing arguments
|
706 |
-
parser.add_argument('--d', '--dir', type=str, default='.',
|
707 |
-
help='Directory containing model files (default: current directory)')
|
708 |
-
parser.add_argument('--embed', type=str, required=False,
|
709 |
-
help='Path to embeddings model (relative to --dir)')
|
710 |
-
parser.add_argument('--ffn', type=str, required=False,
|
711 |
-
help='Path to FFN model (can be chunked, relative to --dir)')
|
712 |
-
parser.add_argument('--lmhead', type=str, required=False,
|
713 |
-
help='Path to LM head model (relative to --dir)')
|
714 |
-
parser.add_argument('--tokenizer', type=str, required=False,
|
715 |
-
help='Path to tokenizer')
|
716 |
-
|
717 |
-
# Add new argument for auto-generation
|
718 |
-
parser.add_argument('--prompt', type=str,
|
719 |
-
help='If specified, run once with this prompt and exit')
|
720 |
-
|
721 |
-
# Model configuration
|
722 |
-
parser.add_argument('--context-length', type=int,
|
723 |
-
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
724 |
-
|
725 |
-
args = parser.parse_args()
|
726 |
-
|
727 |
-
# If meta.yaml is provided, load parameters from it
|
728 |
-
if args.meta:
|
729 |
-
try:
|
730 |
-
with open(args.meta, 'r') as f:
|
731 |
-
meta = yaml.safe_load(f)
|
732 |
-
params = meta['model_info']['parameters']
|
733 |
-
|
734 |
-
# Set model directory to meta.yaml directory if not specified
|
735 |
-
if not args.d or args.d == '.':
|
736 |
-
args.d = str(Path(args.meta).parent)
|
737 |
-
|
738 |
-
# Build model paths based on parameters
|
739 |
-
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
740 |
-
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
741 |
-
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
742 |
-
num_chunks = int(params['num_chunks'])
|
743 |
-
|
744 |
-
# Set model paths if not specified
|
745 |
-
if not args.embed:
|
746 |
-
args.embed = f'{prefix}_embeddings'
|
747 |
-
if not args.lmhead:
|
748 |
-
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
749 |
-
if not args.ffn:
|
750 |
-
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
751 |
-
if not args.tokenizer:
|
752 |
-
args.tokenizer = args.d
|
753 |
-
|
754 |
-
# Set other parameters
|
755 |
-
args.context_length = int(params['context_length'])
|
756 |
-
args.batch_size = int(params['batch_size'])
|
757 |
-
args.num_chunks = num_chunks
|
758 |
-
|
759 |
-
print(f"\nLoaded parameters from {args.meta}:")
|
760 |
-
print(f" Context Length: {args.context_length}")
|
761 |
-
print(f" Batch Size: {args.batch_size}")
|
762 |
-
print(f" Num Chunks: {args.num_chunks}")
|
763 |
-
print(f" Models Directory: {args.d}")
|
764 |
-
print(f" Embeddings: {args.embed}")
|
765 |
-
print(f" LM Head: {args.lmhead}")
|
766 |
-
print(f" FFN: {args.ffn}")
|
767 |
-
|
768 |
-
except Exception as e:
|
769 |
-
print(f"\nError loading meta.yaml: {str(e)}")
|
770 |
-
sys.exit(1)
|
771 |
|
772 |
# Convert directory to absolute path
|
773 |
model_dir = Path(args.d).resolve()
|
@@ -817,18 +914,23 @@ def main():
|
|
817 |
# Create unified state once
|
818 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
819 |
|
|
|
|
|
|
|
820 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
|
|
|
|
832 |
|
833 |
# Main run
|
834 |
chat_loop(
|
@@ -838,6 +940,7 @@ def main():
|
|
838 |
tokenizer=tokenizer,
|
839 |
metadata=metadata,
|
840 |
state=state, # Pass the state
|
|
|
841 |
warmup=False,
|
842 |
auto_prompt=args.prompt
|
843 |
)
|
|
|
28 |
|
29 |
# Add at the top with other constants
|
30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
31 |
+
THINKING_MODE = False
|
32 |
+
THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
|
33 |
|
34 |
class TokenPrinter:
|
35 |
"""Handles background printing of generated tokens."""
|
|
|
193 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
194 |
raise
|
195 |
|
196 |
+
def parse_args():
|
197 |
+
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
198 |
+
|
199 |
+
# Add meta.yaml option
|
200 |
+
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
201 |
+
|
202 |
+
# Add existing arguments
|
203 |
+
parser.add_argument('--d', '--dir', type=str, default='.',
|
204 |
+
help='Directory containing model files (default: current directory)')
|
205 |
+
parser.add_argument('--embed', type=str, required=False,
|
206 |
+
help='Path to embeddings model (relative to --dir)')
|
207 |
+
parser.add_argument('--ffn', type=str, required=False,
|
208 |
+
help='Path to FFN model (can be chunked, relative to --dir)')
|
209 |
+
parser.add_argument('--lmhead', type=str, required=False,
|
210 |
+
help='Path to LM head model (relative to --dir)')
|
211 |
+
parser.add_argument('--tokenizer', type=str, required=False,
|
212 |
+
help='Path to tokenizer')
|
213 |
+
|
214 |
+
# Add new argument for auto-generation
|
215 |
+
parser.add_argument('--prompt', type=str,
|
216 |
+
help='If specified, run once with this prompt and exit')
|
217 |
+
|
218 |
+
# Add no-warmup flag
|
219 |
+
parser.add_argument('--nw', action='store_true',
|
220 |
+
help='Skip warmup phase')
|
221 |
+
|
222 |
+
# Model configuration
|
223 |
+
parser.add_argument('--context-length', type=int,
|
224 |
+
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
225 |
+
parser.add_argument('--batch-size', type=int,
|
226 |
+
help='Batch size for prefill (default: 64)')
|
227 |
+
|
228 |
+
args = parser.parse_args()
|
229 |
+
|
230 |
+
# If meta.yaml is provided, load parameters from it
|
231 |
+
if args.meta:
|
232 |
+
try:
|
233 |
+
with open(args.meta, 'r') as f:
|
234 |
+
meta = yaml.safe_load(f)
|
235 |
+
params = meta['model_info']['parameters']
|
236 |
+
|
237 |
+
# Set model directory to meta.yaml directory if not specified
|
238 |
+
if not args.d or args.d == '.':
|
239 |
+
args.d = str(Path(args.meta).parent)
|
240 |
+
|
241 |
+
# Build model paths based on parameters
|
242 |
+
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
243 |
+
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
244 |
+
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
245 |
+
num_chunks = int(params['num_chunks'])
|
246 |
+
|
247 |
+
# Set model paths if not specified
|
248 |
+
if not args.embed:
|
249 |
+
args.embed = f'{prefix}_embeddings'
|
250 |
+
if not args.lmhead:
|
251 |
+
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
252 |
+
if not args.ffn:
|
253 |
+
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
254 |
+
if not args.tokenizer:
|
255 |
+
args.tokenizer = args.d
|
256 |
+
|
257 |
+
# Set other parameters if not overridden by command line
|
258 |
+
if args.context_length is None:
|
259 |
+
args.context_length = int(params['context_length'])
|
260 |
+
if args.batch_size is None:
|
261 |
+
args.batch_size = int(params['batch_size'])
|
262 |
+
args.num_chunks = num_chunks
|
263 |
+
|
264 |
+
print(f"\nLoaded parameters from {args.meta}:")
|
265 |
+
print(f" Context Length: {args.context_length}")
|
266 |
+
print(f" Batch Size: {args.batch_size}")
|
267 |
+
print(f" Num Chunks: {args.num_chunks}")
|
268 |
+
print(f" Models Directory: {args.d}")
|
269 |
+
print(f" Embeddings: {args.embed}")
|
270 |
+
print(f" LM Head: {args.lmhead}")
|
271 |
+
print(f" FFN: {args.ffn}")
|
272 |
+
|
273 |
+
except Exception as e:
|
274 |
+
print(f"\nError loading meta.yaml: {str(e)}")
|
275 |
+
sys.exit(1)
|
276 |
+
|
277 |
+
return args
|
278 |
+
|
279 |
def load_metadata(model,args):
|
280 |
# Extract metadata and config parameters
|
281 |
metadata = {}
|
|
|
331 |
else:
|
332 |
ctx_len = args.context_length
|
333 |
|
334 |
+
# Use defaults or values from args
|
335 |
metadata['context_length'] = ctx_len
|
336 |
metadata['state_length'] = ctx_len
|
337 |
+
# Get batch size from args or use default
|
338 |
+
metadata['batch_size'] = getattr(args, 'batch_size', 64)
|
339 |
metadata['lut_bits'] = 4
|
340 |
+
metadata['num_chunks'] = getattr(args, 'num_chunks', 4)
|
341 |
+
print("\nUsing parameters:")
|
342 |
print(f" Context Length: {metadata['context_length']}")
|
343 |
print(f" State Length: {metadata['state_length']}")
|
344 |
print(f" Prefill Batch Size: {metadata['batch_size']}")
|
345 |
print(f" LUT Bits: {metadata['lut_bits']}")
|
346 |
print(f" Number of Chunks: {metadata['num_chunks']}")
|
347 |
+
|
348 |
+
# Override with values from args if they exist
|
349 |
+
if hasattr(args, 'batch_size') and args.batch_size is not None:
|
350 |
+
metadata['batch_size'] = args.batch_size
|
351 |
+
print(f"\nOverriding batch size from args: {args.batch_size}")
|
352 |
+
if hasattr(args, 'num_chunks') and args.num_chunks is not None:
|
353 |
+
metadata['num_chunks'] = args.num_chunks
|
354 |
+
print(f"\nOverriding num chunks from args: {args.num_chunks}")
|
355 |
+
|
356 |
return metadata
|
357 |
|
358 |
def load_models(args,metadata):
|
|
|
474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
475 |
return mask
|
476 |
|
477 |
+
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
|
478 |
"""Run prefill on the input sequence."""
|
479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
480 |
|
|
|
499 |
# Generate position IDs for this batch
|
500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
501 |
|
502 |
+
# Use the pre-initialized causal mask and extract the batch portion
|
|
|
|
|
503 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
504 |
|
505 |
# Run embeddings
|
|
|
523 |
|
524 |
return torch.tensor([current_pos], dtype=torch.int32)
|
525 |
|
526 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
527 |
"""Generate the next token."""
|
528 |
# Get current token
|
529 |
current_token = input_ids[:, pos-1:pos]
|
|
|
538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
539 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
540 |
|
541 |
+
# Use the pre-initialized causal mask and extract the single position portion
|
542 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
|
|
543 |
|
544 |
# Run through FFN chunks
|
545 |
for ffn_model in ffn_models:
|
|
|
588 |
print("\nCreated unified transformer state")
|
589 |
return state
|
590 |
|
591 |
+
def initialize_causal_mask(context_length):
|
592 |
+
"""Initialize causal mask for transformer attention."""
|
593 |
+
causal_mask = make_causal_mask(context_length, 0)
|
594 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
595 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
596 |
+
return causal_mask
|
597 |
+
|
598 |
def get_user_input():
|
599 |
+
"""Get input from user, handling special key combinations."""
|
600 |
+
global THINKING_MODE
|
601 |
+
try:
|
602 |
+
import termios
|
603 |
+
import tty
|
604 |
+
import sys
|
605 |
+
|
606 |
+
def _getch():
|
607 |
+
fd = sys.stdin.fileno()
|
608 |
+
old_settings = termios.tcgetattr(fd)
|
609 |
+
try:
|
610 |
+
tty.setraw(sys.stdin.fileno())
|
611 |
+
ch = sys.stdin.read(1)
|
612 |
+
finally:
|
613 |
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
614 |
+
return ch
|
615 |
+
|
616 |
+
buffer = []
|
617 |
+
while True:
|
618 |
+
char = _getch()
|
619 |
+
|
620 |
+
# Debug: print the character code
|
621 |
+
print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
|
622 |
+
|
623 |
+
# Check for Enter key
|
624 |
+
if char == '\r' or char == '\n':
|
625 |
+
print() # Move to next line
|
626 |
+
input_text = ''.join(buffer)
|
627 |
+
# Check if the command is /t
|
628 |
+
if input_text == '/t':
|
629 |
+
THINKING_MODE = not THINKING_MODE
|
630 |
+
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
631 |
+
buffer = [] # Clear buffer
|
632 |
+
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
633 |
+
continue
|
634 |
+
return input_text
|
635 |
+
|
636 |
+
# Handle backspace
|
637 |
+
if char == '\x7f': # backspace
|
638 |
+
if buffer:
|
639 |
+
buffer.pop()
|
640 |
+
sys.stdout.write('\b \b') # Erase character
|
641 |
+
sys.stdout.flush()
|
642 |
+
continue
|
643 |
+
|
644 |
+
# Handle Ctrl-C
|
645 |
+
if char == '\x03': # Ctrl-C
|
646 |
+
print("^C")
|
647 |
+
raise KeyboardInterrupt
|
648 |
+
|
649 |
+
# Print character and add to buffer
|
650 |
+
sys.stdout.write(char)
|
651 |
+
sys.stdout.flush()
|
652 |
+
buffer.append(char)
|
653 |
+
|
654 |
+
except ImportError:
|
655 |
+
# Fallback for systems without termios
|
656 |
+
return input("> ")
|
657 |
+
|
658 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
659 |
"""Interactive chat loop."""
|
660 |
+
global THINKING_MODE
|
661 |
context_length = metadata.get('context_length')
|
662 |
batch_size = metadata.get('batch_size', 64)
|
663 |
|
664 |
if not warmup:
|
665 |
print(f"\nUsing context length: {context_length}")
|
666 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
667 |
+
print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
|
668 |
+
print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
|
669 |
|
670 |
# Keep track of conversation history
|
671 |
conversation = []
|
|
|
674 |
while True:
|
675 |
try:
|
676 |
if not warmup:
|
677 |
+
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
678 |
if auto_prompt is not None:
|
679 |
user_input = auto_prompt
|
680 |
if not warmup:
|
|
|
688 |
|
689 |
if not user_input:
|
690 |
continue
|
691 |
+
|
692 |
+
# Handle /t command
|
693 |
+
if user_input == "/t":
|
694 |
+
THINKING_MODE = not THINKING_MODE
|
695 |
+
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
696 |
+
continue
|
697 |
|
698 |
# Add user message to conversation
|
699 |
conversation.append({"role": "user", "content": user_input})
|
700 |
|
701 |
# Format using chat template with full history
|
702 |
+
if THINKING_MODE:
|
703 |
+
# Add thinking prompt to system message
|
704 |
+
conversation_with_thinking = [{"role": "system", "content": THINKING_PROMPT}] + conversation
|
705 |
+
base_input_ids = tokenizer.apply_chat_template(
|
706 |
+
conversation_with_thinking,
|
707 |
+
return_tensors="pt",
|
708 |
+
add_generation_prompt=True
|
709 |
+
).to(torch.int32)
|
710 |
+
else:
|
711 |
+
base_input_ids = tokenizer.apply_chat_template(
|
712 |
+
conversation,
|
713 |
+
return_tensors="pt",
|
714 |
+
add_generation_prompt=True
|
715 |
+
).to(torch.int32)
|
716 |
|
717 |
# Check if we need to trim history
|
718 |
while base_input_ids.size(1) > context_length - 100: # Leave room for response
|
|
|
747 |
generation_start_time = time.time()
|
748 |
|
749 |
try:
|
|
|
|
|
|
|
|
|
750 |
# Run prefill on entire context
|
751 |
current_pos = run_prefill(
|
752 |
embed_model,
|
|
|
755 |
context_pos,
|
756 |
context_length,
|
757 |
batch_size,
|
758 |
+
state,
|
759 |
+
causal_mask
|
760 |
)
|
761 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
762 |
|
|
|
790 |
new_size, # Prefill the entire shifted content
|
791 |
context_length,
|
792 |
batch_size,
|
793 |
+
state,
|
794 |
+
causal_mask
|
795 |
)
|
796 |
|
797 |
# Start generating from the next position
|
|
|
810 |
input_ids,
|
811 |
pos,
|
812 |
context_length,
|
813 |
+
state,
|
814 |
+
causal_mask
|
815 |
)
|
816 |
|
817 |
# Add token
|
|
|
864 |
traceback.print_exc()
|
865 |
|
866 |
def main():
|
867 |
+
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
868 |
|
869 |
# Convert directory to absolute path
|
870 |
model_dir = Path(args.d).resolve()
|
|
|
914 |
# Create unified state once
|
915 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
916 |
|
917 |
+
# Initialize causal mask once
|
918 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
919 |
+
|
920 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
921 |
+
if not args.nw:
|
922 |
+
for i in range(2):
|
923 |
+
chat_loop(
|
924 |
+
embed_model=embed_model,
|
925 |
+
ffn_models=ffn_models,
|
926 |
+
lmhead_model=lmhead_model,
|
927 |
+
tokenizer=tokenizer,
|
928 |
+
metadata=metadata,
|
929 |
+
state=state, # Pass the state
|
930 |
+
causal_mask=causal_mask, # Pass the causal mask
|
931 |
+
warmup=True,
|
932 |
+
auto_prompt="who are you?"
|
933 |
+
)
|
934 |
|
935 |
# Main run
|
936 |
chat_loop(
|
|
|
940 |
tokenizer=tokenizer,
|
941 |
metadata=metadata,
|
942 |
state=state, # Pass the state
|
943 |
+
causal_mask=causal_mask, # Pass the causal mask
|
944 |
warmup=False,
|
945 |
auto_prompt=args.prompt
|
946 |
)
|