yerang commited on
Commit
614c2f6
·
verified ·
1 Parent(s): 7931de6

Update stf_utils.py

Browse files
Files changed (1) hide show
  1. stf_utils.py +96 -109
stf_utils.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import os
3
  from concurrent.futures import ThreadPoolExecutor
4
  from pydub import AudioSegment
5
- import cv2
6
  from pathlib import Path
7
  import subprocess
8
  from pathlib import Path
@@ -14,7 +14,6 @@ from tqdm import tqdm
14
 
15
  import stf_alternative
16
 
17
- import spaces
18
 
19
 
20
  def exec_cmd(cmd):
@@ -69,138 +68,126 @@ def merge_audio_video(video_fp, audio_fp, wfp):
69
 
70
 
71
  class STFPipeline:
72
- def __init__(
73
- self,
74
- stf_path: str = "/home/user/app/stf/",
75
- template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
76
- config_path: str = "front_config.json",
77
- checkpoint_path: str = "089.pth",
78
- root_path: str = "works",
79
- wavlm_path: str = "microsoft/wavlm-large",
80
- device: str = "cuda:0"
 
81
  ):
82
- self.device = device
83
- self.stf_path = stf_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  self.config_path = os.path.join(stf_path, config_path)
85
  self.checkpoint_path = os.path.join(stf_path, checkpoint_path)
86
- self.work_root_path = os.path.join(stf_path, root_path)
87
- self.wavlm_path = wavlm_path
88
- self.template_video_path = template_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # 비동기적으로 모델 로딩
91
- self.model = self.load_model()
92
- self.template = self.create_template()
93
 
94
- @spaces.GPU(duration=120)
95
- def load_model(self):
96
- """모델을 생성하고 GPU에 할당."""
97
  model = stf_alternative.create_model(
98
  config_path=self.config_path,
99
  checkpoint_path=self.checkpoint_path,
100
  work_root_path=self.work_root_path,
101
  device=self.device,
102
- wavlm_path=self.wavlm_path
103
  )
104
- return model
105
 
106
- @spaces.GPU(duration=120)
107
- def create_template(self):
108
- """템플릿 생성."""
109
- template = stf_alternative.Template(
110
- model=self.model,
111
  config_path=self.config_path,
112
- template_video_path=self.template_video_path
113
  )
114
- return template
115
 
116
- def execute(self, audio: str) -> str:
117
- """오디오를 입력 받아 비디오를 생성."""
118
- # 폴더 생성
119
- Path("dubbing").mkdir(exist_ok=True)
120
- save_path = os.path.join("dubbing", Path(audio).stem + "--lip.mp4")
121
 
 
 
 
 
 
 
122
  reader = iter(self.template._get_reader(num_skip_frames=0))
 
123
  audio_segment = AudioSegment.from_file(audio)
 
124
  results = []
125
 
126
- # 비동기 프레임 생성
127
- with ThreadPoolExecutor(max_workers=4) as executor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  try:
 
129
  gen_infer = self.template.gen_infer_concurrent(
130
- executor, audio_segment, 0
 
 
131
  )
132
- for idx, (it, _) in enumerate(gen_infer):
133
  frame = next(reader)
134
  composed = self.template.compose(idx, frame, it)
135
- results.append(it["pred"])
136
- except StopIteration:
 
 
137
  pass
138
 
139
- self.images_to_video(results, save_path)
140
- return save_path
141
-
142
- @staticmethod
143
- def images_to_video(images, output_path, fps=24):
144
- """이미지 배열을 비디오로 변환."""
145
- writer = imageio.get_writer(output_path, fps=fps, format="mp4", codec="libx264")
146
- for i in track(range(len(images)), description="비디오 생성 중"):
147
- writer.append_data(images[i])
148
- writer.close()
149
- print(f"비디오 저장 완료: {output_path}")
150
-
151
- # class STFPipeline:
152
- # def __init__(self,
153
- # stf_path: str = "/home/user/app/stf/",
154
- # device: str = "cuda:0",
155
- # template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
156
- # config_path: str = "front_config.json",
157
- # checkpoint_path: str = "089.pth",
158
- # root_path: str = "works"
159
-
160
- # ):
161
-
162
- # config_path = os.path.join(stf_path, config_path)
163
- # checkpoint_path = os.path.join(stf_path, checkpoint_path)
164
- # work_root_path = os.path.join(stf_path, root_path)
165
-
166
- # model = stf_alternative.create_model(
167
- # config_path=config_path,
168
- # checkpoint_path=checkpoint_path,
169
- # work_root_path=work_root_path,
170
- # device=device,
171
- # wavlm_path="microsoft/wavlm-large",
172
- # )
173
- # self.template = stf_alternative.Template(
174
- # model=model,
175
- # config_path=config_path,
176
- # template_video_path=template_video_path,
177
- # )
178
-
179
-
180
- # def execute(self, audio: str):
181
- # Path("dubbing").mkdir(exist_ok=True)
182
- # save_path = os.path.join("dubbing", Path(audio).stem+"--lip.mp4")
183
- # reader = iter(self.template._get_reader(num_skip_frames=0))
184
- # audio_segment = AudioSegment.from_file(audio)
185
- # pivot = 0
186
- # results = []
187
- # with ThreadPoolExecutor(4) as p:
188
- # try:
189
-
190
- # gen_infer = self.template.gen_infer_concurrent(
191
- # p,
192
- # audio_segment,
193
- # pivot,
194
- # )
195
- # for idx, (it, chunk) in enumerate(gen_infer, pivot):
196
- # frame = next(reader)
197
- # composed = self.template.compose(idx, frame, it)
198
- # frame_name = f"{idx}".zfill(5)+".jpg"
199
- # results.append(it['pred'])
200
- # pivot = idx + 1
201
- # except StopIteration as e:
202
- # pass
203
-
204
- # images2video(results, save_path)
205
 
