import torch, argparse from commonsense_model import CommonsenseGRUModel from dataloader import RobertaCometDataset from torch.utils.data import DataLoader def load_model(model_path, args): emo_gru = True n_classes = 15 cuda = args.cuda D_m = 1024 D_s = 768 D_g = 150 D_p = 150 D_r = 150 D_i = 150 D_h = 100 D_a = 100 D_e = D_p + D_r + D_i model = CommonsenseGRUModel( D_m, D_s, D_g, D_p, D_r, D_i, D_e, D_h, D_a, n_classes=n_classes, listener_state=args.active_listener, context_attention=args.attention, dropout_rec=args.rec_dropout, dropout=args.dropout, emo_gru=emo_gru, mode1=args.mode1, norm=args.norm, residual=args.residual, ) if cuda: model.cuda() model.load_state_dict(torch.load(model_path)) model.eval() return model def get_valid_dataloader( roberta_features_path: str, comet_features_path: str, batch_size=1, num_workers=0, pin_memory=False, ): valid_set = RobertaCometDataset("valid", roberta_features_path, comet_features_path) test_loader = DataLoader( valid_set, batch_size=batch_size, collate_fn=valid_set.collate_fn, num_workers=num_workers, pin_memory=pin_memory, ) return test_loader, valid_set.keys def predict(model, data_loader, args): predictions = [] for data in data_loader: r1, r2, r3, r4, x1, x2, x3, x4, x5, x6, o1, o2, o3, qmask, umask, label = ( [d.cuda() for d in data[:-1]] if args.cuda else data[:-1] ) log_prob, _, alpha, alpha_f, alpha_b, _ = model( r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask ) lp_ = log_prob.transpose(0, 1).contiguous().view(-1, log_prob.size()[2]) preds = torch.argmax(lp_, dim=-1) predictions.append(preds.data.cpu().numpy()) return predictions def parse_cosmic_args(): parser = argparse.ArgumentParser() # Parse arguments input into the cosmic model parser.add_argument( "--no-cuda", action="store_true", default=False, help="does not use GPU" ) parser.add_argument( "--lr", type=float, default=0.0001, metavar="LR", help="learning rate" ) parser.add_argument( "--l2", type=float, default=0.00003, metavar="L2", help="L2 regularization weight", ) parser.add_argument( "--rec-dropout", type=float, default=0.3, metavar="rec_dropout", help="rec_dropout rate", ) parser.add_argument( "--dropout", type=float, default=0.5, metavar="dropout", help="dropout rate" ) parser.add_argument( "--batch-size", type=int, default=1, metavar="BS", help="batch size" ) parser.add_argument( "--epochs", type=int, default=10, metavar="E", help="number of epochs" ) parser.add_argument( "--class-weight", action="store_true", default=True, help="use class weights" ) parser.add_argument( "--active-listener", action="store_true", default=True, help="active listener" ) parser.add_argument( "--attention", default="simple", help="Attention type in context GRU" ) parser.add_argument( "--tensorboard", action="store_true", default=False, help="Enables tensorboard log", ) parser.add_argument("--mode1", type=int, default=2, help="Roberta features to use") parser.add_argument("--seed", type=int, default=500, metavar="seed", help="seed") parser.add_argument("--norm", type=int, default=0, help="normalization strategy") parser.add_argument("--mu", type=float, default=0, help="class_weight_mu") parser.add_argument( "--residual", action="store_true", default=True, help="use residual connection" ) args = parser.parse_args() args.cuda = torch.cuda.is_available() and not args.no_cuda if args.cuda: print("Running on GPU") else: print("Running on CPU") return args if __name__ == "__main__": def pred_to_labels(preds): mapped_predictions = [] for pred in preds: # map the prediction for each conversation mapped_labels = [] for label in pred: mapped_labels.append(label_mapping[label]) mapped_predictions.append(mapped_labels) # return the mapped labels for each conversation return mapped_predictions label_mapping = { 0: "Curiosity", 1: "Obscene", 2: "Informative", 3: "Openness", 4: "Acceptance", 5: "Interest", 6: "Greeting", 7: "Disapproval", 8: "Denial", 9: "Anxious", 10: "Uninterested", 11: "Remorse", 12: "Confused", 13: "Accusatory", 14: "Annoyed", } args = parse_cosmic_args() model = load_model("epik/best_model.pt", args) test_dataloader, ids = get_valid_dataloader() predicted_labels = pred_to_labels(predict(model, test_dataloader, args)) for id, labels in zip(ids, predicted_labels): print(f"Conversation ID: {id}") print(f"Predicted Sentiment Labels: {labels}") print(len(labels))