File size: 2,298 Bytes
a220803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from diffusers import ModelMixin, ConfigMixin
from torch import nn
import os
import json
import pytorch_lightning as pl
from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin


class VideoBaseAE(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = False
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
    @classmethod
    def load_from_checkpoint(cls, model_path):
        with open(os.path.join(model_path, "config.json"), "r") as file:
            config = json.load(file)
        state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu")
        if 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        model = cls(config=cls.CONFIGURATION_CLS(**config))
        model.load_state_dict(state_dict)
        return model
    
    @classmethod
    def download_and_load_model(cls, model_name, cache_dir=None):
        pass
    
    def encode(self, x: torch.Tensor, *args, **kwargs):
        pass

    def decode(self, encoding: torch.Tensor, *args, **kwargs):
        pass

class VideoBaseAE_PL(pl.LightningModule, ModelMixin, ConfigMixin):
    config_name = "config.json"
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
    def encode(self, x: torch.Tensor, *args, **kwargs):
        pass

    def decode(self, encoding: torch.Tensor, *args, **kwargs):
        pass
    
    @property
    def num_training_steps(self) -> int:
        """Total training steps inferred from datamodule and devices."""
        if self.trainer.max_steps:
            return self.trainer.max_steps
    
        limit_batches = self.trainer.limit_train_batches
        batches = len(self.train_dataloader())
        batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)     
    
        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
        if self.trainer.tpu_cores:
            num_devices = max(num_devices, self.trainer.tpu_cores)
    
        effective_accum = self.trainer.accumulate_grad_batches * num_devices
        return (batches // effective_accum) * self.trainer.max_epochs