import os import shutil import torch import pdb from model_soups_utils import average_two_model from cases_collect import valid_results_collect def remove_folder(path): if os.path.isdir(path): # Check if the directory exists shutil.rmtree(path) print(f"Directory '{path}' has been removed.") else: print(f"Directory '{path}' does not exist.") def score_criteria(x): return x[1] def compare_criteria(x,y): return x<=y def find_best_combination(model_path,valid_data,test_examples,search_name,iteration=5,seed=True,task='nli'): if seed: if isinstance(model_path,list): paths=[] for m_p in model_path: paths.extend([m_p+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]]) else: paths=[model_path+'_{0}'.format(seed) for seed in [str(i) for i in range(2020,2030)]] else: paths=model_path try: update_scores=torch.load('{0}_score.pt'.format(search_name)) del_paths=torch.load('{0}_path.pt'.format(search_name)) for path in del_paths: del paths[paths.index(path)] best_path=torch.load('{0}_best_path.pt'.format(search_name)) best_score=update_scores[-1] except: del_paths=[] update_scores=[] path_count=[] for path_id,path in enumerate(paths): print(0,path_id,len(paths)) f_test,c_test=valid_results_collect(path, test_examples,task) #test_examples, args.task) path_count.append((path,len(c_test)/(len(f_test)+len(c_test)))) #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=5) print(path_count[-1][1]) #path_count.append((path,len(ooa_failed_cases),len(im_failed_cases),len(correct_cases))) path_count.sort(key=lambda x:score_criteria(x),reverse=True) best_path=path_count[0][0] best_score=score_criteria(path_count[0]) update_scores.append(best_score) f_test,c_test=valid_results_collect(best_path, test_examples,'nli') print(best_score,len(c_test)/(len(f_test)+len(c_test))) torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt') del_paths.append(best_path) torch.save(del_paths,'{0}_path.pt'.format(search_name)) #torch.save(update_scores,'{0}_score.pt'.format(search_name)) #del_paths=torch.load('{0}_path.pt'.format(search_name)) #pdb.set_trace() del paths[paths.index(best_path)] torch.save(best_path,'{0}_best_path.pt'.format(search_name)) while len(paths)>0: path_count=[] for path_id,path in enumerate(paths): print(len(update_scores),path_id,len(paths)) average_path="{0}_average".format(best_path+path.split('/')[-1]) if not os.path.isdir(average_path): average_path=average_two_model(best_path,path,len(update_scores)) f_test,c_test=valid_results_collect(average_path, test_examples, 'nli') #valid_results_collect(path, valid_data,args.task) #f_test,c_test=valid_results_collect(average_path, test_examples$ if not path_count: #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(average_path, valid_data,seed=False, iteration=5) path_count.append((path,len(c_test)/(len(f_test)+len(c_test)),average_path)) else: score=len(c_test)/(len(f_test)+len(c_test)) if score>=path_count[-1][1]: path_count.append((path,score,average_path)) else: remove_folder(average_path) print(path_count[-1][1]) #len(ooa_failed_cases),len(im_failed_cases),len(correct_cases),average_path)) path_count.sort(key=lambda x:score_criteria(x),reverse=True) win_path=path_count[0][0] win_score=score_criteria(path_count[0]) #del paths[paths.index(win_path)] if compare_criteria(best_score,win_score): if len(del_paths)>2: remove_folder(best_path) best_path=path_count[0][2] torch.save(best_path,'{0}_best_path.pt'.format(search_name)) best_score=win_score #f_test,c_test=valid_results_collect(best_path, test_examples,args.task) print(best_score) #,len(c_test)/(len(f_test)+len(c_test))) del paths[paths.index(win_path)] #print(best_score) del_paths.append(win_path) torch.save(del_paths,'{0}_path.pt'.format(search_name)) # pdb.set_trace() update_scores.append(best_score) torch.save(update_scores,'{0}_score.pt'.format(search_name)) #torch.save(update_scores,'update_scores_backup.pt') else: while paths: paths.pop() best_path=best_path #break #update_scores.append(best_score) return best_path,update_scores #ooa_failed_cases, im_failed_cases, correct_cases=process_nli_validation_batch(path, valid_data,seed=False, iteration=100)