206
- # return save_path
 
2
  import os
3
  from concurrent.futures import ThreadPoolExecutor
4
  from pydub import AudioSegment
5
+ import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
6
  from pathlib import Path
7
  import subprocess
8
  from pathlib import Path
 
14
 
15
  import stf_alternative
16
 
 
17
 
18
 
19
  def exec_cmd(cmd):
 
68
 
69
 
70
  class STFPipeline:
71
+ def __init__(self,
72
+ stf_path: str = "/home/user/app/stf/",
73
+ device: str = "cuda:0",
74
+ template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
75
+ config_path: str = "front_config.json",
76
+ checkpoint_path: str = "089.pth",
77
+ #root_path: str = "works"
78
+ root_path: str = "/tmp/works",
79
+ male : bool = False
80
+
81
  ):
82
+ #os.makedirs(root_path, exist_ok=True)
83
+ import shutil; shutil.copytree('/home/user/app/stf/works', '/tmp/works', dirs_exist_ok=True)
84
+
85
+ import zipfile
86
+
87
+ if not male:
88
+ dir_zip='/tmp/works/preprocess/nasilhong_f_v1_front/crop_video_front_one_piece_dress_nodded_cut.zip'
89
+ dir_target='/tmp/works/preprocess/nasilhong_f_v1_front/'
90
+ zipfile.ZipFile(dir_zip, 'r').extractall(dir_target)
91
+
92
+ dir_zip='/tmp/works/preprocess/nasilhong_f_v1_front/front_one_piece_dress_nodded_cut.zip'
93
+ dir_target='/tmp/works/preprocess/nasilhong_f_v1_front/'
94
+ zipfile.ZipFile(dir_zip, 'r').extractall(dir_target)
95
+ else:
96
+ dir_zip='/tmp/works/preprocess/Ian_v3_front/crop_video_Cam2_2309071202_0012_Natural_Looped.zip'
97
+ dir_target='/tmp/works/preprocess/Ian_v3_front/'
98
+ zipfile.ZipFile(dir_zip, 'r').extractall(dir_target)
99
+
100
+ dir_zip='/tmp/works/preprocess/Ian_v3_front/Cam2_2309071202_0012_Natural_Looped.zip'
101
+ dir_target='/tmp/works/preprocess/Ian_v3_front/'
102
+ zipfile.ZipFile(dir_zip, 'r').extractall(dir_target)
103
+
104
+
105
  self.config_path = os.path.join(stf_path, config_path)
106
  self.checkpoint_path = os.path.join(stf_path, checkpoint_path)
107
+ #self.work_root_path = os.path.join(stf_path, root_path)
108
+ self.work_root_path = os.path.join(root_path)
109
+ self.device = device
110
+ self.template_video_path=os.path.join(stf_path, template_video_path)
111
+
112
+ # model = stf_alternative.create_model(
113
+ # config_path=config_path,
114
+ # checkpoint_path=checkpoint_path,
115
+ # work_root_path=work_root_path,
116
+ # device=device,
117
+ # wavlm_path="microsoft/wavlm-large",
118
+ # )
119
+ # self.template = stf_alternative.Template(
120
+ # model=model,
121
+ # config_path=config_path,
122
+ # template_video_path=template_video_path,
123
+ # )
124
 
125
+
 
 
126
 
127
+ def execute(self, audio: str):
128
+
129
+
130
  model = stf_alternative.create_model(
131
  config_path=self.config_path,
132
  checkpoint_path=self.checkpoint_path,
133
  work_root_path=self.work_root_path,
134
  device=self.device,
135
+ wavlm_path="microsoft/wavlm-large",
136
  )
 
137
 
138
+
139
+ self.template = stf_alternative.Template(
140
+ model=model,
 
 
141
  config_path=self.config_path,
142
+ template_video_path=self.template_video_path,
143
  )
 
144
 
 
 
 
 
 
145
 
146
+
147
+ # Path("dubbing").mkdir(exist_ok=True)
148
+ # save_path = os.path.join("dubbing", Path(audio).stem+"--lip.mp4")
149
+ Path("/tmp/dubbing").mkdir(exist_ok=True)
150
+ save_path = os.path.join("/tmp/dubbing", Path(audio).stem+"--lip.mp4")
151
+
152
  reader = iter(self.template._get_reader(num_skip_frames=0))
153
+
154
  audio_segment = AudioSegment.from_file(audio)
155
+ pivot = 0
156
  results = []
157
 
158
+ # try:
159
+
160
+ # gen_infer = self.template.gen_infer(
161
+ # audio_segment,
162
+ # pivot,
163
+ # )
164
+ # for idx, (it, chunk) in enumerate(gen_infer, pivot):
165
+ # frame = next(reader)
166
+ # composed = self.template.compose(idx, frame, it)
167
+ # frame_name = f"{idx}".zfill(5)+".jpg"
168
+ # results.append(it['pred'])
169
+ # pivot = idx + 1
170
+ # except StopIteration as e:
171
+ # pass
172
+
173
+
174
+ with ThreadPoolExecutor(1) as p:
175
  try:
176
+
177
  gen_infer = self.template.gen_infer_concurrent(
178
+ p,
179
+ audio_segment,
180
+ pivot,
181
  )
182
+ for idx, (it, chunk) in enumerate(gen_infer, pivot):
183
  frame = next(reader)
184
  composed = self.template.compose(idx, frame, it)
185
+ frame_name = f"{idx}".zfill(5)+".jpg"
186
+ results.append(it['pred'])
187
+ pivot = idx + 1
188
+ except StopIteration as e:
189
  pass
190
 
191
+ images2video(results, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ return save_path