qninhdt commited on
Commit
1ec8802
1 Parent(s): 9d93b93
Files changed (1) hide show
  1. baselines/cyclegan-cut/train.py +66 -31
baselines/cyclegan-cut/train.py CHANGED
@@ -4,31 +4,37 @@ from options.train_options import TrainOptions
4
  from data import create_dataset
5
  from models import create_model
6
  from util.visualizer import Visualizer
 
7
 
 
 
 
 
 
 
8
 
9
- if __name__ == '__main__':
10
- opt = TrainOptions().parse() # get training options
11
- dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
12
- dataset_size = len(dataset) # get the number of images in the dataset.
13
 
14
- model = create_model(opt) # create a model given opt.model and other options
15
- print('The number of training images = %d' % dataset_size)
16
-
17
- visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
18
  opt.visualizer = visualizer
19
- total_iters = 0 # the total number of training iterations
20
 
21
  optimize_time = 0.1
22
 
23
  times = []
24
- for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
 
 
25
  epoch_start_time = time.time() # timer for entire epoch
26
- iter_data_time = time.time() # timer for data loading per iteration
27
- epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
28
- visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch
29
 
30
  dataset.set_epoch(epoch)
31
- for i, data in enumerate(dataset): # inner loop within one epoch
32
  iter_start_time = time.time() # timer for computation per iteration
33
  if total_iters % opt.print_freq == 0:
34
  t_data = iter_start_time - iter_data_time
@@ -41,37 +47,66 @@ if __name__ == '__main__':
41
  optimize_start_time = time.time()
42
  if epoch == opt.epoch_count and i == 0:
43
  model.data_dependent_initialize(data)
44
- model.setup(opt) # regular setup: load and print networks; create schedulers
 
 
45
  model.parallelize()
46
  model.set_input(data) # unpack data from dataset and apply preprocessing
47
- model.optimize_parameters() # calculate loss functions, get gradients, update network weights
48
  if len(opt.gpu_ids) > 0:
49
  torch.cuda.synchronize()
50
- optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
 
 
51
 
52
- if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
 
 
53
  save_result = total_iters % opt.update_html_freq == 0
54
  model.compute_visuals()
55
- visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
 
 
56
 
57
- if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
 
 
58
  losses = model.get_current_losses()
59
- visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
 
 
60
  if opt.display_id is None or opt.display_id > 0:
61
- visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
 
 
62
 
63
- if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
64
- print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
65
- print(opt.name) # it's useful to occasionally show the experiment name on console
66
- save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
 
 
 
 
 
 
 
67
  model.save_networks(save_suffix)
68
 
69
  iter_data_time = time.time()
70
 
71
- if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
72
- print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
73
- model.save_networks('latest')
 
 
 
 
 
74
  model.save_networks(epoch)
75
 
76
- print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
77
- model.update_learning_rate() # update learning rates at the end of every epoch.
 
 
 
 
4
  from data import create_dataset
5
  from models import create_model
6
  from util.visualizer import Visualizer
7
+ from tqdm.auto import tqdm
8
 
9
+ if __name__ == "__main__":
10
+ opt = TrainOptions().parse() # get training options
11
+ dataset = create_dataset(
12
+ opt
13
+ ) # create a dataset given opt.dataset_mode and other options
14
+ dataset_size = len(dataset) # get the number of images in the dataset.
15
 
16
+ model = create_model(opt) # create a model given opt.model and other options
17
+ print("The number of training images = %d" % dataset_size)
 
 
18
 
19
+ visualizer = Visualizer(
20
+ opt
21
+ ) # create a visualizer that display/save images and plots
 
22
  opt.visualizer = visualizer
23
+ total_iters = 0 # the total number of training iterations
24
 
25
  optimize_time = 0.1
26
 
27
  times = []
28
+ for epoch in range(
29
+ opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
30
+ ): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
31
  epoch_start_time = time.time() # timer for entire epoch
32
+ iter_data_time = time.time() # timer for data loading per iteration
33
+ epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
34
+ visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch
35
 
36
  dataset.set_epoch(epoch)
37
+ for i, data in tqdm(enumerate(dataset)): # inner loop within one epoch
38
  iter_start_time = time.time() # timer for computation per iteration
39
  if total_iters % opt.print_freq == 0:
40
  t_data = iter_start_time - iter_data_time
 
47
  optimize_start_time = time.time()
48
  if epoch == opt.epoch_count and i == 0:
49
  model.data_dependent_initialize(data)
50
+ model.setup(
51
+ opt
52
+ ) # regular setup: load and print networks; create schedulers
53
  model.parallelize()
54
  model.set_input(data) # unpack data from dataset and apply preprocessing
55
+ model.optimize_parameters() # calculate loss functions, get gradients, update network weights
56
  if len(opt.gpu_ids) > 0:
57
  torch.cuda.synchronize()
58
+ optimize_time = (
59
+ time.time() - optimize_start_time
60
+ ) / batch_size * 0.005 + 0.995 * optimize_time
61
 
62
+ if (
63
+ total_iters % opt.display_freq == 0
64
+ ): # display images on visdom and save images to a HTML file
65
  save_result = total_iters % opt.update_html_freq == 0
66
  model.compute_visuals()
67
+ visualizer.display_current_results(
68
+ model.get_current_visuals(), epoch, save_result
69
+ )
70
 
71
+ if (
72
+ total_iters % opt.print_freq == 0
73
+ ): # print training losses and save logging information to the disk
74
  losses = model.get_current_losses()
75
+ visualizer.print_current_losses(
76
+ epoch, epoch_iter, losses, optimize_time, t_data
77
+ )
78
  if opt.display_id is None or opt.display_id > 0:
79
+ visualizer.plot_current_losses(
80
+ epoch, float(epoch_iter) / dataset_size, losses
81
+ )
82
 
83
+ if (
84
+ total_iters % opt.save_latest_freq == 0
85
+ ): # cache our latest model every <save_latest_freq> iterations
86
+ print(
87
+ "saving the latest model (epoch %d, total_iters %d)"
88
+ % (epoch, total_iters)
89
+ )
90
+ print(
91
+ opt.name
92
+ ) # it's useful to occasionally show the experiment name on console
93
+ save_suffix = "iter_%d" % total_iters if opt.save_by_iter else "latest"
94
  model.save_networks(save_suffix)
95
 
96
  iter_data_time = time.time()
97
 
98
+ if (
99
+ epoch % opt.save_epoch_freq == 0
100
+ ): # cache our model every <save_epoch_freq> epochs
101
+ print(
102
+ "saving the model at the end of epoch %d, iters %d"
103
+ % (epoch, total_iters)
104
+ )
105
+ model.save_networks("latest")
106
  model.save_networks(epoch)
107
 
108
+ print(
109
+ "End of epoch %d / %d \t Time Taken: %d sec"
110
+ % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)
111
+ )
112
+ model.update_learning_rate() # update learning rates at the end of every epoch.