File size: 2,660 Bytes
2890711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

import torch.nn as nn

from sgm.models.diffusion import DiffusionEngine
from sgm.motionctrl.modified_svd import (
                                         _forward_VideoTransformerBlock_attan2,
                                         forward_SpatialVideoTransformer,
                                         forward_VideoTransformerBlock,
                                         forward_VideoUnet)

class CameraMotionControl(DiffusionEngine):
    def __init__(self,
                 pose_embedding_dim = 1,
                 pose_dim = 12,
                 *args, **kwargs):
        
        if 'ckpt_path' in kwargs:
            ckpt_path = kwargs.pop('ckpt_path')
        else:
            ckpt_path = None

        self.use_checkpoint = kwargs['network_config']['params']['use_checkpoint']

        super().__init__(*args, **kwargs)

        bound_method = forward_VideoUnet.__get__(
                self.model.diffusion_model, 
                self.model.diffusion_model.__class__)
        setattr(self.model.diffusion_model, 'forward', bound_method)

        self.train_module_names = []
        for _name, _module in self.model.diffusion_model.named_modules():
            if _module.__class__.__name__ == 'VideoTransformerBlock':
                bound_method = forward_VideoTransformerBlock.__get__(
                    _module, _module.__class__)
                setattr(_module, 'forward', bound_method)

                
                bound_method = _forward_VideoTransformerBlock_attan2.__get__(
                    _module, _module.__class__)
                setattr(_module, '_forward', bound_method)
                
                cc_projection = nn.Linear(_module.attn2.to_q.in_features + pose_embedding_dim*pose_dim, _module.attn2.to_q.in_features) # 1024
                nn.init.eye_(list(cc_projection.parameters())[0][:_module.attn2.to_q.in_features, :_module.attn2.to_q.in_features])
                nn.init.zeros_(list(cc_projection.parameters())[1])
            
                cc_projection.requires_grad_(True)

                _module.add_module('cc_projection', cc_projection)

                self.train_module_names.append(f'{_name}.cc_projection')
                
                self.train_module_names.append(f'{_name}.attn2')
                self.train_module_names.append(f'{_name}.norm2')


            if _module.__class__.__name__ == 'SpatialVideoTransformer':
                bound_method = forward_SpatialVideoTransformer.__get__(
                    _module, _module.__class__)
                setattr(_module, 'forward', bound_method)

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path)