ldkong commited on
Commit
11e4216
·
1 Parent(s): 7a902b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -20
app.py CHANGED
@@ -6,28 +6,74 @@ import imageio
6
  import cv2
7
 
8
 
9
- class Generator(nn.Module):
10
- # Refer to the link below for explanations about nc, nz, and ngf
11
- # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
12
- def __init__(self, nc=4, nz=100, ngf=64):
13
- super(Generator, self).__init__()
14
- self.network = nn.Sequential(
15
- nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
16
- nn.BatchNorm2d(ngf * 4),
17
- nn.ReLU(True),
18
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
19
- nn.BatchNorm2d(ngf * 2),
20
- nn.ReLU(True),
21
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
22
- nn.BatchNorm2d(ngf),
23
- nn.ReLU(True),
24
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
25
- nn.Tanh(),
26
- )
 
27
 
28
  def forward(self, input):
29
- output = self.network(input)
30
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def display_gif(file_name, save_name):
 
6
  import cv2
7
 
8
 
9
+ class RelationModuleMultiScale(torch.nn.Module):
10
+
11
+ def __init__(self, img_feature_dim, num_bottleneck, num_frames):
12
+ super(RelationModuleMultiScale, self).__init__()
13
+ self.subsample_num = 3
14
+ self.img_feature_dim = img_feature_dim
15
+ self.scales = [i for i in range(num_frames, 1, -1)]
16
+ self.relations_scales = []
17
+ self.subsample_scales = []
18
+ for scale in self.scales:
19
+ relations_scale = self.return_relationset(num_frames, scale)
20
+ self.relations_scales.append(relations_scale)
21
+ self.subsample_scales.append(min(self.subsample_num, len(relations_scale)))
22
+ self.num_frames = num_frames
23
+ self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist
24
+ for i in range(len(self.scales)):
25
+ scale = self.scales[i]
26
+ fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU())
27
+ self.fc_fusion_scales += [fc_fusion]
28
 
29
  def forward(self, input):
30
+ act_scale_1 = input[:, self.relations_scales[0][0] , :]
31
+ act_scale_1 = act_scale_1.view(act_scale_1.size(0), self.scales[0] * self.img_feature_dim)
32
+ act_scale_1 = self.fc_fusion_scales[0](act_scale_1)
33
+ act_scale_1 = act_scale_1.unsqueeze(1)
34
+ act_all = act_scale_1.clone()
35
+ for scaleID in range(1, len(self.scales)):
36
+ act_relation_all = torch.zeros_like(act_scale_1)
37
+ num_total_relations = len(self.relations_scales[scaleID])
38
+ num_select_relations = self.subsample_scales[scaleID]
39
+ idx_relations_evensample = [int(ceil(i * num_total_relations / num_select_relations)) for i in range(num_select_relations)]
40
+ for idx in idx_relations_evensample:
41
+ act_relation = input[:, self.relations_scales[scaleID][idx], :]
42
+ act_relation = act_relation.view(act_relation.size(0), self.scales[scaleID] * self.img_feature_dim)
43
+ act_relation = self.fc_fusion_scales[scaleID](act_relation)
44
+ act_relation = act_relation.unsqueeze(1)
45
+ act_relation_all += act_relation
46
+ act_all = torch.cat((act_all, act_relation_all), 1)
47
+ return act_all
48
+
49
+ def return_relationset(self, num_frames, num_frames_relation):
50
+ import itertools
51
+ return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
52
+
53
+
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument('--dataset', default='Sprite', help='datasets')
56
+ parser.add_argument('--data_root', default='dataset', help='root directory for data')
57
+ parser.add_argument('--num_class', type=int, default=15, help='the number of class for jester dataset')
58
+ parser.add_argument('--input_type', default='image', choices=['feature', 'image'], help='the type of input')
59
+ parser.add_argument('--src', default='domain_1', help='source domain')
60
+ parser.add_argument('--tar', default='domain_2', help='target domain')
61
+ parser.add_argument('--num_segments', type=int, default=8, help='the number of frame segment')
62
+ parser.add_argument('--backbone', type=str, default="dcgan", choices=['dcgan', 'resnet101', 'I3Dpretrain','I3Dfinetune'], help='backbone')
63
+ parser.add_argument('--channels', default=3, type=int, help='input channels for image inputs')
64
+ parser.add_argument('--add_fc', default=1, type=int, metavar='M', help='number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2)')
65
+ parser.add_argument('--fc_dim', type=int, default=1024, help='dimension of added fc')
66
+ parser.add_argument('--frame_aggregation', type=str, default='trn', choices=[ 'rnn', 'trn'], help='aggregation of frame features (none if baseline_type is not video)')
67
+ parser.add_argument('--dropout_rate', default=0.5, type=float, help='dropout ratio for frame-level feature (default: 0.5)')
68
+ parser.add_argument('--f_dim', type=int, default=512, help='dim of f')
69
+ parser.add_argument('--z_dim', type=int, default=512, help='dimensionality of z_t')
70
+ parser.add_argument('--f_rnn_layers', type=int, default=1, help='number of layers (content lstm)')
71
+ parser.add_argument('--use_bn', type=str, default='none', choices=['none', 'AdaBN', 'AutoDIAL'], help='normalization-based methods')
72
+ parser.add_argument('--prior_sample', type=str, default='random', choices=['random', 'post'], help='how to sample prior')
73
+ parser.add_argument('--batch_size', default=128, type=int, help='-batch size')
74
+ parser.add_argument('--use_attn', type=str, default='TransAttn', choices=['none', 'TransAttn', 'general'], help='attention-mechanism')
75
+ parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads')
76
+ opt = parser.parse_args(args=[])
77
 
78
 
79
  def display_gif(file_name, save_name):