File size: 1,469 Bytes
44c9b90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from diffusers import DiffusionPipeline
import torch.nn.functional as F

class LindelPipeline(DiffusionPipeline):
    def __init__(self, indel_model, ins_model, del_model):
        super().__init__()

        self.register_modules(indel_model=indel_model, ins_model=ins_model, del_model=del_model)
        Lindel_dlen = int(round((-7 + (49 + 4 * (8 + 2 * self.del_model.linear.weight.shape[0])) ** 0.5) / 2))
        self.dstarts, self.dends = [], []
        for dlen in range(Lindel_dlen - 1, 0, -1):
            for dstart in range(-dlen - 1, 3):
                self.dstarts.append(dstart)
                self.dends.append(dstart + dlen)

    @torch.no_grad()
    def __call__(self, batch):
        indel_proba = F.softmax(self.indel_model(batch["input_indel"].to(self.indel_model.device))["logit"], dim=1)
        ins_base_proba = F.softmax(self.ins_model(batch["input_ins"].to(self.ins_model.device))["logit"], dim=1)
        del_pos_proba = F.softmax(self.del_model(batch["input_del"].to(self.del_model.device))["logit"], dim=1)
        return {
            "del_proba": indel_proba[:, 0],
            "ins_proba": indel_proba[:, 1],
            "ins_base": ["A", "C", "G", "T", "AA", "AC", "AG", "AT", "CA", "CC", "CG", "CT", "GA", "GC", "GG", "GT", "TA", "TC", "TG", "TT", ">2"],
            "ins_base_proba": ins_base_proba,
            "dstart": self.dstarts,
            "dend": self.dends,
            "del_pos_proba": del_pos_proba
        }