File size: 1,319 Bytes
98e362a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from diffusers import DiffusionPipeline
import torch.nn.functional as F
import numpy as np

class FOREcasTPipeline(DiffusionPipeline):
    def __init__(self, FOREcasT_model, MAX_DEL_SIZE):
        super().__init__()

        self.register_modules(FOREcasT_model=FOREcasT_model)
        self.MAX_DEL_SIZE = MAX_DEL_SIZE
        self.lefts = np.concatenate([
            np.arange(-DEL_SIZE, 1)
            for DEL_SIZE in range(self.MAX_DEL_SIZE, -1, -1)
        ] + [np.zeros(20, np.int64)])
        self.rights = np.concatenate([
            np.arange(0, DEL_SIZE + 1)
            for DEL_SIZE in range(self.MAX_DEL_SIZE, -1, -1)
        ] + [np.zeros(20, np.int64)])
        self.inss = (self.MAX_DEL_SIZE + 2) * (self.MAX_DEL_SIZE + 1) // 2 * [""] + ["A", "C", "G", "T", "AA", "AC", "AG", "AT", "CA", "CC", "CG", "CT", "GA", "GC", "GG", "GT", "TA", "TC", "TG", "TT"]

    @torch.no_grad()
    def __call__(self, batch):
        assert batch["feature"].shape[1] == len(self.lefts), "the possible mutation number of the input feature does not fit the pipeline"
        return {
            "proba": F.softmax(self.FOREcasT_model(batch["feature"].to(self.FOREcasT_model.device))["logit"], dim=-1),
            "left": self.lefts,
            "right": self.rights,
            "ins_seq": self.inss
        }