Spaces:
Runtime error
Runtime error
import gradio as gr | |
from examples.story_examples import get_examples | |
import spaces | |
import numpy as np | |
import torch | |
import random | |
import os | |
import torch.nn.functional as F | |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler | |
import copy | |
from huggingface_hub import hf_hub_download | |
from diffusers.utils import load_image | |
from storyDiffusion.utils.gradio_utils import AttnProcessor2_0 as AttnProcessor, cal_attn_mask_xl | |
from storyDiffusion.utils import PhotoMakerStableDiffusionXLPipeline | |
from storyDiffusion.utils.utils import get_comic | |
from storyDiffusion.utils.style_template import styles | |
# Constants | |
image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder" | |
ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" | |
os.environ["no_proxy"] = "localhost,127.0.0.1,::1" | |
STYLE_NAMES = list(styles.keys()) | |
DEFAULT_STYLE_NAME = "Japanese Anime" | |
MAX_SEED = np.iinfo(np.int32).max | |
# Global variables | |
global models_dict, use_va, photomaker_path, pipe2, pipe4, attn_count, total_count, id_length, total_length, cur_step, cur_model_type, write, sa32, sa64, height, width, attn_procs, unet, num_steps | |
models_dict = { | |
"RealVision": "SG161222/RealVisXL_V4.0", | |
"Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y" | |
} | |
use_va = True | |
photomaker_path = hf_hub_download( | |
repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model") | |
device = "cuda" | |
# Functions | |
def setup_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def set_text_unfinished(): | |
return gr.update(visible=True, value="<h3>(Not Finished) Generating ··· The intermediate results will be shown.</h3>") | |
def set_text_finished(): | |
return gr.update(visible=True, value="<h3>Generation Finished</h3>") | |
class SpatialAttnProcessor2_0(torch.nn.Module): | |
r""" | |
Attention processor for IP-Adapater for PyTorch 2.0. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
text_context_len (`int`, defaults to 77): | |
The context length of the text features. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
""" | |
def __init__(self, hidden_size=None, cross_attention_dim=None, id_length=4, device="cuda", dtype=torch.float16): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
self.device = device | |
self.dtype = dtype | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.total_length = id_length + 1 | |
self.id_length = id_length | |
self.id_bank = {} | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None): | |
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2) | |
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb) | |
# 生成一个0到1之间的随机数 | |
global total_count, attn_count, cur_step, mask1024, mask4096 | |
global sa32, sa64 | |
global write | |
global height, width | |
global num_steps | |
if write: | |
# print(f"white:{cur_step}") | |
self.id_bank[cur_step] = [ | |
hidden_states[:self.id_length], hidden_states[self.id_length:]] | |
else: | |
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to( | |
self.device), hidden_states[:1], self.id_bank[cur_step][1].to(self.device), hidden_states[1:])) | |
# 判断随机数是否大于0.5 | |
if cur_step <= 1: | |
hidden_states = self.__call2__( | |
attn, hidden_states, None, attention_mask, temb) | |
else: # 256 1024 4096 | |
random_number = random.random() | |
if cur_step < 0.4 * num_steps: | |
rand_num = 0.3 | |
else: | |
rand_num = 0.1 | |
# print(f"hidden state shape {hidden_states.shape[1]}") | |
if random_number > rand_num: | |
# print("mask shape",mask1024.shape,mask4096.shape) | |
if not write: | |
if hidden_states.shape[1] == (height//32) * (width//32): | |
attention_mask = mask1024[mask1024.shape[0] // | |
self.total_length * self.id_length:] | |
else: | |
attention_mask = mask4096[mask4096.shape[0] // | |
self.total_length * self.id_length:] | |
else: | |
# print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32)) | |
if hidden_states.shape[1] == (height//32) * (width//32): | |
attention_mask = mask1024[:mask1024.shape[0] // self.total_length * | |
self.id_length, :mask1024.shape[0] // self.total_length * self.id_length] | |
else: | |
attention_mask = mask4096[:mask4096.shape[0] // self.total_length * | |
self.id_length, :mask4096.shape[0] // self.total_length * self.id_length] | |
# print(attention_mask.shape) | |
# print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None") | |
hidden_states = self.__call1__( | |
attn, hidden_states, encoder_hidden_states, attention_mask, temb) | |
else: | |
hidden_states = self.__call2__( | |
attn, hidden_states, None, attention_mask, temb) | |
attn_count += 1 | |
if attn_count == total_count: | |
attn_count = 0 | |
cur_step += 1 | |
mask1024, mask4096 = cal_attn_mask_xl( | |
self.total_length, self.id_length, sa32, sa64, height, width, device=self.device, dtype=self.dtype) | |
return hidden_states | |
def __call1__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
): | |
# print("hidden state shape",hidden_states.shape,self.id_length) | |
residual = hidden_states | |
# if encoder_hidden_states is not None: | |
# raise Exception("not implement") | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
total_batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
total_batch_size, channel, height * width).transpose(1, 2) | |
total_batch_size, nums_token, channel = hidden_states.shape | |
img_nums = total_batch_size//2 | |
hidden_states = hidden_states.view(-1, img_nums, nums_token, | |
channel).reshape(-1, img_nums * nums_token, channel) | |
batch_size, sequence_length, _ = hidden_states.shape | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm( | |
hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states # B, N, C | |
else: | |
encoder_hidden_states = encoder_hidden_states.view( | |
-1, self.id_length+1, nums_token, channel).reshape(-1, (self.id_length+1) * nums_token, channel) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, | |
head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, | |
head_dim).transpose(1, 2) | |
# print(key.shape,value.shape,query.shape,attention_mask.shape) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
# print(query.shape,key.shape,value.shape,attention_mask.shape) | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
total_batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
# if input_ndim == 4: | |
# tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
# if attn.residual_connection: | |
# tile_hidden_states = tile_hidden_states + residual | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
total_batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
# print(hidden_states.shape) | |
return hidden_states | |
def __call2__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, channel = ( | |
hidden_states.shape | |
) | |
# print(hidden_states.shape) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm( | |
hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states # B, N, C | |
else: | |
encoder_hidden_states = encoder_hidden_states.view( | |
-1, self.id_length+1, sequence_length, channel).reshape(-1, (self.id_length+1) * sequence_length, channel) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, | |
head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, | |
head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose( | |
-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def set_attention_processor(unet, id_length, is_ipadapter=False): | |
global total_count | |
total_count = 0 | |
attn_procs = {} | |
for name in unet.attn_processors.keys(): | |
cross_attention_dim = None if name.endswith( | |
"attn1.processor") else unet.config.cross_attention_dim | |
if cross_attention_dim is None: | |
if name.startswith("up_blocks"): | |
attn_procs[name] = SpatialAttnProcessor2_0(id_length=id_length) | |
total_count += 1 | |
else: | |
attn_procs[name] = AttnProcessor() | |
else: | |
attn_procs[name] = AttnProcessor() | |
unet.set_attn_processor(copy.deepcopy(attn_procs)) | |
print("Successfully loaded paired self-attention") | |
print(f"Number of processors: {total_count}") | |
attn_count = 0 | |
total_count = 0 | |
cur_step = 0 | |
id_length = 4 | |
total_length = 5 | |
cur_model_type = "" | |
device = "cuda" | |
attn_procs = {} | |
write = False | |
sa32 = 0.5 | |
sa64 = 0.5 | |
height = 768 | |
width = 768 | |
def swap_to_gallery(images): | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
def upload_example_to_gallery(images, prompt, style, negative_prompt): | |
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
def remove_back_to_files(): | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
def remove_tips(): | |
return gr.update(visible=False) | |
def apply_style_positive(style_name: str, positive: str): | |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
return p.replace("{prompt}", positive) | |
def apply_style(style_name: str, positives: list, negative: str = ""): | |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
return [p.replace("{prompt}", positive) for positive in positives], n + ' ' + negative | |
def change_visiale_by_model_type(_model_type): | |
if _model_type == "Only Using Textual Description": | |
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
elif _model_type == "Using Ref Images": | |
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) | |
else: | |
raise ValueError("Invalid model type", _model_type) | |
def process_generation(_sd_type, _model_type, _upload_images, _num_steps, style_name, _Ip_Adapter_Strength, _style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, _comic_type): | |
global sa32, sa64, id_length, total_length, attn_procs, unet, cur_model_type, device, num_steps, write, cur_step, attn_count, height, width, pipe2, pipe4, sd_model_path, models_dict | |
_model_type = "Photomaker" if _model_type == "Using Ref Images" else "original" | |
if _model_type == "Photomaker" and "img" not in general_prompt: | |
raise gr.Error( | |
"Please add the trigger word 'img' behind the class word you want to customize, such as: man img or woman img") | |
if _upload_images is None and _model_type != "original": | |
raise gr.Error("Cannot find any input face image!") | |
if len(prompt_array.splitlines()) > 10: | |
raise gr.Error( | |
f"No more than 10 prompts in Hugging Face demo for speed! But found {len(prompt_array.splitlines())} prompts!") | |
height = G_height | |
width = G_width | |
sd_model_path = models_dict[_sd_type] | |
num_steps = _num_steps | |
if style_name == "(No style)": | |
sd_model_path = models_dict["RealVision"] | |
if _model_type == "original": | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
sd_model_path, torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
elif _model_type == "Photomaker": | |
if _sd_type != "RealVision" and style_name != "(No style)": | |
pipe = pipe2.to(device) | |
pipe.id_encoder.to(device) | |
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
else: | |
pipe = pipe4.to(device) | |
pipe.id_encoder.to(device) | |
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
else: | |
raise NotImplementedError( | |
"You should choose between original and Photomaker!", f"But you chose {_model_type}") | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
cur_model_type = _sd_type + "-" + _model_type + str(id_length_) | |
if _model_type != "original": | |
input_id_images = [load_image(img) for img in _upload_images] | |
prompts = prompt_array.splitlines() | |
start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps) | |
if start_merge_step > 30: | |
start_merge_step = 30 | |
print(f"start_merge_step: {start_merge_step}") | |
generator = torch.Generator(device="cuda").manual_seed(seed_) | |
sa32, sa64 = sa32_, sa64_ | |
id_length = id_length_ | |
clipped_prompts = prompts[:] | |
prompts = [general_prompt + "," + prompt if "[NC]" not in prompt else prompt.replace( | |
"[NC]", "") for prompt in clipped_prompts] | |
prompts = [prompt.rpartition( | |
'#')[0] if "#" in prompt else prompt for prompt in prompts] | |
print(prompts) | |
id_prompts = prompts[:id_length] | |
real_prompts = prompts[id_length:] | |
torch.cuda.empty_cache() | |
write = True | |
cur_step = 0 | |
attn_count = 0 | |
id_prompts, negative_prompt = apply_style( | |
style_name, id_prompts, negative_prompt) | |
setup_seed(seed_) | |
total_results = [] | |
if _model_type == "original": | |
id_images = pipe(id_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
height=height, width=width, negative_prompt=negative_prompt, generator=generator).images | |
elif _model_type == "Photomaker": | |
id_images = pipe(id_prompts, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images | |
else: | |
raise NotImplementedError( | |
"You should choose between original and Photomaker!", f"But you chose {_model_type}") | |
total_results = id_images + total_results | |
yield total_results | |
real_images = [] | |
write = False | |
for real_prompt in real_prompts: | |
setup_seed(seed_) | |
cur_step = 0 | |
real_prompt = apply_style_positive(style_name, real_prompt) | |
if _model_type == "original": | |
real_images.append(pipe(real_prompt, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0]) | |
elif _model_type == "Photomaker": | |
real_images.append(pipe(real_prompt, input_id_images=input_id_images, num_inference_steps=_num_steps, guidance_scale=guidance_scale, | |
start_merge_step=start_merge_step, height=height, width=width, negative_prompt=negative_prompt, generator=generator).images[0]) | |
else: | |
raise NotImplementedError( | |
"You should choose between original and Photomaker!", f"But you chose {_model_type}") | |
total_results = [real_images[-1]] + total_results | |
yield total_results | |
if _comic_type != "No typesetting (default)": | |
from PIL import ImageFont | |
captions = prompt_array.splitlines() | |
captions = [caption.replace("[NC]", "") for caption in captions] | |
captions = [caption.split( | |
'#')[-1] if "#" in caption else caption for caption in captions] | |
total_results = get_comic(id_images + real_images, _comic_type, captions=captions, | |
font=ImageFont.truetype("./storyDiffusion/fonts/Inkfree.ttf", int(45))) + total_results | |
if _model_type == "Photomaker": | |
pipe = pipe2.to("cpu") | |
pipe.id_encoder.to("cpu") | |
set_attention_processor(pipe.unet, id_length_, is_ipadapter=False) | |
yield total_results | |
# Initialize pipelines | |
pipe2 = PhotoMakerStableDiffusionXLPipeline.from_pretrained( | |
models_dict["Unstable"], torch_dtype=torch.float16, use_safetensors=False) | |
pipe2 = pipe2.to("cpu") | |
pipe2.load_photomaker_adapter( | |
os.path.dirname(photomaker_path), | |
subfolder="", | |
weight_name=os.path.basename(photomaker_path), | |
trigger_word="img" | |
) | |
pipe2 = pipe2.to("cpu") | |
pipe2.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
pipe2.fuse_lora() | |
pipe4 = PhotoMakerStableDiffusionXLPipeline.from_pretrained( | |
models_dict["RealVision"], torch_dtype=torch.float16, use_safetensors=True) | |
pipe4 = pipe4.to("cpu") | |
pipe4.load_photomaker_adapter( | |
os.path.dirname(photomaker_path), | |
subfolder="", | |
weight_name=os.path.basename(photomaker_path), | |
trigger_word="img" | |
) | |
pipe4 = pipe4.to("cpu") | |
pipe4.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) | |
pipe4.fuse_lora() | |
def story_generation_ui(): | |
with gr.Row(): | |
with gr.Group(elem_id="main-image"): | |
prompts = [] | |
colors = [] | |
with gr.Column(visible=True) as gen_prompt_vis: | |
sd_type = gr.Dropdown(choices=list(models_dict.keys( | |
)), value="Unstable", label="sd_type", info="Select pretrained model") | |
model_type = gr.Radio(["Only Using Textual Description", "Using Ref Images"], label="model_type", | |
value="Only Using Textual Description", info="Control type of the Character") | |
with gr.Group(visible=False) as control_image_input: | |
files = gr.Files( | |
label="Drag (Select) 1 or more photos of your face", | |
file_types=["image"], | |
) | |
uploaded_files = gr.Gallery( | |
label="Your images", visible=False, columns=5, rows=1, height=200) | |
with gr.Column(visible=False) as clear_button: | |
remove_and_reupload = gr.ClearButton( | |
value="Remove and upload new ones", components=files, size="sm") | |
general_prompt = gr.Textbox( | |
value='', label="(1) Textual Description for Character", interactive=True) | |
negative_prompt = gr.Textbox( | |
value='', label="(2) Negative_prompt", interactive=True) | |
style = gr.Dropdown( | |
label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME) | |
prompt_array = gr.Textbox( | |
lines=3, value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True) | |
with gr.Accordion("(4) Tune the hyperparameters", open=False): | |
sa32_ = gr.Slider(label="(The degree of Paired Attention at 32 x 32 self-attention layers)", | |
minimum=0, maximum=1., value=0.7, step=0.1) | |
sa64_ = gr.Slider(label="(The degree of Paired Attention at 64 x 64 self-attention layers)", | |
minimum=0, maximum=1., value=0.7, step=0.1) | |
id_length_ = gr.Slider( | |
label="Number of id images in total images", minimum=2, maximum=4, value=3, step=1) | |
seed_ = gr.Slider(label="Seed", minimum=-1, | |
maximum=MAX_SEED, value=0, step=1) | |
num_steps = gr.Slider( | |
label="Number of sample steps", | |
minimum=25, | |
maximum=50, | |
step=1, | |
value=50, | |
) | |
G_height = gr.Slider( | |
label="height", | |
minimum=256, | |
maximum=1024, | |
step=32, | |
value=1024, | |
) | |
G_width = gr.Slider( | |
label="width", | |
minimum=256, | |
maximum=1024, | |
step=32, | |
value=1024, | |
) | |
comic_type = gr.Radio(["No typesetting (default)", "Four Pannel", "Classic Comic Style"], | |
value="Classic Comic Style", label="Typesetting Style", info="Select the typesetting style ") | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.1, | |
maximum=10.0, | |
step=0.1, | |
value=5, | |
) | |
style_strength_ratio = gr.Slider( | |
label="Style strength of Ref Image (%)", | |
minimum=15, | |
maximum=50, | |
step=1, | |
value=20, | |
visible=False | |
) | |
Ip_Adapter_Strength = gr.Slider( | |
label="Ip_Adapter_Strength", | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.5, | |
visible=False | |
) | |
final_run_btn = gr.Button("Generate ! 😺") | |
with gr.Column(): | |
out_image = gr.Gallery(label="Result", columns=2, height='auto') | |
generated_information = gr.Markdown( | |
label="Generation Details", value="", visible=False) | |
model_type.change(fn=change_visiale_by_model_type, inputs=model_type, outputs=[ | |
control_image_input, style_strength_ratio, Ip_Adapter_Strength]) | |
files.upload(fn=swap_to_gallery, inputs=files, outputs=[ | |
uploaded_files, clear_button, files]) | |
remove_and_reupload.click(fn=remove_back_to_files, outputs=[ | |
uploaded_files, clear_button, files]) | |
final_run_btn.click(fn=set_text_unfinished, outputs=generated_information | |
).then(process_generation, inputs=[sd_type, model_type, files, num_steps, style, Ip_Adapter_Strength, style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array, G_height, G_width, comic_type], outputs=out_image | |
).then(fn=set_text_finished, outputs=generated_information) | |
gr.Examples( | |
examples=get_examples(), | |
inputs=[seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, | |
prompt_array, style, model_type, files, G_height, G_width], | |
label='😺 Examples 😺', | |
) | |