jonruida commited on
Commit
eff1c6a
verified
1 Parent(s): f620723

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +156 -0
inference.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Data
3
+ import numpy as np
4
+ import json
5
+
6
+ class GNN(torch.nn.Module):
7
+ """
8
+ Overall graph neural network. Consists of learnable user/item (i.e., playlist/song) embeddings
9
+ and LightGCN layers.
10
+ """
11
+ def __init__(self, embedding_dim, num_nodes, num_playlists, num_layers):
12
+ super(GNN, self).__init__()
13
+
14
+ self.embedding_dim = embedding_dim
15
+ self.num_nodes = num_nodes # total number of nodes (songs + playlists) in dataset
16
+ self.num_playlists = num_playlists # total number of playlists in dataset
17
+ self.num_layers = num_layers
18
+
19
+ # Initialize embeddings for all playlists and songs. Playlists will have indices from 0...num_playlists-1,
20
+ # songs will have indices from num_playlists...num_nodes-1
21
+ self.embeddings = torch.nn.Embedding(num_embeddings=self.num_nodes, embedding_dim=self.embedding_dim)
22
+ torch.nn.init.normal_(self.embeddings.weight, std=0.1)
23
+
24
+ self.layers = torch.nn.ModuleList() # LightGCN layers
25
+ for _ in range(self.num_layers):
26
+ self.layers.append(LightGCN())
27
+
28
+ self.sigmoid = torch.sigmoid
29
+
30
+ def forward(self):
31
+ raise NotImplementedError("forward() has not been implemented for the GNN class. Do not use")
32
+
33
+ def gnn_propagation(self, edge_index_mp):
34
+ """
35
+ Performs the linear embedding propagation (using the LightGCN layers) and calculates final (multi-scale) embeddings
36
+ for each user/item, which are calculated as a weighted sum of that user/item's embeddings at each layer (from
37
+ 0 to self.num_layers). Technically, the weighted sum here is the average, which is what the LightGCN authors recommend.
38
+
39
+ args:
40
+ edge_index_mp: a tensor of all (undirected) edges in the graph, which is used for message passing/propagation and
41
+ calculating the multi-scale embeddings. (In contrast to the evaluation/supervision edges, which are distinct
42
+ from the message passing edges and will be used for calculating loss/performance metrics).
43
+ returns:
44
+ final multi-scale embeddings for all users/items
45
+ """
46
+ x = self.embeddings.weight # layer-0 embeddings
47
+
48
+ x_at_each_layer = [x] # stores embeddings from each layer. Start with layer-0 embeddings
49
+ for i in range(self.num_layers): # now performing the GNN propagation
50
+ x = self.layers[i](x, edge_index_mp)
51
+ x_at_each_layer.append(x)
52
+ final_embs = torch.stack(x_at_each_layer, dim=0).mean(dim=0) # take average to calculate multi-scale embeddings
53
+ return final_embs
54
+
55
+ def predict_scores(self, edge_index, embs):
56
+ """
57
+ Calculates predicted scores for each playlist/song pair in the list of edges. Uses dot product of their embeddings.
58
+
59
+ args:
60
+ edge_index: tensor of edges (between playlists and songs) whose scores we will calculate.
61
+ embs: node embeddings for calculating predicted scores (typically the multi-scale embeddings from gnn_propagation())
62
+ returns:
63
+ predicted scores for each playlist/song pair in edge_index
64
+ """
65
+ scores = embs[edge_index[0,:], :] * embs[edge_index[1,:], :] # taking dot product for each playlist/song pair
66
+ scores = scores.sum(dim=1)
67
+ scores = self.sigmoid(scores)
68
+ return scores
69
+
70
+ def calc_loss(self, data_mp, data_pos, data_neg):
71
+ """
72
+ The main training step. Performs GNN propagation on message passing edges, to get multi-scale embeddings.
73
+ Then predicts scores for each training example, and calculates Bayesian Personalized Ranking (BPR) loss.
74
+
75
+ args:
76
+ data_mp: tensor of edges used for message passing / calculating multi-scale embeddings
77
+ data_pos: set of positive edges that will be used during loss calculation
78
+ data_neg: set of negative edges that will be used during loss calculation
79
+ returns:
80
+ loss calculated on the positive/negative training edges
81
+ """
82
+ # Perform GNN propagation on message passing edges to get final embeddings
83
+ final_embs = self.gnn_propagation(data_mp.edge_index)
84
+
85
+ # Get edge prediction scores for all positive and negative evaluation edges
86
+ pos_scores = self.predict_scores(data_pos.edge_index, final_embs)
87
+ neg_scores = self.predict_scores(data_neg.edge_index, final_embs)
88
+
89
+ # # Calculate loss (binary cross-entropy). Commenting out, but can use instead of BPR if desired.
90
+ # all_scores = torch.cat([pos_scores, neg_scores], dim=0)
91
+ # all_labels = torch.cat([torch.ones(pos_scores.shape[0]), torch.zeros(neg_scores.shape[0])], dim=0)
92
+ # loss_fn = torch.nn.BCELoss()
93
+ # loss = loss_fn(all_scores, all_labels)
94
+
95
+ # Calculate loss (using variation of Bayesian Personalized Ranking loss, similar to the one used in official
96
+ # LightGCN implementation at https://github.com/gusye1234/LightGCN-PyTorch/blob/master/code/model.py#L202)
97
+ loss = -torch.log(self.sigmoid(pos_scores - neg_scores)).mean()
98
+ return loss
99
+
100
+ def evaluation(self, data_mp, data_pos, k):
101
+ """
102
+ Performs evaluation on validation or test set. Calculates recall@k.
103
+
104
+ args:
105
+ data_mp: message passing edges to use for propagation/calculating multi-scale embeddings
106
+ data_pos: positive edges to use for scoring metrics. Should be no overlap between these edges and data_mp's edges
107
+ k: value of k to use for recall@k
108
+ returns:
109
+ dictionary mapping playlist ID -> recall@k on that playlist
110
+ """
111
+ # Run propagation on the message-passing edges to get multi-scale embeddings
112
+ final_embs = self.gnn_propagation(data_mp.edge_index)
113
+
114
+ # Get embeddings of all unique playlists in the batch of evaluation edges
115
+ unique_playlists = torch.unique_consecutive(data_pos.edge_index[0,:])
116
+ playlist_emb = final_embs[unique_playlists, :] # has shape [number of playlists in batch, 64]
117
+
118
+ # Get embeddings of ALL songs in dataset
119
+ song_emb = final_embs[self.num_playlists:, :] # has shape [total number of songs in dataset, 64]
120
+
121
+ # All ratings for each playlist in batch to each song in entire dataset (using dot product as the scoring function)
122
+ ratings = self.sigmoid(torch.matmul(playlist_emb, song_emb.t())) # shape: [# playlists in batch, # songs in dataset]
123
+ # where entry i,j is rating of song j for playlist i
124
+ # Calculate recall@k
125
+ result = recall_at_k(ratings.cpu(), k, self.num_playlists, data_pos.edge_index.cpu(),
126
+ unique_playlists.cpu(), data_mp.edge_index.cpu())
127
+ return result
128
+
129
+
130
+ # Carga el modelo previamente entrenado
131
+ data = torch.load(os.path.join(base_dir, "data_object.pt"))
132
+ with open(os.path.join(base_dir, "dataset_stats.json"), 'r') as f:
133
+ stats = json.load(f)
134
+ num_playlists, num_nodes = stats["num_playlists"], stats["num_nodes"]
135
+ model = GNN(embedding_dim=64, num_nodes=data.num_nodes, num_playlists=num_playlists, num_layers=3)
136
+ model.load_state_dict(torch.load("pesos_modelo.pth")) # Reemplaza "pesos_modelo.pth" con el nombre de tu archivo de pesos
137
+
138
+ # Define la funci贸n de inferencia
139
+ def predict(edge_index):
140
+ # Convierte la entrada en un objeto PyG Data
141
+ data = Data(edge_index=edge_index)
142
+
143
+ # Realiza la inferencia con el modelo
144
+ model.eval()
145
+ with torch.no_grad():
146
+ output = model.gnn_propagation(data.edge_index)
147
+
148
+ # Aqu铆 puedes realizar cualquier postprocesamiento necesario de las predicciones
149
+ return output
150
+
151
+ # Ejemplo de uso
152
+ if __name__ == "__main__":
153
+ # Aqu铆 puedes realizar pruebas con datos de ejemplo
154
+ edge_index = np.array([[0, 1, 2], [1, 2, 0]]) # Ejemplo de datos de entrada (lista de aristas)
155
+ predictions = predict(edge_index)
156
+ print(predictions)