from PIL import Image from base64 import b64encode import torch from torch import autocast from torch.nn import functional as F from diffusers import StableDiffusionPipeline, AutoencoderKL from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler #from transformers import CLIPTextModel, CLIPTokenizer from tqdm.auto import tqdm from huggingface_hub import notebook_login import torch.nn as nn device = 'cpu' from Multilingual_CLIP.multilingual_clip import Config_MCLIP import transformers import torch class MultilingualCLIP(transformers.PreTrainedModel): config_class = Config_MCLIP.MCLIPConfig def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.transformer = transformers.AutoModel.from_pretrained(config.modelBase) self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions, out_features=config.numDims) def forward(self, txt, tokenizer, device): txt_tok = tokenizer(txt, padding='max_length', max_length=77, truncation=True, return_tensors='pt').to(device) embs = self.transformer(**txt_tok) embs = embs[0] att = txt_tok['attention_mask'] embs = (embs * att.unsqueeze(2)) / att.sum(dim=1)[:, None].unsqueeze(2) return self.LinearTransformation(embs) @classmethod def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True): model.load_state_dict(state_dict) return model, [], [], [] # Define the adaptation layer, 'checkpoint_9.pth' class AdaptationLayer(nn.Module): def __init__(self, input_dim, output_dim): super(AdaptationLayer, self).__init__() self.fc1 = nn.Linear(input_dim, output_dim*2) torch.nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu') self.bn1 = nn.BatchNorm1d(77) self.fc2 = nn.Linear(input_dim*2, output_dim*2) torch.nn.init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu') self.bn2 = nn.BatchNorm1d(77) self.fc3 = nn.Linear(input_dim*2, output_dim) torch.nn.init.kaiming_uniform_(self.fc3.weight, nonlinearity='relu') self.bn3 = nn.BatchNorm1d(77) self.fc4 = nn.Linear(input_dim, output_dim) torch.nn.init.kaiming_uniform_(self.fc4.weight, nonlinearity='relu') self.bn4 = nn.BatchNorm1d(77) self.fc5 = nn.Linear(input_dim, output_dim) def forward(self, x): x = nn.functional.normalize(x, p=2.0, dim=1, eps=1e-12, out=None) x = torch.relu(self.bn1(self.fc1(x))) x = torch.relu(self.bn2(self.fc2(x))) x = torch.relu(self.bn3(self.fc3(x))) x = torch.relu(self.bn4(self.fc4(x))) return self.fc5(x) adapt_model = AdaptationLayer(768,768) adapt_model.to(device) state_dict = torch.load('weights/checkpoint_9.pth') adapt_model.load_state_dict(state_dict) from Multilingual_CLIP.multilingual_clip import pt_multilingual_clip texts = [ 'قطة تقرأ كتابا' ] model_name = 'M-CLIP/LABSE-Vit-L-14' # Load Model & Tokenizer text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name) text_model = text_model.to(device) text_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) embeddings= text_model.forward(texts, text_tokenizer, device ) # 1. Load the autoencoder model which will be used to decode the latents into image space. vae = AutoencoderKL.from_pretrained( 'CompVis/stable-diffusion-v1-4', subfolder='vae', use_auth_token=True) vae = vae.to(device) # 2. Load the tokenizer and text encoder to tokenize and encode the text. tokenizer = text_tokenizer text_encoder = text_model # 3. The UNet model for generating the latents. unet = UNet2DConditionModel.from_pretrained( 'CompVis/stable-diffusion-v1-4', subfolder='unet', use_auth_token=True) unet = unet.to(device) # 4. Create a scheduler for inference scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000) def get_text_embeds(prompt): with torch.no_grad(): text_embeddings = text_model(prompt, text_tokenizer, device) text_embeddings = adapt_model(text_embeddings) # Do the same for unconditional embeddings with torch.no_grad(): uncond_embeddings = text_model([''] * len(prompt), text_tokenizer, device) uncond_embeddings = adapt_model(uncond_embeddings) # Cat for final embeddings text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings def produce_latents(text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): if latents is None: latents = torch.randn((text_embeddings.shape[0] // 2, unet.in_channels, \ height // 8, width // 8)) latents = latents.to(device) scheduler.set_timesteps(num_inference_steps) latents = latents * scheduler.sigmas[0] with autocast('cpu'): for i, t in tqdm(enumerate(scheduler.timesteps)): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual with torch.no_grad(): noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings.to(device))['sample'] # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = scheduler.step(noise_pred, i, latents)['prev_sample'] return latents def decode_img_latents(latents): latents = 1 / 0.18215 * latents with torch.no_grad(): imgs = vae.decode(latents) imgs = (imgs / 2 + 0.5).clamp(0, 1) imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() imgs = (imgs * 255).round().astype('uint8') pil_images = [Image.fromarray(image) for image in imgs] return pil_images def prompt_to_img(prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): if isinstance(prompts, str): prompts = [prompts] # Prompts -> text embeds text_embeds = get_text_embeds(prompts) # Text embeds -> img latents latents = produce_latents( text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # Img latents -> imgs imgs = decode_img_latents(latents) return imgs