Spaces:
Runtime error
Runtime error
ShaoTengLiu
commited on
Commit
·
e3712a5
1
Parent(s):
f5c12d4
update
Browse files- trainer.py +10 -5
trainer.py
CHANGED
@@ -11,6 +11,7 @@ import sys
|
|
11 |
import gradio as gr
|
12 |
import slugify
|
13 |
import torch
|
|
|
14 |
from huggingface_hub import HfApi
|
15 |
from omegaconf import OmegaConf
|
16 |
|
@@ -33,16 +34,20 @@ class Trainer:
|
|
33 |
self.checkpoint_dir = pathlib.Path('checkpoints')
|
34 |
self.checkpoint_dir.mkdir(exist_ok=True)
|
35 |
|
36 |
-
def download_base_model(self, base_model_id: str) -> str:
|
37 |
model_dir = self.checkpoint_dir / base_model_id
|
38 |
if not model_dir.exists():
|
39 |
org_name = base_model_id.split('/')[0]
|
40 |
org_dir = self.checkpoint_dir / org_name
|
41 |
org_dir.mkdir(exist_ok=True)
|
42 |
print(f'https://huggingface.co/{base_model_id}')
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
return model_dir.as_posix()
|
47 |
|
48 |
def join_model_library_org(self, token: str) -> None:
|
@@ -241,7 +246,7 @@ class Trainer:
|
|
241 |
self.hf_token if self.hf_token else input_token)
|
242 |
|
243 |
config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
|
244 |
-
config.pretrained_model_path = self.download_base_model(tuned_model)
|
245 |
config.output_dir = output_dir.as_posix()
|
246 |
config.train_data.video_path = training_video.name # type: ignore
|
247 |
config.train_data.prompt = training_prompt
|
|
|
11 |
import gradio as gr
|
12 |
import slugify
|
13 |
import torch
|
14 |
+
import huggingface_hub
|
15 |
from huggingface_hub import HfApi
|
16 |
from omegaconf import OmegaConf
|
17 |
|
|
|
34 |
self.checkpoint_dir = pathlib.Path('checkpoints')
|
35 |
self.checkpoint_dir.mkdir(exist_ok=True)
|
36 |
|
37 |
+
def download_base_model(self, base_model_id: str, token=None) -> str:
|
38 |
model_dir = self.checkpoint_dir / base_model_id
|
39 |
if not model_dir.exists():
|
40 |
org_name = base_model_id.split('/')[0]
|
41 |
org_dir = self.checkpoint_dir / org_name
|
42 |
org_dir.mkdir(exist_ok=True)
|
43 |
print(f'https://huggingface.co/{base_model_id}')
|
44 |
+
try:
|
45 |
+
subprocess.run(shlex.split(
|
46 |
+
f'git clone https://huggingface.co/{base_model_id}'),
|
47 |
+
cwd=org_dir)
|
48 |
+
except:
|
49 |
+
temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token)
|
50 |
+
subprocess.run(shlex.split(f'mv {temp_path} {org_dir}'))
|
51 |
return model_dir.as_posix()
|
52 |
|
53 |
def join_model_library_org(self, token: str) -> None:
|
|
|
246 |
self.hf_token if self.hf_token else input_token)
|
247 |
|
248 |
config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
|
249 |
+
config.pretrained_model_path = self.download_base_model(tuned_model, token=input_token)
|
250 |
config.output_dir = output_dir.as_posix()
|
251 |
config.train_data.video_path = training_video.name # type: ignore
|
252 |
config.train_data.prompt = training_prompt
|