Spaces:
Runtime error
Runtime error
Haoxin Chen
commited on
Commit
·
fe42b63
1
Parent(s):
2288e12
update ui and bug fix
Browse files- app.py +10 -6
- videocontrol_test.py +15 -6
app.py
CHANGED
@@ -40,7 +40,7 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
40 |
lora_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, label='Lora Scale', value=1.0, elem_id="lora_scale")
|
41 |
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=15.0, elem_id="cfg_scale")
|
42 |
send_btn = gr.Button("Send")
|
43 |
-
with gr.Tab(label='
|
44 |
output_video_1 = gr.Video().style(width=384)
|
45 |
gr.Examples(examples=t2v_examples,
|
46 |
inputs=[input_text,steps,model_index,eta,cfg_scale,lora_scale],
|
@@ -59,7 +59,9 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
59 |
with gr.Row():
|
60 |
# with gr.Tab(label='input'):
|
61 |
with gr.Column():
|
62 |
-
|
|
|
|
|
63 |
with gr.Row():
|
64 |
vc_input_text = gr.Text(label='Prompts')
|
65 |
with gr.Row():
|
@@ -72,16 +74,18 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
72 |
vc_end_btn = gr.Button("Send")
|
73 |
with gr.Tab(label='Result'):
|
74 |
vc_output_info = gr.Text(label='Info')
|
75 |
-
|
|
|
|
|
76 |
|
77 |
gr.Examples(examples=control_examples,
|
78 |
inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
|
79 |
-
outputs=[vc_output_info, vc_output_video],
|
80 |
fn = videocontrol.get_video,
|
81 |
-
cache_examples=
|
82 |
)
|
83 |
vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
|
84 |
-
outputs=[vc_output_info, vc_output_video],
|
85 |
fn = videocontrol.get_video
|
86 |
)
|
87 |
|
|
|
40 |
lora_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, label='Lora Scale', value=1.0, elem_id="lora_scale")
|
41 |
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=15.0, elem_id="cfg_scale")
|
42 |
send_btn = gr.Button("Send")
|
43 |
+
with gr.Tab(label='result'):
|
44 |
output_video_1 = gr.Video().style(width=384)
|
45 |
gr.Examples(examples=t2v_examples,
|
46 |
inputs=[input_text,steps,model_index,eta,cfg_scale,lora_scale],
|
|
|
59 |
with gr.Row():
|
60 |
# with gr.Tab(label='input'):
|
61 |
with gr.Column():
|
62 |
+
with gr.Row():
|
63 |
+
vc_input_video = gr.Video(label="Input Video").style(width=256)
|
64 |
+
vc_origin_video = gr.Video(label='Center-cropped Video').style(width=256)
|
65 |
with gr.Row():
|
66 |
vc_input_text = gr.Text(label='Prompts')
|
67 |
with gr.Row():
|
|
|
74 |
vc_end_btn = gr.Button("Send")
|
75 |
with gr.Tab(label='Result'):
|
76 |
vc_output_info = gr.Text(label='Info')
|
77 |
+
with gr.Row():
|
78 |
+
vc_depth_video = gr.Video(label="Depth Video").style(width=256)
|
79 |
+
vc_output_video = gr.Video(label="Generated Video").style(width=256)
|
80 |
|
81 |
gr.Examples(examples=control_examples,
|
82 |
inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
|
83 |
+
outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
|
84 |
fn = videocontrol.get_video,
|
85 |
+
cache_examples=os.getenv('SYSTEM') == 'spaces',
|
86 |
)
|
87 |
vc_end_btn.click(inputs=[vc_input_video, vc_input_text, frame_stride, vc_steps, vc_cfg_scale, vc_eta],
|
88 |
+
outputs=[vc_output_info, vc_origin_video, vc_depth_video, vc_output_video],
|
89 |
fn = videocontrol.get_video
|
90 |
)
|
91 |
|
videocontrol_test.py
CHANGED
@@ -70,7 +70,7 @@ class VideoControl:
|
|
70 |
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
|
71 |
except:
|
72 |
os.remove(input_video)
|
73 |
-
return 'please input video', None
|
74 |
|
75 |
if h < w:
|
76 |
scale = h / self.resolution
|
@@ -82,7 +82,7 @@ class VideoControl:
|
|
82 |
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
|
83 |
except:
|
84 |
os.remove(input_video)
|
85 |
-
return 'load video error', None
|
86 |
video = self.spatial_transform(video)
|
87 |
print('video shape', video.shape)
|
88 |
|
@@ -103,14 +103,23 @@ class VideoControl:
|
|
103 |
filename = prompt
|
104 |
filename = filename.replace("/", "_slash_") if "/" in filename else filename
|
105 |
filename = filename.replace(" ", "_") if " " in filename else filename
|
|
|
|
|
106 |
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
109 |
|
110 |
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
|
111 |
# delete video
|
112 |
-
os.
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
def download_model(self):
|
115 |
REPO_ID = 'VideoCrafter/t2v-version-1-1'
|
116 |
filename_list = ['models/base_t2v/model.ckpt',
|
|
|
70 |
h, w, c = VideoReader(input_video, ctx=cpu(0))[0].shape
|
71 |
except:
|
72 |
os.remove(input_video)
|
73 |
+
return 'please input video', None, None, None
|
74 |
|
75 |
if h < w:
|
76 |
scale = h / self.resolution
|
|
|
82 |
video, info_str = load_video(input_video, frame_stride, video_size=(h, w), video_frames=16)
|
83 |
except:
|
84 |
os.remove(input_video)
|
85 |
+
return 'load video error', None, None, None
|
86 |
video = self.spatial_transform(video)
|
87 |
print('video shape', video.shape)
|
88 |
|
|
|
103 |
filename = prompt
|
104 |
filename = filename.replace("/", "_slash_") if "/" in filename else filename
|
105 |
filename = filename.replace(" ", "_") if " " in filename else filename
|
106 |
+
if len(filename) > 200:
|
107 |
+
filename = filename[:200]
|
108 |
video_path = os.path.join(self.savedir, f'{filename}_sample.mp4')
|
109 |
+
depth_path = os.path.join(self.savedir, f'{filename}_depth.mp4')
|
110 |
+
origin_path = os.path.join(self.savedir, f'{filename}.mp4')
|
111 |
+
tensor_to_mp4(video=video.detach().cpu(), savepath=origin_path, fps=8)
|
112 |
+
tensor_to_mp4(video=batch_conds.detach().cpu(), savepath=depth_path, fps=8)
|
113 |
+
tensor_to_mp4(video=batch_samples.detach().cpu(), savepath=video_path, fps=8)
|
114 |
|
115 |
print(f"Saved in {video_path}. Time used: {(time.time() - start):.2f} seconds")
|
116 |
# delete video
|
117 |
+
(path, input_filename) = os.path.split(input_video)
|
118 |
+
if input_filename != 'flamingo.mp4':
|
119 |
+
os.remove(input_video)
|
120 |
+
print('delete input video')
|
121 |
+
# print(input_video)
|
122 |
+
return info_str, origin_path, depth_path, video_path
|
123 |
def download_model(self):
|
124 |
REPO_ID = 'VideoCrafter/t2v-version-1-1'
|
125 |
filename_list = ['models/base_t2v/model.ckpt',
|