Haoxin Chen commited on
Commit
fe42b63
·
1 Parent(s): 2288e12

update ui and bug fix

Browse files
Files changed (2) hide show
  1. app.py +10 -6
  2. videocontrol_test.py +15 -6
app.py CHANGED
@@ -40,7 +40,7 @@ def videocrafter_demo(result_dir='./tmp/'):
40
  lora_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, label='Lora Scale', value=1.0, elem_id="lora_scale")
41
  cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=15.0, elem_id="cfg_scale")
42
  send_btn = gr.Button("Send")
43
- with gr.Tab(label='show'):
44
  output_video_1 = gr.Video().style(width=384)
45
  gr.Examples(examples=t2v_examples,
46
  inputs=[input_text,steps,model_index,eta,cfg_scale,lora_scale],
@@ -59,7 +59,9 @@ def videocrafter_demo(result_dir='./tmp/'):
59
  with gr.Row():
60
  # with gr.Tab(label='input'):
61
  with gr.Column():
62
- vc_input_video = gr.Video().style(width=256)
 
 
63
  with gr.Row():
64
  vc_input_text = gr.Text(label='Prompts')
65
  with gr.Row():
@@ -72,16 +74,18 @@ def videocrafter_demo(result_dir='./tmp/'):
72
  vc_end_btn = gr.Button("Send")
73
  with gr.Tab(label='Result'):
74
  vc_output_info = gr.Text(label='Info')
75
- vc_output_video = gr.Video().style(width=384)
 
 
76
 
77
  gr.Examples(examples=control_examples,
78
  inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
79
- outputs=[vc_output_info, vc_output_video],
80
  fn = videocontrol.get_video,
81
- cache_examples=False
82
  )
83
  vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
84
- outputs=[vc_output_info, vc_output_video],
85
  fn = videocontrol.get_video
86
  )
87
 
 
40
  lora_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, label='Lora Scale', value=1.0, elem_id="lora_scale")
41
  cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=15.0, elem_id="cfg_scale")
42
  send_btn = gr.Button("Send")
43
+ with gr.Tab(label='result'):
44
  output_video_1 = gr.Video().style(width=384)
45
  gr.Examples(examples=t2v_examples,
46
  inputs=[input_text,steps,model_index,eta,cfg_scale,lora_scale],
 
59
  with gr.Row():
60
  # with gr.Tab(label='input'):
61
  with gr.Column():
62
+ with gr.Row():
63
+ vc_input_video = gr.Video(label="Input Video").style(width=256)
64
+ vc_origin_video = gr.Video(label='Center-cropped Video').style(width=256)
65
  with gr.Row():
66
  vc_input_text = gr.Text(label='Prompts')
67
  with gr.Row():
 
74
  vc_end_btn = gr.Button("Send")
75
  with gr.Tab(label='Result'):
76
  vc_output_info = gr.Text(label='Info')
77
+ with gr.Row():
78
+ vc_depth_video = gr.Video(label="Depth Video").style(width=256)
79
+ vc_output_video = gr.Video(label="Generated Video").style(width=256)
80
 
81
  gr.Examples(examples=control_examples,
82
  inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
83
+ outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
84
  fn = videocontrol.get_video,
85
+ cache_examples=os.getenv('SYSTEM') == 'spaces',
86
  )
87
  vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
88
+ outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
89
  fn = videocontrol.get_video
90
  )
91
 
videocontrol_test.py CHANGED
@@ -70,7 +70,7 @@ class VideoControl:
70
  h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
71
  except:
72
  os.remove(input_video)
73
- return 'please input video', None
74
 
75
  if h < w:
76
  scale = h / self.resolution
@@ -82,7 +82,7 @@ class VideoControl:
82
  video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
83
  except:
84
  os.remove(input_video)
85
- return 'load video error', None
86
  video = self.spatial_transform(video)
87
  print('video shape', video.shape)
88
 
@@ -103,14 +103,23 @@ class VideoControl:
103
  filename = prompt
104
  filename = filename.replace("/", "_slash_") if "/" in filename else filename
105
  filename = filename.replace(" ", "_") if " " in filename else filename
 
 
106
  video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
107
- # tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=os.path.join(self.savedir, f'{filename}_depth.mp4'), fps=10)
108
- tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=os.path.join(self.savedir, f'{filename}_sample.mp4'), fps=8)
 
 
 
109
 
110
  print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
111
  # delete video
112
- os.remove(input_video)
113
- return info_str, video_path
 
 
 
 
114
  def download_model(self):
115
  REPO_ID = 'VideoCrafter/t2v-version-1-1'
116
  filename_list = ['models/base_t2v/model.ckpt',
 
70
  h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
71
  except:
72
  os.remove(input_video)
73
+ return 'please input video', None, None, None
74
 
75
  if h < w:
76
  scale = h / self.resolution
 
82
  video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
83
  except:
84
  os.remove(input_video)
85
+ return 'load video error', None, None, None
86
  video = self.spatial_transform(video)
87
  print('video shape', video.shape)
88
 
 
103
  filename = prompt
104
  filename = filename.replace("/", "_slash_") if "/" in filename else filename
105
  filename = filename.replace(" ", "_") if " " in filename else filename
106
+ if len(filename) > 200:
107
+ filename = filename[:200]
108
  video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
109
+ depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
110
+ origin_path = os.path.join(self.savedir, f'{filename}.mp4')
111
+ tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
112
+ tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
113
+ tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)
114
 
115
  print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
116
  # delete video
117
+ (path, input_filename) = os.path.split(input_video)
118
+ if input_filename != 'flamingo.mp4':
119
+ os.remove(input_video)
120
+ print('delete input video')
121
+ # print(input_video)
122
+ return info_str, origin_path, depth_path, video_path
123
  def download_model(self):
124
  REPO_ID = 'VideoCrafter/t2v-version-1-1'
125
  filename_list = ['models/base_t2v/model.ckpt',