from sklearn.cluster import * import os import numpy as np from config import config import yaml from multiprocessing import Pool from tqdm import tqdm def process_speaker(speaker): embs = [] wavnames = [] print("\nspeaker: " + speaker) for file in filelist_dict[speaker]: try: embs.append( np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0) ) wavnames.append(os.path.basename(file)) except Exception as e: print(e) x = np.concatenate(embs, axis=0) x = np.squeeze(x) # 聚类算法类的数量 n_clusters = config.emo_cluster_config.n_clusters method = config.emo_cluster_config.method if method == "b": model = Birch(n_clusters=n_clusters, threshold=0.2) elif method == "s": model = SpectralClustering(n_clusters=n_clusters) elif method == "a": model = AgglomerativeClustering(n_clusters=n_clusters) else: model = KMeans(n_clusters=n_clusters, random_state=42) # 可以自行尝试各种不同的聚类算法 y_predict = model.fit_predict(x) classes = [[] for i in range(y_predict.max() + 1)] for idx, wavname in enumerate(wavnames): classes[y_predict[idx]].append(wavname) yml_result = {} yml_result[speaker] = {} os.makedirs( os.path.join(config.dataset_path, f"emo_clustering/{speaker}"), exist_ok=True ) for i in range(y_predict.max() + 1): class_length = len(classes[i]) print("类别:", i, "本类中样本数量:", class_length) yml_result[speaker][f"class{i}"] = [] for j in range(config.emo_cluster_config.n_samples): if j >= class_length: break print(classes[i][j]) yml_result[speaker][f"class{i}"].append(classes[i][j]) if hasattr(model, "cluster_centers_") and config.emo_cluster_config.save_center: centers = model.cluster_centers_ filename = os.path.join( config.dataset_path, f"emo_clustering/{speaker}/cluster_center_{i}.npy" ) # 保存中心 np.save(filename, centers[i]) elif config.emo_cluster_config.save_center: labels = model.labels_ centers = np.array([X[labels == i].mean(0) for i in range(n_clusters)]) filename = os.path.join( config.dataset_path, f"emo_clustering/{speaker}/cluster_center_{i}.npy" ) np.save(filename, centers[i]) return yml_result if __name__ == "__main__": filelist_dict = {} with open( config.preprocess_text_config.train_path, mode="r", encoding="utf-8" ) as f: for line in f: speaker = line.split("|")[1] if speaker not in filelist_dict: filelist_dict[speaker] = [] filelist_dict[speaker].append(line.split("|")[0]) with Pool() as p: results = list( tqdm( p.imap(process_speaker, list(filelist_dict.keys())), total=len(filelist_dict), ) ) yml_result = {} for result in results: yml_result.update(result) with open( os.path.join(config.dataset_path, "emo_clustering/emo_clustering.yml"), "w", encoding="utf-8", ) as f: yaml.dump(yml_result, f)