File size: 463 Bytes
44bb39d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from diffusers import DiffusionPipeline
class CRISPRTransformerPipeline(DiffusionPipeline):
def __init__(self, CRISPR_transformer_model):
super().__init__()
self.register_modules(CRISPR_transformer_model=CRISPR_transformer_model)
@torch.no_grad()
def __call__(self, batch):
return {
"logit": self.CRISPR_transformer_model(batch["refcode"].to(self.CRISPR_transformer_model.device))["logit"]
} |