import os import torch import logging from pathlib import Path from typing import List, Dict, Tuple from datasets import load_dataset from greedy_search import find_best_combination from cases_collect import valid_results_collect def setup_logger() -> logging.Logger: """Configure and return logger.""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) return logging.getLogger(__name__) def get_model_paths(model_names: List[str], base_path: str = './') -> List[str]: """Generate model paths from names.""" return [os.path.join(base_path, f"{name}_model") for name in model_names] def load_test_data(dataset_name: str = 'hippocrates/MedNLI_test') -> List[Dict]: """Load and prepare test dataset.""" dataset = load_dataset(dataset_name) return [ {'Input': item['query'], 'Output': item['answer']} for item in dataset['test'] ] def calculate_accuracy(correct: List, failed: List) -> float: """Calculate accuracy from correct and failed cases.""" total = len(correct) + len(failed) return len(correct) / total if total > 0 else 0.0 def main(): """Main execution function.""" logger = setup_logger() try: # Configuration config = { 'search_name': 'randoms_model', 'model_names': ['randoms_data_3k_model'], 'base_path': './', 'valid_data_path': 'nli_demo.pt', 'seed': True, 'iteration': 5 } # Generate model paths model_paths = get_model_paths(config['model_names'], config['base_path']) logger.info(f"Generated model paths: {model_paths}") # Load datasets logger.info("Loading test data...") test_examples = load_test_data() logger.info(f"Loaded {len(test_examples)} test examples") logger.info("Loading validation data...") try: valid_data = torch.load(config['valid_data_path']) logger.info(f"Loaded validation data from {config['valid_data_path']}") except Exception as e: logger.error(f"Failed to load validation data: {str(e)}") raise # Find best combination logger.info("Finding best model combination...") best_path, update_scores = find_best_combination( model_paths, valid_data, valid_data, config['search_name'], iteration=config['iteration'], seed=config['seed'] ) logger.info(f"Best path found with scores: {update_scores}") # Evaluate on test set logger.info("Evaluating on test set...") failed_test, correct_test = valid_results_collect( best_path, test_examples, 'nli' ) # Calculate and log accuracy accuracy = calculate_accuracy(correct_test, failed_test) logger.info(f"Test Accuracy: {accuracy:.4f}") # Save results results = { 'best_path': best_path, 'update_scores': update_scores, 'test_accuracy': accuracy, 'test_results': { 'correct': len(correct_test), 'failed': len(failed_test) } } save_path = Path(f"results_{config['search_name']}.pt") torch.save(results, save_path) logger.info(f"Results saved to {save_path}") except Exception as e: logger.error(f"Error in main execution: {str(e)}", exc_info=True) raise if __name__ == "__main__": main()