BartPoint commited on
Commit
66ac02c
·
1 Parent(s): e8ec4e4

Should fix V2

Browse files
Files changed (1) hide show
  1. app_multi.py +30 -18
app_multi.py CHANGED
@@ -70,7 +70,16 @@ app = gr.Blocks(
70
  )
71
 
72
  # Load hubert model
73
- hubert_model = util.load_hubert_model(config.device, args.hubert)
 
 
 
 
 
 
 
 
 
74
  hubert_model.eval()
75
 
76
  # Load models
@@ -91,26 +100,29 @@ for model_name in multi_cfg.get('models'):
91
  map_location='cpu'
92
  )
93
  tgt_sr = cpt['config'][-1]
94
- cpt['config'][-3] = cpt['weight']['emb_g.weight'].shape[0] # n_spk
95
-
96
- if_f0 = cpt.get('f0', 1)
97
- net_g: Union[SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono]
98
- if if_f0 == 1:
99
- net_g = SynthesizerTrnMs768NSFsid(
100
- *cpt['config'],
101
- is_half=util.is_half(config.device)
102
- )
103
- else:
104
- net_g = SynthesizerTrnMs768NSFsid_nono(*cpt['config'])
105
-
 
 
 
106
  del net_g.enc_q
107
 
108
- # According to original code, this thing seems necessary.
109
- print(net_g.load_state_dict(cpt['weight'], strict=False))
110
-
111
  net_g.eval().to(config.device)
112
- net_g = net_g.half() if util.is_half(config.device) else net_g.float()
113
-
 
 
114
  vc = VC(tgt_sr, config)
115
 
116
  loaded_models.append(dict(
 
70
  )
71
 
72
  # Load hubert model
73
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
74
+ ["hubert_base.pt"],
75
+ suffix="",
76
+ )
77
+ hubert_model = models[0]
78
+ hubert_model = hubert_model.to(config.device)
79
+ if config.is_half:
80
+ hubert_model = hubert_model.half()
81
+ else:
82
+ hubert_model = hubert_model.float()
83
  hubert_model.eval()
84
 
85
  # Load models
 
100
  map_location='cpu'
101
  )
102
  tgt_sr = cpt['config'][-1]
103
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
104
+ if_f0 = cpt.get("f0", 1)
105
+ version = cpt.get("version", "v1")
106
+ if version == "v1":
107
+ if if_f0 == 1:
108
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
109
+ else:
110
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
111
+ model_version = "V1"
112
+ elif version == "v2":
113
+ if if_f0 == 1:
114
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
115
+ else:
116
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
117
+ model_version = "V2"
118
  del net_g.enc_q
119
 
120
+ print(net_g.load_state_dict(cpt["weight"], strict=False))
 
 
121
  net_g.eval().to(config.device)
122
+ if config.is_half:
123
+ net_g = net_g.half()
124
+ else:
125
+ net_g = net_g.float()
126
  vc = VC(tgt_sr, config)
127
 
128
  loaded_models.append(dict(