#!/usr/bin/env python3 from huggingface_hub import HfApi import torch import requests from PIL import Image from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline from diffusers.schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from transformers import BlipForConditionalGeneration, BlipProcessor api = HfApi() img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB').resize((512, 512)) processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16, low_cpu_mem_usage=True) model_ckpt = "CompVis/stable-diffusion-v1-4" pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( model_ckpt, caption_generator=model, caption_processor=processor, torch_dtype=torch.float16, safety_checker=None, ) pipeline.enable_model_cpu_offload() caption = pipeline.generate_caption(raw_image) pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) print(caption) generator = torch.manual_seed(0) inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] source_embeds = pipeline.get_embeds(source_prompts, batch_size=2) target_embeds = pipeline.get_embeds(target_prompts, batch_size=2) image = pipeline( caption, source_embeds=source_embeds, target_embeds=target_embeds, num_inference_steps=50, cross_attention_guidance_amount=0.15, generator=generator, latents=inv_latents, negative_prompt=caption, ).images[0] path = "/home/patrick_huggingface_co/images/aa.png" image.save(path) api.upload_file( path_or_fileobj=path, path_in_repo=path.split("/")[-1], repo_id="patrickvonplaten/images", repo_type="dataset", ) print("https://huggingface.co./datasets/patrickvonplaten/images/blob/main/aa.png")