DmitrMakeev commited on
Commit
f03fe1a
1 Parent(s): 0b2b527

Upload interface.py

Browse files
Files changed (1) hide show
  1. tools/interface.py +190 -0
tools/interface.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from skimage import io,img_as_float32
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import subprocess
8
+ import pandas
9
+ from models.audio2pose import audio2poseLSTM
10
+ from scipy.io import wavfile
11
+ import python_speech_features
12
+ import pyworld
13
+ import config
14
+ import json
15
+ from scipy.interpolate import interp1d
16
+
17
+ def inter_pitch(y,y_flag):
18
+ frame_num = y.shape[0]
19
+ i = 0
20
+ last = -1
21
+ while(i<frame_num):
22
+ if y_flag[i] == 0:
23
+ while True:
24
+ if y_flag[i]==0:
25
+ if i == frame_num-1:
26
+ if last !=-1:
27
+ y[last+1:] = y[last]
28
+ i+=1
29
+ break
30
+ i+=1
31
+ else:
32
+ break
33
+ if i >= frame_num:
34
+ break
35
+ elif last == -1:
36
+ y[:i] = y[i]
37
+ else:
38
+ inter_num = i-last+1
39
+ fy = np.array([y[last],y[i]])
40
+ fx = np.linspace(0, 1, num=2)
41
+ f = interp1d(fx,fy)
42
+ fx_new = np.linspace(0,1,inter_num)
43
+ fy_new = f(fx_new)
44
+ y[last+1:i] = fy_new[1:-1]
45
+ last = i
46
+ i+=1
47
+
48
+ else:
49
+ last = i
50
+ i+=1
51
+ return y
52
+
53
+
54
+ def load_ckpt(checkpoint_path, generator = None, kp_detector = None, ph2kp = None):
55
+ checkpoint = torch.load(checkpoint_path)
56
+ if ph2kp is not None:
57
+ ph2kp.load_state_dict(checkpoint['ph2kp'])
58
+ if generator is not None:
59
+ generator.load_state_dict(checkpoint['generator'])
60
+ if kp_detector is not None:
61
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
62
+
63
+ def get_img_pose(img_path):
64
+ processor = config.OPENFACE_POSE_EXTRACTOR_PATH
65
+
66
+ tmp_dir = "samples/tmp_dir"
67
+ os.makedirs((tmp_dir),exist_ok=True)
68
+ subprocess.call([processor, "-f", img_path, "-out_dir", tmp_dir, "-pose"])
69
+
70
+ img_file = os.path.basename(img_path)[:-4]+".csv"
71
+ csv_file = os.path.join(tmp_dir,img_file)
72
+ pos_data = pandas.read_csv(csv_file)
73
+ i = 0
74
+ pose = [pos_data["pose_Rx"][i], pos_data["pose_Ry"][i], pos_data["pose_Rz"][i],pos_data["pose_Tx"][i], pos_data["pose_Ty"][i], pos_data["pose_Tz"][i]]
75
+ # pose = [pose]
76
+ pose = np.array(pose,dtype=np.float32)
77
+ return pose
78
+
79
+ def read_img(path):
80
+ img = io.imread(path)[:,:,:3]
81
+ img = cv2.resize(img, (256, 256))
82
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
83
+ img = np.array(img_as_float32(img))
84
+ img = img.transpose((2, 0, 1))
85
+ img = torch.from_numpy(img).unsqueeze(0)
86
+ return img
87
+
88
+
89
+ def parse_phoneme_file(phoneme_path,use_index = True):
90
+ with open(phoneme_path,'r') as f:
91
+ result_text = json.load(f)
92
+ frame_num = int(result_text[-1]['phones'][-1]['ed']/100*25)
93
+ phoneset_list = []
94
+ index = 0
95
+
96
+ word_len = len(result_text)
97
+ word_index = 0
98
+ phone_index = 0
99
+ cur_phone_list = result_text[0]["phones"]
100
+ phone_len = len(cur_phone_list)
101
+ cur_end = cur_phone_list[0]["ed"]
102
+
103
+ phone_list = []
104
+
105
+ phoneset_list.append(cur_phone_list[0]["ph"])
106
+ i = 0
107
+ while i < frame_num:
108
+ if i * 4 < cur_end:
109
+ phone_list.append(cur_phone_list[phone_index]["ph"])
110
+ i += 1
111
+ else:
112
+ phone_index += 1
113
+ if phone_index >= phone_len:
114
+ word_index += 1
115
+ if word_index >= word_len:
116
+ phone_list.append(cur_phone_list[-1]["ph"])
117
+ i += 1
118
+ else:
119
+ phone_index = 0
120
+ cur_phone_list = result_text[word_index]["phones"]
121
+ phone_len = len(cur_phone_list)
122
+ cur_end = cur_phone_list[phone_index]["ed"]
123
+ phoneset_list.append(cur_phone_list[phone_index]["ph"])
124
+ index += 1
125
+ else:
126
+ # print(word_index,phone_index)
127
+ cur_end = cur_phone_list[phone_index]["ed"]
128
+ phoneset_list.append(cur_phone_list[phone_index]["ph"])
129
+ index += 1
130
+
131
+ with open("phindex.json") as f:
132
+ ph2index = json.load(f)
133
+ if use_index:
134
+ phone_list = [ph2index[p] for p in phone_list]
135
+ saves = {"phone_list": phone_list}
136
+
137
+ return saves
138
+
139
+ def get_audio_feature_from_audio(audio_path):
140
+ sample_rate, audio = wavfile.read(audio_path)
141
+ if len(audio.shape) == 2:
142
+ if np.min(audio[:, 0]) <= 0:
143
+ audio = audio[:, 1]
144
+ else:
145
+ audio = audio[:, 0]
146
+
147
+ audio = audio - np.mean(audio)
148
+ audio = audio / np.max(np.abs(audio))
149
+ a = python_speech_features.mfcc(audio, sample_rate)
150
+ b = python_speech_features.logfbank(audio, sample_rate)
151
+ c, _ = pyworld.harvest(audio, sample_rate, frame_period=10)
152
+ c_flag = (c == 0.0) ^ 1
153
+ c = inter_pitch(c, c_flag)
154
+ c = np.expand_dims(c, axis=1)
155
+ c_flag = np.expand_dims(c_flag, axis=1)
156
+ frame_num = np.min([a.shape[0], b.shape[0], c.shape[0]])
157
+
158
+ cat = np.concatenate([a[:frame_num], b[:frame_num], c[:frame_num], c_flag[:frame_num]], axis=1)
159
+ return cat
160
+
161
+ def get_pose_from_audio(img,audio,audio2pose):
162
+
163
+ num_frame = len(audio) // 4
164
+
165
+ minv = np.array([-0.6, -0.6, -0.6, -128.0, -128.0, 128.0], dtype=np.float32)
166
+ maxv = np.array([0.6, 0.6, 0.6, 128.0, 128.0, 384.0], dtype=np.float32)
167
+ generator = audio2poseLSTM().cuda().eval()
168
+
169
+ ckpt_para = torch.load(audio2pose)
170
+
171
+ generator.load_state_dict(ckpt_para["generator"])
172
+ generator.eval()
173
+
174
+
175
+ audio_seq = []
176
+ for i in range(num_frame):
177
+ audio_seq.append(audio[i*4:i*4+4])
178
+
179
+ audio = torch.from_numpy(np.array(audio_seq,dtype=np.float32)).unsqueeze(0).cuda()
180
+
181
+ x = {}
182
+ x ["img"] = img
183
+ x["audio"] = audio
184
+ poses = generator(x)
185
+
186
+ poses = poses.cpu().data.numpy()[0]
187
+ poses = (poses+1)/2*(maxv-minv)+minv
188
+
189
+ return poses
190
+