Haoxin Chen commited on
Commit
2959057
·
1 Parent(s): 15190a9

fix gpu OOM

Browse files
Files changed (2) hide show
  1. i2v_test.py +3 -1
  2. t2v_test.py +3 -1
i2v_test.py CHANGED
@@ -22,7 +22,7 @@ class Image2Video():
22
  model_list = []
23
  for gpu_id in range(gpu_num):
24
  model = instantiate_from_config(model_config)
25
- model = model.cuda(gpu_id)
26
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
27
  model = load_model_checkpoint(model, ckpt_path)
28
  model.eval()
@@ -38,6 +38,7 @@ class Image2Video():
38
  if steps > 60:
39
  steps = 60
40
  model = self.model_list[gpu_id]
 
41
  batch_size=1
42
  channels = model.model.diffusion_model.in_channels
43
  frames = model.temporal_length
@@ -65,6 +66,7 @@ class Image2Video():
65
 
66
  save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
67
  print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
 
68
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
69
 
70
  def download_model(self):
 
22
  model_list = []
23
  for gpu_id in range(gpu_num):
24
  model = instantiate_from_config(model_config)
25
+ # model = model.cuda(gpu_id)
26
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
27
  model = load_model_checkpoint(model, ckpt_path)
28
  model.eval()
 
38
  if steps > 60:
39
  steps = 60
40
  model = self.model_list[gpu_id]
41
+ model = model.cuda()
42
  batch_size=1
43
  channels = model.model.diffusion_model.in_channels
44
  frames = model.temporal_length
 
66
 
67
  save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
68
  print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
69
+ model = model.cpu()
70
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
71
 
72
  def download_model(self):
t2v_test.py CHANGED
@@ -20,7 +20,7 @@ class Text2Video():
20
  model_list = []
21
  for gpu_id in range(gpu_num):
22
  model = instantiate_from_config(model_config)
23
- model = model.cuda(gpu_id)
24
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
25
  model = load_model_checkpoint(model, ckpt_path)
26
  model.eval()
@@ -36,6 +36,7 @@ class Text2Video():
36
  if steps > 60:
37
  steps = 60
38
  model = self.model_list[gpu_id]
 
39
  batch_size=1
40
  channels = model.model.diffusion_model.in_channels
41
  frames = model.temporal_length
@@ -56,6 +57,7 @@ class Text2Video():
56
 
57
  save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
58
  print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
 
59
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
60
 
61
  def download_model(self):
 
20
  model_list = []
21
  for gpu_id in range(gpu_num):
22
  model = instantiate_from_config(model_config)
23
+ # model = model.cuda(gpu_id)
24
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
25
  model = load_model_checkpoint(model, ckpt_path)
26
  model.eval()
 
36
  if steps > 60:
37
  steps = 60
38
  model = self.model_list[gpu_id]
39
+ model = model.cuda()
40
  batch_size=1
41
  channels = model.model.diffusion_model.in_channels
42
  frames = model.temporal_length
 
57
 
58
  save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
59
  print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
60
+ model=model.cpu()
61
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
62
 
63
  def download_model(self):