|
library(ggplot2) |
|
result.plot <- readRDS('figs/fig.5.prepare.RDS') |
|
result.plot <- result.plot[result.plot$task.type=='Gene',] |
|
result.plot$use.lw <- F |
|
|
|
result.plot <- result.plot[!grepl('.itan.split', result.plot$task.id),] |
|
pick.cond <- 'auc' |
|
|
|
uniq.models <- unique(gsub('.lw', '', result.plot$model)) |
|
|
|
uniq.models <- uniq.models[grepl('/$', uniq.models)] |
|
uniq.genes <- unique(result.plot$task.id) |
|
|
|
for (g in uniq.genes) { |
|
for (m in uniq.models) { |
|
for (f in 0:4) { |
|
lw.loss <- result.plot$val.loss[result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f] |
|
loss <- result.plot$val.loss[result.plot$model == m & result.plot$task.id == g & result.plot$fold==f] |
|
lw.tr.auc <- result.plot$tr.auc[result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f] |
|
tr.auc <- result.plot$tr.auc[result.plot$model == m & result.plot$task.id == g & result.plot$fold==f] |
|
if (pick.cond == 'auc') { |
|
cond <- !is.na(mean(lw.tr.auc)) & lw.tr.auc > tr.auc |
|
} else if (pick.cond == 'loss') { |
|
cond <- !is.na(mean(lw.loss)) & loss > lw.loss |
|
} else if (pick.cond == 'auc+loss') { |
|
cond <- !is.na(lw.loss) & !is.na(lw.tr.auc) & (tr.auc/loss > lw.tr.auc/lw.loss) |
|
} else { |
|
cond <- F |
|
} |
|
if (cond) { |
|
|
|
to.remove <- which(result.plot$model == m & result.plot$task.id == g & result.plot$fold==f) |
|
to.anno <- which(result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f) |
|
result.plot$model[to.anno] <- m |
|
result.plot$use.lw[to.anno] <- T |
|
result.plot <- result.plot[-to.remove,] |
|
} else { |
|
to.remove <- which(result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f) |
|
result.plot <- result.plot[-to.remove,] |
|
} |
|
} |
|
} |
|
} |
|
|
|
result.plot$task.name[result.plot$task.id == "Q14524.clean"] <- "Gene: SCN5A" |
|
result.plot <- result.plot[result.plot$model %in% c("PreMode/", |
|
"PreMode.noESM/", |
|
"PreMode.noMSA/", |
|
"PreMode.noStructure/", |
|
"PreMode.ptm/", |
|
"ESM.SLP/", |
|
"PreMode.noPretrain/", |
|
"random.forest" |
|
),] |
|
model.dic <- c("PreMode/"="1: PreMode", |
|
"PreMode.noPretrain/"="2: PreMode: no Pretrain", |
|
"random.forest"="3: Random Forest", |
|
"ESM.SLP/"="4: ESM + SLP", |
|
"PreMode.noESM/"="5: PreMode: no ESM", |
|
"PreMode.noMSA/"="6: PreMode: no MSA", |
|
"PreMode.noStructure/"="7: PreMode: no Structure", |
|
"PreMode.ptm/"="8: PreMode: add ptm") |
|
result.plot$model <- model.dic[result.plot$model] |
|
|
|
|
|
uniq.result.plot <- result.plot[result.plot$fold==0,] |
|
for (i in 1:dim(uniq.result.plot)[1]) { |
|
aucs <- result.plot$auc[result.plot$model==uniq.result.plot$model[i] & |
|
result.plot$task.name==uniq.result.plot$task.name[i]] |
|
uniq.result.plot$auc[i] = mean(aucs, na.rm=T) |
|
uniq.result.plot$auc.se[i] = sd(aucs, na.rm=T) / sqrt(length(aucs)) |
|
} |
|
|
|
uniq.model.result.plot <- uniq.result.plot[!duplicated(uniq.result.plot$model),] |
|
for (i in 1:dim(uniq.model.result.plot)[1]) { |
|
task.sizes.lof <- uniq.result.plot$task.size.lof[uniq.result.plot$model==uniq.model.result.plot$model[i]] |
|
task.sizes.gof <- uniq.result.plot$task.size.gof[uniq.result.plot$model==uniq.model.result.plot$model[i]] |
|
|
|
task.sizes <- task.sizes.lof * task.sizes.gof / (task.sizes.lof + task.sizes.gof) |
|
aucs <- uniq.result.plot$auc[uniq.result.plot$model==uniq.model.result.plot$model[i]] |
|
auc.ses <- uniq.result.plot$auc.se[uniq.result.plot$model==uniq.model.result.plot$model[i]] |
|
|
|
task.sizes <- task.sizes[!is.na(aucs)] |
|
aucs <- aucs[!is.na(aucs)] |
|
auc.ses <- auc.ses[!is.na(auc.ses)] |
|
uniq.model.result.plot$auc[i] <- sum(aucs * task.sizes / sum(task.sizes), na.rm=T) |
|
uniq.model.result.plot$auc.se[i] <- sum(auc.ses * task.sizes / sum(task.sizes), na.rm=T) |
|
} |
|
|
|
uniq.model.result.plot$model.type <- 'PreMode: Ablation' |
|
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "8: PreMode: add ptm"] <- 'PreMode: add ptm' |
|
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "4: ESM + SLP"] <- 'Baselines' |
|
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "3: Random Forest"] <- 'Baselines' |
|
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "1: PreMode"] <- 'PreMode' |
|
uniq.model.result.plot$model.type <- factor(uniq.model.result.plot$model.type, |
|
levels = c('PreMode', 'PreMode: add ptm', 'PreMode: Ablation', 'Baselines')) |
|
|
|
p <- ggplot(uniq.model.result.plot, aes(x=model, y=auc, col=model.type)) + |
|
geom_point() + scale_color_manual(values = c("#F8766D", "#CD9600", "#999999", "#619CFF")) + |
|
geom_errorbar(aes(ymin=auc-auc.se, ymax=auc+auc.se), width=.2) + |
|
coord_flip() + guides(col=guide_legend(ncol=2)) + |
|
labs(x = "models", y = "auc", fill = "model") + |
|
theme_bw() + ylim(0.5, 0.9) + ggtitle('PreMode ablation analysis') + |
|
ggeasy::easy_center_title() + |
|
theme(axis.text.x = element_text(angle=60, vjust = 1, hjust = 1), |
|
text = element_text(size = 16), |
|
plot.title = element_text(size=15), |
|
legend.text = element_text(size=10), |
|
legend.title = element_blank(), |
|
legend.position="bottom", |
|
legend.direction="horizontal") + |
|
ggeasy::easy_center_title() |
|
ggsave('figs/fig.5b.pdf', p, height=5, width=6) |
|
|
|
|
|
|