File size: 4,225 Bytes
75f2d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62b6d65
75f2d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a341a4
75f2d00
 
 
 
 
 
 
62b6d65
75f2d00
4a341a4
 
 
75f2d00
62b6d65
 
75f2d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import time
import yaml, math
from tqdm import trange
import torch
import numpy as np
from omegaconf import OmegaConf
import torch.distributed as dist
from pytorch_lightning import seed_everything

from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import str2bool
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
from scripts.sample_text2video import sample_text2video
from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
from lvdm.models.modules.lora import change_lora, change_lora_v2

from huggingface_hub import hf_hub_download


def save_results(videos, save_dir, 
                 save_name="results", save_fps=8
                 ):
    save_subdir = os.path.join(save_dir, "videos")
    os.makedirs(save_subdir, exist_ok=True)
    for i in range(videos.shape[0]):
        npz_to_video_grid(videos[i:i+1,...], 
                            os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), 
                            fps=save_fps)
    print(f'Successfully saved videos in {save_subdir}')
    video_path_list = [os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4") for i in range(videos.shape[0])]
    return video_path_list
    

class Text2Video():
    def __init__(self,result_dir='./tmp/') -> None:
        self.download_model()
        config_file = 'models/base_t2v/model_config.yaml'
        ckpt_path = 'models/base_t2v/model.ckpt'
        config = OmegaConf.load(config_file)
        self.lora_path_list = ['','models/videolora/lora_001_Loving_Vincent_style.ckpt',
                                'models/videolora/lora_002_frozenmovie_style.ckpt',
                                'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt',
                                'models/videolora/lora_004_coco_style.ckpt']
        self.lora_trigger_word_list = ['','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
        model, _, _ = load_model(config, ckpt_path, gpu_id=0, inject_lora=False)
        self.model = model
        self.last_time_lora = ''
        self.last_time_lora_scale = 1.0
        self.result_dir = result_dir
        self.save_fps = 8
        self.ddim_sampler = DDIMSampler(model) 
        self.origin_weight = None

    def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0):
        if model_index > 0:
            input_text = input_text + ', ' + self.lora_trigger_word_list[model_index]
        inject_lora = model_index > 0
        self.origin_weight = change_lora_v2(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index],
                    last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale, origin_weight=self.origin_weight)

        all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
                        sample_type='ddim', sampler=self.ddim_sampler,
                        ddim_steps=steps, eta=eta, 
                        cfg_scale=cfg_scale,
                        )
        prompt = input_text
        prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
        prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
        self.last_time_lora=self.lora_path_list[model_index]
        self.last_time_lora_scale = lora_scale
        video_path_list = save_results(all_videos, self.result_dir, save_name=prompt_str, save_fps=self.save_fps)
        return video_path_list[0]
    
    def download_model(self):
        REPO_ID = 'VideoCrafter/t2v-version-1-1'
        filename_list = ['models/base_t2v/model.ckpt',
                        'models/videolora/lora_001_Loving_Vincent_style.ckpt',
                        'models/videolora/lora_002_frozenmovie_style.ckpt',
                        'models/videolora/lora_003_MakotoShinkaiYourName_style.ckpt',
                        'models/videolora/lora_004_coco_style.ckpt']
        for filename in filename_list:
            if not os.path.exists(filename):
                hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./', local_dir_use_symlinks=False)