MultilingualSD3-adapter
is a multilingual adapter tailored for the SD3. Originating from an ECCV 2024 paper titled PEA-Diffusion. The open-source code is available at https://github.com/OPPO-Mente-Lab/PEA-Diffusion.
Usage
We used the multilingual encoder umt5-xxl,Mul-OpenCLIP and HunyuanDiT_CLIP. We implemented a reverse denoising process for distillation training.
MultilingualSD3
import os
import torch
import torch.nn as nn
from typing import Any, Callable, Dict, List, Optional, Union
import inspect
from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers import AutoencoderKL
from tqdm import tqdm
from PIL import Image
from transformers import T5Tokenizer,T5EncoderModel,BertModel, BertTokenizer
import open_clip
class MLP(nn.Module):
def __init__(self, in_dim=1024, out_dim=2048, hidden_dim=2048, out_dim1=4096, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.projector = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=False),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.GELU(),
nn.Linear(hidden_dim, out_dim, bias=False),
)
self.fc = nn.Linear(out_dim, out_dim1)
self.use_residual = use_residual
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.projector(x)
x2 = nn.GELU()(x)
x2 = self.fc(x2)
return x2
class Transformer(nn.Module):
def __init__(self, d_model, n_heads, out_dim1, out_dim2,num_layers=1) -> None:
super().__init__()
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2048, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.linear1 = nn.Linear(d_model, out_dim1)
self.linear2 = nn.Linear(d_model, out_dim2)
def forward(self, x):
x = self.transformer_encoder(x)
x1 = self.linear1(x)
x1 = torch.mean(x1,1)
x2 = self.linear2(x)
return x1,x2
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusionTest():
def __init__(self,model_path,text_encoder_path,text_encoder_path1,text_encoder_path2,proj_path,proj_t5_path):
super().__init__()
self.transformer = SD3Transformer2DModel.from_pretrained(model_path, subfolder="transformer",torch_dtype=dtype).to(device)
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device,dtype=dtype)
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
self.text_encoder_t5 = T5EncoderModel.from_pretrained(text_encoder_path).to(device,dtype=dtype)
self.tokenizer_t5 = T5Tokenizer.from_pretrained(text_encoder_path)
self.text_encoder = BertModel.from_pretrained(f"{text_encoder_path1}/clip_text_encoder", False, revision=None).to(device,dtype=dtype)
self.tokenizer = BertTokenizer.from_pretrained(f"{text_encoder_path1}/tokenizer")
self.text_encoder2, _, _ = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path2)
self.tokenizer2 = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')
self.text_encoder2.text.output_tokens = True
self.text_encoder2 = self.text_encoder2.to(device,dtype=dtype)
self.proj = MLP(2048, 2048, 2048, 4096, use_residual=False).to(device,dtype=dtype)
self.proj.load_state_dict(torch.load(proj_path, map_location="cpu"))
self.proj_t5 = Transformer(d_model=4096, n_heads=8, out_dim1=2048, out_dim2=4096).to(device,dtype=dtype)
self.proj_t5.load_state_dict(torch.load(proj_t5_path, map_location="cpu"))
def encode_prompt(self, prompt, device, do_classifier_free_guidance=True, negative_prompt=None):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_input_ids_t5 = self.tokenizer_t5(
prompt,
padding="max_length",
max_length=77,
truncation=True,
add_special_tokens=False,
return_tensors="pt",
).input_ids.to(device)
text_embeddings = self.text_encoder_t5(text_input_ids_t5)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
encoder_hidden_states = self.text_encoder(input_ids,attention_mask=attention_mask)[0]
text_input_ids = self.tokenizer2(prompt).to(device)
_,encoder_hidden_states2 = self.text_encoder2.encode_text(text_input_ids)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states2], dim=-1)
encoder_hidden_states_t5 = text_embeddings[0]
encoder_hidden_states = self.proj(encoder_hidden_states)
add_text_embeds,encoder_hidden_states_t5 = self.proj_t5(encoder_hidden_states_t5.half())
prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=-2)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if negative_prompt is None:
uncond_tokens = [""] * batch_size
else:
uncond_tokens = negative_prompt
text_input_ids_t5 = self.tokenizer_t5(
uncond_tokens,
padding="max_length",
max_length=77,
truncation=True,
add_special_tokens=False,
return_tensors="pt",
).input_ids.to(device)
text_embeddings = self.text_encoder_t5(text_input_ids_t5)
text_inputs = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
encoder_hidden_states = self.text_encoder(input_ids,attention_mask=attention_mask)[0]
text_input_ids = self.tokenizer2(uncond_tokens).to(device)
_,encoder_hidden_states2 = self.text_encoder2.encode_text(text_input_ids)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states2], dim=-1)
encoder_hidden_states_t5 = text_embeddings[0]
encoder_hidden_states_uncond = self.proj(encoder_hidden_states)
add_text_embeds_uncond,encoder_hidden_states_t5_uncond = self.proj_t5(encoder_hidden_states_t5.half())
prompt_embeds_uncond = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_t5_uncond], dim=-2)
prompt_embeds = torch.cat([prompt_embeds_uncond, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([add_text_embeds_uncond, add_text_embeds], dim=0)
return prompt_embeds,pooled_prompt_embeds
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = torch.randn(shape, generator=generator, dtype=dtype).to(device)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
prompt_embeds,pooled_prompt_embeds = self.encode_prompt(prompt, device)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
for i, t in tqdm(enumerate(timesteps)):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
timestep = t.expand(latent_model_input.shape[0]).to(dtype=dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds.to(dtype=self.transformer.dtype),
pooled_projections=pooled_prompt_embeds.to(dtype=self.transformer.dtype),
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
return image
if __name__ == '__main__':
device = "cuda"
dtype = torch.float16
text_encoder_path = 'google/umt5-xxl'
text_encoder_path1 = "Tencent-Hunyuan/HunyuanDiT/t2i"
text_encoder_path2 = 'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/open_clip_pytorch_model.bin'
model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
proj_path = "OPPOer/MultilingualSD3-adapter/pytorch_model.bin"
proj_t5_path = "OPPOer/MultilingualSD3-adapter/pytorch_model_t5.bin"
sdt = StableDiffusionTest(model_path,text_encoder_path,text_encoder_path1,text_encoder_path2,proj_path,proj_t5_path)
batch=2
height = 1024
width = 1024
while True:
raw_text = input("\nPlease Input Query (stop to exit) >>> ")
if not raw_text:
print('Query should not be empty!')
continue
if raw_text == "stop":
break
images = sdt([raw_text]*batch,height=height,width=width)
grid = image_grid(images, rows=1, cols=batch)
grid.save("MultilingualSD3.png")
To learn more check out the diffusers documentation
License
The adapter itself is Apache License 2.0, but it must follow the license of the main model.
Model tree for OPPOer/MultilingualSD3-adapter
Base model
stabilityai/stable-diffusion-3-medium