HNSCC-MultiOmics-Risk-Feature-Extraction / MultiOmicsGraphAttentionAutoencoderModel.py
VatsalPatel18's picture
Model files
c238491
from transformers import PreTrainedModel
from OmicsConfig import OmicsConfig
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch_geometric.utils import negative_sampling
from torch.nn.functional import cosine_similarity
from torch.optim.lr_scheduler import StepLR
from GATv2EncoderModel import GATv2EncoderModel
from GATv2DecoderModel import GATv2DecoderModel
from EdgeWeightPredictorModel import EdgeWeightPredictorModel
class MultiOmicsGraphAttentionAutoencoderModel(PreTrainedModel):
config_class = OmicsConfig
base_model_prefix = "graph-attention-autoencoder"
def __init__(self, config):
super().__init__(config)
self.encoder = GATv2EncoderModel(config)
self.decoder = GATv2DecoderModel(config)
self.optimizer = AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=config.learning_rate)
self.scheduler = StepLR(self.optimizer, step_size=30, gamma=0.7)
def forward(self, x, edge_index, edge_attr):
z, attention_weights = self.encoder(x, edge_index, edge_attr)
x_reconstructed = self.decoder(z)
return x_reconstructed, attention_weights
def predict_edge_weights(self, z, edge_index):
return self.decoder.predict_edge_weights(z, edge_index)
def train_model(self, data_loader, device):
self.encoder.to(device)
self.decoder.to(device)
self.encoder.train()
self.decoder.train()
total_loss = 0
total_cosine_similarity = 0
loss_weight_node = 1.0
loss_weight_edge = 1.0
loss_weight_edge_attr = 1.0
for data in data_loader:
data = data.to(device)
self.optimizer.zero_grad()
z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
x_reconstructed = self.decoder(z)
node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
edge_loss = edge_reconstruction_loss(z, data.edge_index)
cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
total_cosine_similarity += cos_sim.item()
pred_edge_weights = self.decoder.predict_edge_weights(z, data.edge_index)
edge_weight_loss = edge_weight_reconstruction_loss(pred_edge_weights, data.edge_attr)
loss = (loss_weight_node * node_loss) + (loss_weight_edge * edge_loss) + (loss_weight_edge_attr * edge_weight_loss)
print(f"node_loss: {node_loss}, edge_loss: {edge_loss:.4f}, edge_weight_loss: {edge_weight_loss:.4f}, cosine_similarity: {cos_sim:.4f}")
loss.backward()
self.optimizer.step()
total_loss += loss.item()
avg_loss, avg_cosine_similarity = total_loss / len(data_loader), total_cosine_similarity / len(data_loader)
return avg_loss, avg_cosine_similarity
def fit(self, train_loader, validation_loader, epochs, device):
train_losses = []
val_losses = []
for epoch in range(1, epochs + 1):
train_loss, train_cosine_similarity = self.train_model(train_loader, device)
torch.cuda.empty_cache()
val_loss, val_cosine_similarity = self.validate(validation_loader, device)
print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Cosine Similarity: {train_cosine_similarity:.4f}, Validation Loss: {val_loss:.4f}, Validation Cosine Similarity: {val_cosine_similarity:.4f}")
self.scheduler.step()
return train_losses, val_losses
def validate(self, validation_loader, device):
self.encoder.to(device)
self.decoder.to(device)
self.encoder.eval()
self.decoder.eval()
total_loss = 0
total_cosine_similarity = 0
with torch.no_grad():
for data in validation_loader:
data = data.to(device)
z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
x_reconstructed = self.decoder(z)
node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
edge_loss = edge_reconstruction_loss(z, data.edge_index)
cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
total_cosine_similarity += cos_sim.item()
loss = node_loss + edge_loss
total_loss += loss.item()
avg_loss = total_loss / len(validation_loader)
avg_cosine_similarity = total_cosine_similarity / len(validation_loader)
return avg_loss, avg_cosine_similarity
def evaluate(self, test_loader, device):
self.encoder.to(device)
self.decoder.to(device)
self.encoder.eval()
self.decoder.eval()
total_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_loader:
data = data.to(device)
z, attention_weights = self.encoder(data.x, data.edge_index, data.edge_attr)
x_reconstructed = self.decoder(z)
node_loss = graph_reconstruction_loss(x_reconstructed, data.x)
edge_loss = edge_reconstruction_loss(z, data.edge_index)
cos_sim = cosine_similarity(x_reconstructed, data.x, dim=-1).mean()
total_cosine_similarity += cos_sim.item()
loss = node_loss + edge_loss
total_loss += loss.item()
avg_loss = total_loss / len(validation_loader)
avg_cosine_similarity = total_cosine_similarity / len(validation_loader)
return avg_loss, avg_cosine_similarity
# Define a collate function for the DataLoader
def collate_graph_data(batch):
return Batch.from_data_list(batch)
# Define a function to create a DataLoader
def create_data_loader(train_data, batch_size=1, shuffle=True):
graph_data = list(train_data.values())
return DataLoader(graph_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_graph_data)
# Define functions for the losses
def graph_reconstruction_loss(pred_features, true_features):
return F.mse_loss(pred_features, true_features)
def edge_reconstruction_loss(z, pos_edge_index, neg_edge_index=None):
pos_logits = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)
pos_loss = F.binary_cross_entropy_with_logits(pos_logits, torch.ones_like(pos_logits))
if neg_edge_index is None:
neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
neg_logits = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)
neg_loss = F.binary_cross_entropy_with_logits(neg_logits, torch.zeros_like(neg_logits))
return pos_loss + neg_loss
def edge_weight_reconstruction_loss(pred_weights, true_weights):
pred_weights = pred_weights.squeeze(-1)
return F.mse_loss(pred_weights, true_weights)