diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0cebb3f50a1ec24241e261472a48ae2b3e04dfd8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +# Use an official CUDA runtime as a parent image +FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04 + +# Set the working directory in the container +WORKDIR /usr/src/app + +# Install Python +RUN apt-get update && apt-get install -y python3.8 python3-pip + +# Copy the current directory contents into the container at /usr/src/app +COPY . /usr/src/app + +# Install any needed packages specified in requirements.txt +RUN pip3 install --no-cache-dir -r requirements.txt + +# Make port 80 available to the world outside this container +EXPOSE 7860 + +# Run app.py when the container launches +CMD ["python3", "app.py"] diff --git a/README.md b/README.md index ce8b683ef0c224f09c8bb3fd279a3f09011815d6..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/README.md +++ b/README.md @@ -1,10 +0,0 @@ ---- -title: Minerva Generate Docker -emoji: 😻 -colorFrom: red -colorTo: purple -sdk: docker -pinned: false ---- - -Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference diff --git a/__pycache__/app.cpython-310.pyc b/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c712b54bd6825f797efcd5621667b551c64765 Binary files /dev/null and b/__pycache__/app.cpython-310.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..45f0bf48876e1dc6ada816686a35e32a4f8ddeb2 --- /dev/null +++ b/app.py @@ -0,0 +1,5 @@ +import gradio as gr +from blocks.main import main_box +blocks = main_box() +blocks.launch(share = True) + diff --git a/blocks/.DS_Store b/blocks/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..df93a9b0d9655287359ac7432a9c6c959589de12 Binary files /dev/null and b/blocks/.DS_Store differ diff --git a/blocks/__init__.py b/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e08103f4d0e1c4396623022cfd2c4ff0ec17381b --- /dev/null +++ b/blocks/__init__.py @@ -0,0 +1,17 @@ +IMG2IMG_MODEL_LIST = { + "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5", + "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1", + "OpenJourney v4" : "prompthero/openjourney-v4", + "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0", + "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0" +} + +TEXT2IMG_MODEL_LIST = { + "OpenJourney v4" : "prompthero/openjourney-v4", + "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5", + "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1", + "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0", + "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0", + "DreamShaper" : "Lykon/DreamShaper", + "NeverEnding-Dream" : "Lykon/NeverEnding-Dream" +} \ No newline at end of file diff --git a/blocks/__pycache__/__init__.cpython-310.pyc b/blocks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e32727616cd70082efd0f4db814fc1a9c2e6856 Binary files /dev/null and b/blocks/__pycache__/__init__.cpython-310.pyc differ diff --git a/blocks/__pycache__/__init__.cpython-39.pyc b/blocks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bb98b075ea4d2e2b1b5186edf7160707c123984 Binary files /dev/null and b/blocks/__pycache__/__init__.cpython-39.pyc differ diff --git a/blocks/__pycache__/download.cpython-310.pyc b/blocks/__pycache__/download.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26929a76993e7fa3eb6e1f7d16a7420647843bb7 Binary files /dev/null and b/blocks/__pycache__/download.cpython-310.pyc differ diff --git a/blocks/__pycache__/download.cpython-39.pyc b/blocks/__pycache__/download.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a7adaa7111b7d81198b617bbd28a47e6a11fa4 Binary files /dev/null and b/blocks/__pycache__/download.cpython-39.pyc differ diff --git a/blocks/__pycache__/img2img.cpython-39.pyc b/blocks/__pycache__/img2img.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44815539cb900ef45262923d9084e4f40c4981e6 Binary files /dev/null and b/blocks/__pycache__/img2img.cpython-39.pyc differ diff --git a/blocks/__pycache__/inpainting.cpython-310.pyc b/blocks/__pycache__/inpainting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d2e10b5209af9dfe1721f85bbf5fef3beff2d6a Binary files /dev/null and b/blocks/__pycache__/inpainting.cpython-310.pyc differ diff --git a/blocks/__pycache__/inpainting.cpython-39.pyc b/blocks/__pycache__/inpainting.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2597c8dd8a29cda5f38bf346a4ae394886c88134 Binary files /dev/null and b/blocks/__pycache__/inpainting.cpython-39.pyc differ diff --git a/blocks/__pycache__/main.cpython-310.pyc b/blocks/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4851fcef4b191a71a9ba5fcf7401416482b2d5be Binary files /dev/null and b/blocks/__pycache__/main.cpython-310.pyc differ diff --git a/blocks/__pycache__/main.cpython-39.pyc b/blocks/__pycache__/main.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60fce1c6c945b3033bd4f3ce099b98ca30c2e495 Binary files /dev/null and b/blocks/__pycache__/main.cpython-39.pyc differ diff --git a/blocks/__pycache__/text2img.cpython-39.pyc b/blocks/__pycache__/text2img.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4517a9c56a0eccfc9d379a49e9442a5c223ebc08 Binary files /dev/null and b/blocks/__pycache__/text2img.cpython-39.pyc differ diff --git a/blocks/download.py b/blocks/download.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6931e58d4518e32f3cb4e0f1e6d77fbbc5aa80 --- /dev/null +++ b/blocks/download.py @@ -0,0 +1,400 @@ +community_icon_html = """""" + +def get_community_loading_icon(task = "text2img"): + if task == "text2img": + community_icon = """""" + loading_icon = """""" + + elif task == "img2img": + community_icon = """""" + loading_icon = """""" + + elif task == "inpainting": + community_icon = """""" + loading_icon = """""" + return community_icon, loading_icon + +loading_icon_html = """""" + +CSS = """ + #col-container {margin-left: auto; margin-right: auto;} + a {text-decoration-line: underline; font-weight: 600;} + .animate-spin { + animation: spin 1s linear infinite; + } + @keyframes spin { + from { transform: rotate(0deg); } + to { transform: rotate(360deg); } + } + .gradio-container { + font-family: 'IBM Plex Sans', sans-serif; + } + .gr-button { + color: white; + border-color: black; + background: black; + } + input[type='range'] { + accent-color: black; + } + .dark input[type='range'] { + accent-color: #dfdfdf; + } + .container { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; + } + #gallery { + min-height: 22rem; + margin-bottom: 15px; + margin-left: auto; + margin-right: auto; + border-bottom-right-radius: .5rem !important; + border-bottom-left-radius: .5rem !important; + } + #gallery>div>.h-full { + min-height: 20rem; + } + .details:hover { + text-decoration: underline; + } + .gr-button { + white-space: nowrap; + } + .gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); + outline: none; + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); + --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); + --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); + --tw-ring-opacity: .5; + } + #advanced-btn { + font-size: .7rem !important; + line-height: 19px; + margin-top: 12px; + margin-bottom: 12px; + padding: 2px 8px; + border-radius: 14px !important; + } + #advanced-options { + display: none; + margin-bottom: 20px; + } + .footer { + margin-bottom: 45px; + margin-top: 35px; + text-align: center; + border-bottom: 1px solid #e5e5e5; + } + .footer>p { + font-size: .8rem; + display: inline-block; + padding: 0 10px; + transform: translateY(10px); + background: white; + } + .dark .footer { + border-color: #303030; + } + .dark .footer>p { + background: #0b0f19; + } + .acknowledgments h4{ + margin: 1.25em 0 .25em 0; + font-weight: bold; + font-size: 115%; + } + .animate-spin { + animation: spin 1s linear infinite; + } + @keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } + } + #share-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; + margin-top: 10px; + margin-left: auto; + margin-right: auto; + } + #share-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #share-btn * { + all: unset; + } + + #share-btn-inpainting { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #share-btn-inpainting * { + all: unset; + } + + #share-btn-img2img { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #share-btn-img2img * { + all: unset; + } + #share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; + } + #share-btn-container .wrap { + display: none !important; + } + + #download-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; + margin-top: 10px; + margin-left: auto; + margin-right: auto; + } + #download-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #download-btn * { + all: unset; + } + #download-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; + } + #download-btn-container .wrap { + display: none !important; + } + + .gr-form{ + flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; + } + #prompt-container{ + gap: 0; + } + + #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem} + #component-16{border-top-width: 1px!important;margin-top: 1em} + .image_duplication{position: absolute; width: 100px; left: 50px} +""" + +# window.addEventListener('message', function (event) { +# if (event.origin !== 'http://127.0.0.1:5000'){ +# console.log('Origin not allowed'); +# return; +# } +# userId = event.data.userId; +# console.log('User ID received from parent:', userId); +# }); + +def get_share_js(): + share_js = """ + async () => { + // Get the username from the URL itself + const urlParams = new URLSearchParams(window.location.search); + const username = urlParams.get('username'); + + async function uploadFile( + file, + _meta_prompt, + _meta_negative_prompt, + _meta_model_name, + _meta_scheduler_name, + _meta_model_guidance_scale, + _meta_model_num_steps, + _meta_model_image_size, + _meta_seed, + _meta_mask = null, + _meta_reference_image = null, + ){ + const UPLOAD_URL = 'http://127.0.0.1:5000/v1/api/upload-image'; + const formData = new FormData(); + formData.append('file', file); + + // Add the meta data headers to the form data + formData.append('text_prompt', _meta_prompt); + formData.append('negative_prompt', _meta_negative_prompt); + formData.append('model_name', _meta_model_name); + formData.append('model_guidance_scale', _meta_model_guidance_scale); + formData.append('model_num_steps', _meta_model_num_steps); + formData.append('scheduler_name', _meta_scheduler_name); + formData.append('seed', _meta_seed); + formData.append('model_image_size', _meta_model_image_size); + + // Add the optional meta data headers to the form data + if(_meta_mask){ + formData.append('mask', _meta_mask); + } + if(_meta_reference_image){ + formData.append('reference_image', _meta_reference_image); + } + + formData.append('username',username); // This is constant for all the images + const response = await fetch(UPLOAD_URL, { + method: 'POST', + headers: { + 'X-Requested-With': 'XMLHttpRequest', + }, + body: formData, + }); + const url = await response.text(); // This returns the URL of the uploaded file (S3) bucket + return url; + } + + const gradioEl = document.querySelector('gradio-app'); + const imgEls = gradioEl.querySelectorAll('#gallery img'); + + // Get the necessary fields + const promptTxt = gradioEl.querySelector('#prompt-text-input textarea').value; + const negativePromptTxt = gradioEl.querySelector('#negative-prompt-text-input textarea').value; + + console.log(promptTxt); + console.log(negativePromptTxt); + + // Get values from the sliders + const modelGuidanceScale = parseFloat(gradioEl.querySelector('#guidance-scale-slider input').value); + console.log(modelGuidanceScale); + + const numSteps = parseInt(gradioEl.querySelector('#num-inference-step-slider input').value); + const imageSize = parseInt(gradioEl.querySelector('#image-size-slider input').value); + const seed = parseInt(gradioEl.querySelector('#seed-slider input').value); + + console.log(numSteps); + console.log(imageSize); + console.log(seed); + + // Get the values from dropdowns + const modelName = gradioEl.querySelector('#model-dropdown input').value; + const schedulerName = gradioEl.querySelector('#scheduler-dropdown input').value; + + console.log(modelName); + console.log(schedulerName); + + const shareBtnEl = gradioEl.querySelector('#share-btn'); + const shareIconEl = gradioEl.querySelector('#share-btn-share-icon'); + const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon'); + + if(!imgEls.length){ + return; + }; + + shareBtnEl.style.pointerEvents = 'none'; + shareIconEl.style.display = 'none'; + loadingIconEl.style.removeProperty('display'); + const files = await Promise.all( + [...imgEls].map(async (imgEl) => { + const res = await fetch(imgEl.src); + const blob = await res.blob(); + const fileSrc = imgEl.src.split('/').pop(); // Get the file name from the img src path + const imgId = Date.now(); + const fileName = `${fileSrc}-${imgId}.jpg`; // Fixed fileName construction + return new File([blob], fileName, { type: 'image/jpeg' }); + }) + ); + + // Ensure that only one image is uploaded by taking the first element if there are multiple + if (files.length > 1) { + files.splice(1, files.length - 1); + } + + const urls = await Promise.all(files.map((f) => uploadFile( + f, + promptTxt, + negativePromptTxt, + modelName, + schedulerName, + modelGuidanceScale, + numSteps, + imageSize, + seed, + ))); + const htmlImgs = urls.map(url => ``); + + shareBtnEl.style.removeProperty('pointer-events'); + shareIconEl.style.removeProperty('display'); + loadingIconEl.style.display = 'none'; + } + """ + return share_js + +def get_load_from_artwork_js(): + load_artwork_js = """ + async () => { + const urlParams = new URLSearchParams(window.location.search); + const username = urlParams.get('username'); + const artworkId = urlParams.get('artworkId'); + + const LOAD_URL = `http://127.0.0.1:5000/v1/api/load-parameters?artworkId=${artworkId}`; + const response = await fetch(LOAD_URL, { + method: 'GET', + headers: { + 'X-Requested-With': 'XMLHttpRequest', + } + }); + + // Check if the response is okay + if (!response.ok) { + console.error("An error occurred while fetching the parameters."); + return; + } + + const parameters = await response.json(); // Assuming you're getting a JSON response + + // Get the necessary elements + const gradioEl = document.querySelector('gradio-app'); + const promptInput = gradioEl.querySelector('#prompt-text-input textarea'); + const negativePromptInput = gradioEl.querySelector('#negative-prompt-text-input textarea'); + + // Get the slider inputs + const guidanceScaleInput = gradioEl.querySelector('#guidance-scale-slider input'); + const numInferenceStepInput = gradioEl.querySelector('#num-inference-step-slider input'); + const imageSizeInput = gradioEl.querySelector('#image-size-slider input'); + const seedInput = gradioEl.querySelector('#seed-slider input'); + + // Get the dropdown inputs + const modelDropdown = gradioEl.querySelector('#model-dropdown input'); + const schedulerDropdown = gradioEl.querySelector('#scheduler-dropdown input'); + + // Set the values based on the parameters received + promptInput.value = parameters.text_prompt; + negativePromptInput.value = parameters.negative_prompt; + guidanceScaleInput.value = parameters.model_guidance_scale; + numInferenceStepInput.value = parameters.model_num_steps; + imageSizeInput.value = parameters.model_image_size; + seedInput.value = parameters.seed; + modelDropdown.value = parameters.model_name; + schedulerDropdown.value = parameters.scheduler_name; + } + """ + return load_artwork_js \ No newline at end of file diff --git a/blocks/img2img.py b/blocks/img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..094eab4f1ccde4125dc9fa07846c598d496b0de6 --- /dev/null +++ b/blocks/img2img.py @@ -0,0 +1,204 @@ +import gradio as gr +import torch +from diffusers import StableDiffusionImg2ImgPipeline +from .utils.schedulers import SCHEDULER_LIST, get_scheduler_list +from .utils.prompt2prompt import generate +from .utils.device import get_device +from PIL import Image +from .download import get_share_js, get_community_loading_icon, CSS + +IMG2IMG_MODEL_LIST = { + "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5", + "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1", + "OpenJourney v4" : "prompthero/openjourney-v4", + "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0", + "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0" +} + +class StableDiffusionImage2ImageGenerator: + def __init__(self): + self.pipe = None + + def load_model(self, model_path, scheduler): + model_path = IMG2IMG_MODEL_LIST[model_path] + if self.pipe is None: + self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_path, safety_checker=None, torch_dtype=torch.float32 + ) + + device = get_device() + self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler) + + self.pipe.to(device) + #self.pipe.enable_attention_slicing() + + return self.pipe + + def generate_image( + self, + image_path: str, + model_path: str, + prompt: str, + negative_prompt: str, + num_images_per_prompt: int, + scheduler: str, + guidance_scale: int, + num_inference_step: int, + seed_generator=0, + ): + pipe = self.load_model( + model_path=model_path, + scheduler=scheduler, + ) + + if seed_generator == 0: + random_seed = torch.randint(0, 1000000, (1,)) + generator = torch.manual_seed(random_seed) + else: + generator = torch.manual_seed(seed_generator) + + image = Image.open(image_path) + images = pipe( + prompt, + image=image, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=num_inference_step, + guidance_scale=guidance_scale, + generator=generator, + ).images + + return images + + def app(): + demo = gr.Blocks(css=CSS) + with demo: + with gr.Row(): + with gr.Column(): + image2image_image_file = gr.Image( + type="filepath", label="Upload",elem_id="image-upload-img2img" + ).style(height=260) + + image2image_prompt = gr.Textbox( + lines=1, + placeholder="Prompt", + show_label=False, + elem_id="prompt-text-input-img2img", + value='' + ) + + image2image_negative_prompt = gr.Textbox( + lines=1, + placeholder="Negative Prompt", + show_label=False, + elem_id = "negative-prompt-text-input-img2img", + value='' + ) + + # add button for generating a prompt from the prompt + image2image_generate_prompt_button = gr.Button( + label="Generate Prompt", + type="primary", + align="center", + value = "Generate Prompt" + ) + + # show a text box with the generated prompt + image2image_generated_prompt = gr.Textbox( + lines=1, + placeholder="Generated Prompt", + show_label=False, + ) + + with gr.Row(): + with gr.Column(): + image2image_model_path = gr.Dropdown( + choices=list(IMG2IMG_MODEL_LIST.keys()), + value=list(IMG2IMG_MODEL_LIST.keys())[0], + label="Imaget2Image Model Selection", + elem_id="model-dropdown-img2img", + ) + + image2image_guidance_scale = gr.Slider( + minimum=0.1, + maximum=15, + step=0.1, + value=7.5, + label="Guidance Scale", + elem_id = "guidance-scale-slider-img2img" + ) + + image2image_num_inference_step = gr.Slider( + minimum=1, + maximum=100, + step=1, + value=50, + label="Num Inference Step", + elem_id = "num-inference-step-slider-img2img" + ) + with gr.Row(): + with gr.Column(): + image2image_scheduler = gr.Dropdown( + choices=SCHEDULER_LIST, + value=SCHEDULER_LIST[0], + label="Scheduler", + elem_id="scheduler-dropdown-img2img", + ) + image2image_num_images_per_prompt = gr.Slider( + minimum=1, + maximum=30, + step=1, + value=1, + label="Number Of Images", + ) + + image2image_seed_generator = gr.Slider( + label="Seed(0 for random)", + minimum=0, + maximum=1000000, + value=0, + elem_id="seed-slider-img2img", + ) + + image2image_predict_button = gr.Button(value="Generator") + + with gr.Column(): + output_image = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery", + ).style(grid=(1, 2)) + + with gr.Group(elem_id="container-advanced-btns"): + with gr.Group(elem_id="share-btn-container"): + community_icon_html, loading_icon_html = get_community_loading_icon("img2img") + community_icon = gr.HTML(community_icon_html) + loading_icon = gr.HTML(loading_icon_html) + share_button = gr.Button("Save artwork", elem_id="share-btn-img2img") + + image2image_predict_button.click( + fn=StableDiffusionImage2ImageGenerator().generate_image, + inputs=[ + image2image_image_file, + image2image_model_path, + image2image_prompt, + image2image_negative_prompt, + image2image_num_images_per_prompt, + image2image_scheduler, + image2image_guidance_scale, + image2image_num_inference_step, + image2image_seed_generator, + ], + outputs=[output_image], + ) + + image2image_generate_prompt_button.click( + fn=generate, + inputs=[image2image_prompt], + outputs=[image2image_generated_prompt], + ) + + + return demo + + \ No newline at end of file diff --git a/blocks/inpainting.py b/blocks/inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..a61d2d487b76ad2384a087d07bf8176fc6b19050 --- /dev/null +++ b/blocks/inpainting.py @@ -0,0 +1,219 @@ +import gradio as gr +from diffusers import DiffusionPipeline,StableDiffusionInpaintPipeline +import torch +from .utils.prompt2prompt import generate +from .utils.device import get_device +from .utils.schedulers import SCHEDULER_LIST, get_scheduler_list +from .download import get_share_js, CSS, get_community_loading_icon + +INPAINT_MODEL_LIST = { + "Stable Diffusion 2" : "stabilityai/stable-diffusion-2-inpainting", + "Stable Diffusion 1" : "runwayml/stable-diffusion-inpainting", +} + +class StableDiffusionInpaintGenerator: + def __init__(self): + self.pipe = None + + def load_model(self, model_path, scheduler): + model_path = INPAINT_MODEL_LIST[model_path] + if self.pipe is None: + self.pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_path, torch_dtype=torch.float32 + ) + device = get_device() + self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler) + self.pipe.to(device) + #self.pipe.enable_attention_slicing() + return self.pipe + + def generate_image( + self, + pil_image: str, + model_path: str, + prompt: str, + negative_prompt: str, + num_images_per_prompt: int, + scheduler: str, + guidance_scale: int, + num_inference_step: int, + height: int, + width: int, + seed_generator=0, + ): + + image = pil_image["image"].convert("RGB").resize((width, height)) + mask_image = pil_image["mask"].convert("RGB").resize((width, height)) + + pipe = self.load_model(model_path,scheduler) + + if seed_generator == 0: + random_seed = torch.randint(0, 1000000, (1,)) + generator = torch.manual_seed(random_seed) + else: + generator = torch.manual_seed(seed_generator) + + output = pipe( + prompt=prompt, + image=image, + mask_image=mask_image, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=num_inference_step, + guidance_scale=guidance_scale, + generator=generator, + ).images + + return output + + + def app(): + demo = gr.Blocks(css=CSS) + with demo: + with gr.Row(): + with gr.Column(): + stable_diffusion_inpaint_image_file = gr.Image( + source="upload", + tool="sketch", + elem_id="image-upload-inpainting", + type="pil", + label="Upload", + + ).style(height=260) + + stable_diffusion_inpaint_prompt = gr.Textbox( + lines=1, + placeholder="Prompt", + show_label=False, + elem_id="prompt-text-input-inpainting", + value='' + ) + + stable_diffusion_inpaint_negative_prompt = gr.Textbox( + lines=1, + placeholder="Negative Prompt", + show_label=False, + elem_id = "negative-prompt-text-input-inpainting", + value='' + ) + # add button for generating a prompt from the prompt + stable_diffusion_inpaint_generate = gr.Button( + label="Generate Prompt", + type="primary", + align="center", + value = "Generate Prompt" + ) + + # show a text box with the generated prompt + stable_diffusion_inpaint_generated_prompt = gr.Textbox( + lines=1, + placeholder="Generated Prompt", + show_label=False, + ) + + stable_diffusion_inpaint_model_id = gr.Dropdown( + choices=list(INPAINT_MODEL_LIST.keys()), + value=list(INPAINT_MODEL_LIST.keys())[0], + label="Inpaint Model Selection", + elem_id="model-dropdown-inpainting", + ) + with gr.Row(): + with gr.Column(): + stable_diffusion_inpaint_guidance_scale = gr.Slider( + minimum=0.1, + maximum=15, + step=0.1, + value=7.5, + label="Guidance Scale", + elem_id = "guidance-scale-slider-inpainting" + ) + + stable_diffusion_inpaint_num_inference_step = gr.Slider( + minimum=1, + maximum=100, + step=1, + value=50, + label="Num Inference Step", + elem_id = "num-inference-step-slider-inpainting" + ) + + stable_diffusion_inpiant_num_images_per_prompt = gr.Slider( + minimum=1, + maximum=10, + step=1, + value=1, + label="Number Of Images", + ) + + with gr.Row(): + with gr.Column(): + stable_diffusion_inpaint_scheduler = gr.Dropdown( + choices=SCHEDULER_LIST, + value=SCHEDULER_LIST[0], + label="Scheduler", + elem_id="scheduler-dropdown-inpainting", + ) + + stable_diffusion_inpaint_size = gr.Slider( + minimum=128, + maximum=1280, + step=32, + value=512, + label="Image Size", + elem_id="image-size-slider-inpainting", + ) + + stable_diffusion_inpaint_seed_generator = gr.Slider( + label="Seed(0 for random)", + minimum=0, + maximum=1000000, + value=0, + elem_id="seed-slider-inpainting", + ) + + stable_diffusion_inpaint_predict = gr.Button( + value="Generator" + ) + + with gr.Column(): + output_image = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery-inpainting", + ).style(grid=(1, 2)) + + with gr.Group(elem_id="container-advanced-btns"): + with gr.Group(elem_id="share-btn-container"): + community_icon_html, loading_icon_html = get_community_loading_icon("inpainting") + community_icon = gr.HTML(community_icon_html) + loading_icon = gr.HTML(loading_icon_html) + share_button = gr.Button("Save artwork", elem_id="share-btn-inpainting") + + stable_diffusion_inpaint_predict.click( + fn=StableDiffusionInpaintGenerator().generate_image, + inputs=[ + stable_diffusion_inpaint_image_file, + stable_diffusion_inpaint_model_id, + stable_diffusion_inpaint_prompt, + stable_diffusion_inpaint_negative_prompt, + stable_diffusion_inpiant_num_images_per_prompt, + stable_diffusion_inpaint_scheduler, + stable_diffusion_inpaint_guidance_scale, + stable_diffusion_inpaint_num_inference_step, + stable_diffusion_inpaint_size, + stable_diffusion_inpaint_size, + stable_diffusion_inpaint_seed_generator, + ], + outputs=[output_image], + ) + + stable_diffusion_inpaint_generate.click( + fn=generate, + inputs=[stable_diffusion_inpaint_prompt], + outputs=[stable_diffusion_inpaint_generated_prompt], + ) + + + + + return demo diff --git a/blocks/main.py b/blocks/main.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9ac3bdfce4efb9352dcf2585aca75748fb7efc --- /dev/null +++ b/blocks/main.py @@ -0,0 +1,31 @@ +import gradio as gr +from .download import CSS +from .inpainting import StableDiffusionInpaintGenerator +from .text2img import StableDiffusionText2ImageGenerator +from .img2img import StableDiffusionImage2ImageGenerator + +def main_box(username : str = "admin"): + """ + Implement the main interface for the app which will be served + to the frontend. + """ + # customize the share_js button by letting username + app = gr.Blocks(css = CSS) + with app: + with gr.Row(): + with gr.Column(): + with gr.Tab("Text-to-Image", id = 'text-to-image', elem_id='text-to-image-tab'): + StableDiffusionText2ImageGenerator.app() + with gr.Tab("Image-to-Image", id = 'image-to-image', elem_id='image-to-image-tab'): + StableDiffusionImage2ImageGenerator.app() + with gr.Tab("Inpainting", id = 'inpainting', elem_id = 'inpainting-tab'): + StableDiffusionInpaintGenerator.app() + + # Add a footer that will be displayed at the bottom of the app + + gr.HTML(""" +
Minerva : Only your imagination is the limit!
+ """) + + app.queue(concurrency_count=2) + return app \ No newline at end of file diff --git a/blocks/text2img.py b/blocks/text2img.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ac5edbab1d7d934102c492b38ebddd67f898bb --- /dev/null +++ b/blocks/text2img.py @@ -0,0 +1,235 @@ +import gradio as gr +import torch +from diffusers import StableDiffusionPipeline +from .utils.schedulers import SCHEDULER_LIST, get_scheduler_list +from .utils.prompt2prompt import generate +from .utils.device import get_device +from .download import get_share_js, community_icon_html, loading_icon_html, CSS + +#--- create a download button that takes the output image from gradio and downloads it + +TEXT2IMG_MODEL_LIST = { + "OpenJourney v4" : "prompthero/openjourney-v4", + "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5", + "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1", + "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0", + "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0", + "DreamShaper" : "Lykon/DreamShaper", + "NeverEnding-Dream" : "Lykon/NeverEnding-Dream" +} + +class StableDiffusionText2ImageGenerator: + def __init__(self): + self.pipe = None + + def load_model( + self, + model_path, + scheduler + ): + model_path = TEXT2IMG_MODEL_LIST[model_path] + if self.pipe is None: + self.pipe = StableDiffusionPipeline.from_pretrained( + model_path, safety_checker=None, torch_dtype=torch.float32 + ) + + device = get_device() + self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler) + self.pipe.to(device) + #self.pipe.enable_attention_slicing() + + return self.pipe + + def generate_image( + self, + model_path: str, + prompt: str, + negative_prompt: str, + num_images_per_prompt: int, + scheduler: str, + guidance_scale: int, + num_inference_step: int, + height: int, + width: int, + seed_generator=0, + ): + print("model_path", model_path) + print("prompt", prompt) + print("negative_prompt", negative_prompt) + print("num_images_per_prompt", num_images_per_prompt) + print("scheduler", scheduler) + print("guidance_scale", guidance_scale) + print("num_inference_step", num_inference_step) + print("height", height) + print("width", width) + print("seed_generator", seed_generator) + + pipe = self.load_model( + model_path=model_path, + scheduler=scheduler, + ) + if seed_generator == 0: + random_seed = torch.randint(0, 1000000, (1,)) + generator = torch.manual_seed(random_seed) + else: + generator = torch.manual_seed(seed_generator) + + images = pipe( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=num_inference_step, + guidance_scale=guidance_scale, + generator=generator, + ).images + + return images + + + def app(username : str = "admin"): + demo = gr.Blocks(css = CSS) + with demo: + with gr.Row(): + with gr.Column(): + text2image_prompt = gr.Textbox( + lines=1, + placeholder="Prompt", + show_label=False, + elem_id="prompt-text-input", + value='' + ) + + text2image_negative_prompt = gr.Textbox( + lines=1, + placeholder="Negative Prompt", + show_label=False, + elem_id = "negative-prompt-text-input", + value='' + ) + + # add button for generating a prompt from the prompt + text2image_prompt_generate_button = gr.Button( + label="Generate Prompt", + type="primary", + align="center", + value = "Generate Prompt" + ) + + # show a text box with the generated prompt + text2image_prompt_generated_prompt = gr.Textbox( + lines=1, + placeholder="Generated Prompt", + show_label=False, + ) + with gr.Row(): + with gr.Column(): + text2image_model_path = gr.Dropdown( + choices=list(TEXT2IMG_MODEL_LIST.keys()), + value=list(TEXT2IMG_MODEL_LIST.keys())[0], + label="Text2Image Model Selection", + elem_id="model-dropdown", + ) + + text2image_guidance_scale = gr.Slider( + minimum=0.1, + maximum=15, + step=0.1, + value=7.5, + label="Guidance Scale", + elem_id = "guidance-scale-slider" + ) + + text2image_num_inference_step = gr.Slider( + minimum=1, + maximum=100, + step=1, + value=50, + label="Num Inference Step", + elem_id = "num-inference-step-slider" + ) + text2image_num_images_per_prompt = gr.Slider( + minimum=1, + maximum=30, + step=1, + value=1, + label="Number Of Images", + ) + with gr.Row(): + with gr.Column(): + + text2image_scheduler = gr.Dropdown( + choices=SCHEDULER_LIST, + value=SCHEDULER_LIST[0], + label="Scheduler", + elem_id="scheduler-dropdown", + ) + + text2image_size = gr.Slider( + minimum=128, + maximum=1280, + step=32, + value=512, + label="Image Size", + elem_id="image-size-slider", + ) + + text2image_seed_generator = gr.Slider( + label="Seed(0 for random)", + minimum=0, + maximum=1000000, + value=0, + elem_id="seed-slider", + ) + text2image_predict = gr.Button(value="Generator") + + with gr.Column(): + output_image = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery", + ).style(grid=(1, 2), height='auto') + + with gr.Group(elem_id="container-advanced-btns"): + with gr.Group(elem_id="share-btn-container"): + community_icon = gr.HTML(community_icon_html) + loading_icon = gr.HTML(loading_icon_html) + share_button = gr.Button("Save artwork", elem_id="share-btn") + + + text2image_predict.click( + fn=StableDiffusionText2ImageGenerator().generate_image, + inputs=[ + text2image_model_path, + text2image_prompt, + text2image_negative_prompt, + text2image_num_images_per_prompt, + text2image_scheduler, + text2image_guidance_scale, + text2image_num_inference_step, + text2image_size, + text2image_size, + text2image_seed_generator, + ], + outputs=output_image, + ) + + text2image_prompt_generate_button.click( + fn=generate, + inputs=[text2image_prompt], + outputs=[text2image_prompt_generated_prompt], + ) + + # share_button.click( + # None, + # [], + # [], + # _js=get_share_js(), + # ) + + # autoclik the share button + + + + return demo \ No newline at end of file diff --git a/blocks/utils/__init__.py b/blocks/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/blocks/utils/__pycache__/__init__.cpython-39.pyc b/blocks/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d208b7409dd98f7964029118abe7df906b05d000 Binary files /dev/null and b/blocks/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/blocks/utils/__pycache__/device.cpython-39.pyc b/blocks/utils/__pycache__/device.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..841f24cccd338e5da4a5328c27ef1eb2d05bf657 Binary files /dev/null and b/blocks/utils/__pycache__/device.cpython-39.pyc differ diff --git a/blocks/utils/__pycache__/prompt2prompt.cpython-39.pyc b/blocks/utils/__pycache__/prompt2prompt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9f1a07c1b8386a5db5c76050e3ee6db5afae1f5 Binary files /dev/null and b/blocks/utils/__pycache__/prompt2prompt.cpython-39.pyc differ diff --git a/blocks/utils/__pycache__/schedulers.cpython-39.pyc b/blocks/utils/__pycache__/schedulers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3299fa7fb7e0082568e28018ad7afeef141ff27 Binary files /dev/null and b/blocks/utils/__pycache__/schedulers.cpython-39.pyc differ diff --git a/blocks/utils/device.py b/blocks/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e27d43454f1ad2557c5f3971f93893ceff4ae3 --- /dev/null +++ b/blocks/utils/device.py @@ -0,0 +1,16 @@ +import torch + + +def get_device(device = None): + if device is None: + # get cuda -> mps -> cpu + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + if torch.backends.mps.is_built(): + device = "mps" + else: + device = "cpu" + else: + device = "cpu" + return device \ No newline at end of file diff --git a/blocks/utils/prompt2prompt.py b/blocks/utils/prompt2prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..1d50b450c9e048b808156fe22b48f850f7f06331 --- /dev/null +++ b/blocks/utils/prompt2prompt.py @@ -0,0 +1,23 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import pipeline, set_seed +import re +import random +gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2') + +def generate(starting_text): + seed = random.randint(100, 1000000) + set_seed(seed) + + response = gpt2_pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=4) + response_list = [] + for x in response: + resp = x['generated_text'].strip() + if resp != starting_text and len(resp) > (len(starting_text) + 4) and resp.endswith((":", "-", "—")) is False: + response_list.append(resp+'\n') + + response_end = "\n".join(response_list) + response_end = re.sub('[^ ]+\.[^ ]+','', response_end) + response_end = response_end.replace("<", "").replace(">", "") + + if response_end != "": + return response_end diff --git a/blocks/utils/schedulers.py b/blocks/utils/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..926bd3d43cb072f0229578e33db2858c7762ee69 --- /dev/null +++ b/blocks/utils/schedulers.py @@ -0,0 +1,47 @@ +from diffusers import ( + DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + DPMSolverMultistepScheduler +) + +SCHEDULER_LIST = [ + "DDIM", + "EulerA", + "Euler", + "LMS", + "Heun", + "DPMMultistep", +] + + +def get_scheduler_list(pipe, scheduler): + if scheduler == SCHEDULER_LIST[0]: + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + + elif scheduler == SCHEDULER_LIST[1]: + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipe.scheduler.config + ) + + elif scheduler == SCHEDULER_LIST[2]: + pipe.scheduler = EulerDiscreteScheduler.from_config( + pipe.scheduler.config + ) + + elif scheduler == SCHEDULER_LIST[3]: + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + + elif scheduler == SCHEDULER_LIST[4]: + pipe.scheduler = HeunDiscreteScheduler.from_config( + pipe.scheduler.config + ) + + elif scheduler == SCHEDULER_LIST[5]: + pipe.scheduler = DPMSolverMultistepScheduler.from_config( + pipe.scheduler.config + ) + + return pipe \ No newline at end of file diff --git a/diffmodels/__init__.py b/diffmodels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3f735f33280d6b9b40e8e9129f3621794e489b --- /dev/null +++ b/diffmodels/__init__.py @@ -0,0 +1,25 @@ +from .diffusion_utils import build_pipeline + +NAME_TO_MODEL = { + "stable-diffusion-v1-4": + { + "model" : "CompVis/stable-diffusion-v1-4", + "unet" : "CompVis/stable-diffusion-v1-4", + "tokenizer" : "openai/clip-vit-large-patch14", + "text_encoder" : "openai/clip-vit-large-patch14", + }, + "stable_diffusion_v2_1": + { + "model" : "stabilityai/stable-diffusion-2-1", + "unet" : "stabilityai/stable-diffusion-2-1", + "tokenizer" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + "text_encoder" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + } +} + +def get_model(model_name): + model = NAME_TO_MODEL.get(model_name) + if model is None: + raise ValueError(f"Model name {model_name} not found. Available models: {list(NAME_TO_MODEL.keys())}") + vae, tokenizer, text_encoder, unet = build_pipeline(model["model"], model["tokenizer"], model["text_encoder"], model["unet"]) + return vae, tokenizer, text_encoder, unet \ No newline at end of file diff --git a/diffmodels/__pycache__/__init__.cpython-310.pyc b/diffmodels/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cb1f31977ffdd98571638bfdc1585024865f6fd Binary files /dev/null and b/diffmodels/__pycache__/__init__.cpython-310.pyc differ diff --git a/diffmodels/__pycache__/__init__.cpython-39.pyc b/diffmodels/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0daccf864cf8ee03645aff7e6ded950f21c4fe7 Binary files /dev/null and b/diffmodels/__pycache__/__init__.cpython-39.pyc differ diff --git a/diffmodels/__pycache__/diffusion_utils.cpython-310.pyc b/diffmodels/__pycache__/diffusion_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e5a1f0ed42f10c696b0c069aae3af09bc762e4e Binary files /dev/null and b/diffmodels/__pycache__/diffusion_utils.cpython-310.pyc differ diff --git a/diffmodels/__pycache__/diffusion_utils.cpython-39.pyc b/diffmodels/__pycache__/diffusion_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7d2ccbc53bc4fe3657079b4e10084870e17e295 Binary files /dev/null and b/diffmodels/__pycache__/diffusion_utils.cpython-39.pyc differ diff --git a/diffmodels/__pycache__/simple_diffusion.cpython-310.pyc b/diffmodels/__pycache__/simple_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cce2ffd19a7fa667ff19b03b3d0e53fdbcf70400 Binary files /dev/null and b/diffmodels/__pycache__/simple_diffusion.cpython-310.pyc differ diff --git a/diffmodels/diffusion_utils.py b/diffmodels/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..512a477f5a7b3caa61191da014618a2670f83ef7 --- /dev/null +++ b/diffmodels/diffusion_utils.py @@ -0,0 +1,218 @@ +# Utility class for loading and using diffusers model +import diffusers +import transformers + +import torch +from typing import Union +import os +import warnings +import numpy as np +from PIL import Image +import tqdm +from copy import deepcopy +import matplotlib.pyplot as plt + +def build_generator( + device : torch.device, + seed : int, +): + """ + Build a torch.Generator with a given seed. + """ + generator = torch.Generator(device).manual_seed(seed) + return generator + +def load_stablediffusion_model( + model_id : Union[str, os.PathLike], + device : torch.device, + ): + """ + Load a complete diffusion model from a model id. + Returns a tuple of the model and a torch.Generator if seed is not None. + + """ + pipe = diffusers.DiffusionPipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + use_auth_token=True, + ) + pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + try: + pipe = pipe.to(device) + except: + warnings.warn( + f'Could not load model to device:{device}. Using CPU instead.' + ) + pipe = pipe.to('cpu') + device = 'cpu' + + return pipe + + +def visualize_image_grid( + imgs : np.array, + rows : int, + cols : int): + + assert len(imgs) == rows*cols + + # create grid + w, h = imgs[0].size # assuming each image is the same size + + grid = Image.new('RGB', size=(cols*w, rows*h)) + + for i,img in enumerate(imgs): + grid.paste(img, box=(i%cols*w, i//cols*h)) + return grid + + +def build_pipeline( + autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4", + tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14", + text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14", + unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4", + device : torch.device = torch.device('cuda'), + ): + """ + Create a pipeline for StableDiffusion by loading the model and component seperetely. + Arguments: + autoencoder: path to model that autoencoder will be loaded from + tokenizer: path to tokenizer + text_encoder: path to text_encoder + unet: path to unet + """ + # Load the VAE for encoding images into the latent space + vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae') + + # Load tokenizer & text encoder for encoding text into the latent space + tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer) + text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder) + + # Use the UNet model for conditioning the diffusion process + unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet') + + # Move all the components to device + vae = vae.to(device) + text_encoder = text_encoder.to(device) + unet = unet.to(device) + + return vae, tokenizer, text_encoder, unet + +#TODO : Add negative prompting +def custom_stablediffusion_inference( + vae, + tokenizer, + text_encoder, + unet, + noise_scheduler, + prompt : list, + device : torch.device, + num_inference_steps = 100, + image_size = (512,512), + guidance_scale = 8, + seed = 42, + return_image_step = 5, + ): + # Get the text embeddings that will condition the diffusion process + if isinstance(prompt,str): + prompt = [prompt] + + batch_size = len(prompt) + text_input = tokenizer( + prompt, + padding = 'max_length', + truncation = True, + max_length = tokenizer.model_max_length, + return_tensors = 'pt').to(device) + + text_embeddings = text_encoder( + text_input.input_ids.to(device) + )[0] + + # Get the text embeddings for classifier-free guidance + max_length = text_input.input_ids.shape[-1] + empty = [""] * batch_size + uncond_input = tokenizer( + empty, + padding = 'max_length', + truncation = True, + max_length = max_length, + return_tensors = 'pt').to(device) + + uncond_embeddings = text_encoder( + uncond_input.input_ids.to(device) + )[0] + + # Concatenate the text embeddings to get the conditioning vector + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # Generate initial noise + latents = torch.randn( + (1, unet.in_channels, image_size[0] // 8, image_size[1] // 8), + generator=torch.manual_seed(seed) if seed is not None else None + ) + print(latents.shape) + + latents = latents.to(device) + + # Initialize scheduler for noise generation + noise_scheduler.set_timesteps(num_inference_steps) + + latents = latents * noise_scheduler.init_noise_sigma + + noise_scheduler.set_timesteps(num_inference_steps) + for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)): + # If no text embedding is provided (classifier-free guidance), extend the conditioning vector + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) + + with torch.no_grad(): + # Get the noise prediction from the UNet + noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample + + # Perform guidance from the text embeddings + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previously noisy sample x_t -> x_t-1 + latents = noise_scheduler.step(noise_pred, t, latents).prev_sample + + # Now that latent is generated from a noise, use unet decoder to generate images + if i % return_image_step == 0: + with torch.no_grad(): + latents_copy = deepcopy(latents) + image = vae.decode(1/0.18215 * latents_copy).sample + + image = (image / 2 + 0.5).clamp(0,1) + image = image.detach().cpu().permute(0,2,3,1).numpy() # bxhxwxc + images = (image * 255).round().astype("uint8") + + pil_images = [Image.fromarray(img) for img in images] + + yield pil_images[0] + + yield pil_images[0] + +if __name__ == "__main__": + device = torch.device("cpu") + model_id = "stabilityai/stable-diffusion-2-1" + tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + #noise_scheduler = diffusers.LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler") + prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\ + cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\ + render, architectural HD, HQ, 4k, 8k" + + vae, tokenizer, text_encoder, unet = build_pipeline( + autoencoder = model_id, + tokenizer=tokenizer_id, + text_encoder=tokenizer_id, + unet=model_id, + device=device, + ) + image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None) + for i, image in enumerate(image_iter): + image.save(f"step_{i}.png") + diff --git a/diffmodels/simple_diffusion.py b/diffmodels/simple_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..088998d09310b4a0df4332c5d997d8c01dabecdb --- /dev/null +++ b/diffmodels/simple_diffusion.py @@ -0,0 +1,309 @@ +import diffusers +import transformers +import utils.log +import torch +import PIL +from typing import Union, Dict, Any, Optional, List, Tuple, Callable +import os +import re + +class SimpleDiffusion(diffusers.DiffusionPipeline): + """ + An unified interface for diffusion models. This allow us to use : + - txt2img + - img2img + - inpainting + - unconditional image generation + + This class is highly inspired from the Stable-Diffusion-Mega pipeline. + DiffusionPipeline class allow us to load/download all the models hubbed by HuggingFace with an ease. Read more information + about the DiffusionPipeline class here: https://huggingface.co./transformers/main_classes/pipelines.html#transformers.DiffusionPipeline + + Args: + logger (:obj:`utils.log.Logger`): + The logger to use for logging any information. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionMegaSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co./runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + + """ + def __init__( + self, + vae: diffusers.AutoencoderKL, + text_encoder: transformers.CLIPTextModel, + tokenizer: transformers.CLIPTokenizer, + unet: diffusers.UNet2DConditionModel, + scheduler: Union[diffusers.DDIMScheduler, diffusers.PNDMScheduler, diffusers.LMSDiscreteScheduler], + safety_checker: diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker, + feature_extractor: transformers.CLIPFeatureExtractor, + prompt_generation = "succinctly/text2image-prompt-generator" + ): + super().__init__() + self._logger = None + self.register_modules( # already defined in ConfigMixin class, from_pretrained loads these modules + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + + ) + self._generated_prompts = [] + self._enable_prompt_generation = False + if prompt_generation: + self._enable_prompt_generation = True + self._prompt_generator = transformers.pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2') + + def _generate_prompt(self, prompt, **kwargs): + """ + Generate a prompt from a given text. + Args: + prompt (str): The text to generate a prompt from. + **kwargs: Additional keyword arguments passed to the prompt generator pipeline. + """ + max_length = kwargs.pop("max_length", None) + num_return_sequences = kwargs.pop("num_return_sequences", None) + + prompt = self._prompt_generator(prompt, max_length=max_length, num_return_sequences=num_return_sequences) + prompt = self._process_prompt(prompt, **kwargs) + return prompt[0]['generated_text'] + + def _process_prompt(self,original_prompt, prompt_list): + # TODO : Add documentation; add more prompt processing + response_list = [] + for x in prompt_list: + resp = x['generated_text'].strip() + if resp != original_prompt and len(resp) > (len(original_prompt) + 4) and resp.endswith((":", "-", "—")) is False: + response_list.append(resp+'\n') + + response_end = "\n".join(response_list) + response_end = re.sub('[^ ]+\.[^ ]+','', response_end) + response_end = response_end.replace("<", "").replace(">", "") + + if response_end != "": + return response_end + + # Following components are required for the DiffusionPipeline class - but they exist in the StableDiffusionModel class + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + Refer to the [StableDiffusionModel](https://github.com/huggingface/diffusers/blob/main/examples/community/stable_diffusion_mega.py) repo + for more information. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + if self._logger is not None: + self._logger.info("Attention slicing enabled!") + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + if self._logger is not None: + self._logger.info("Attention slicing disabled!") + self.enable_attention_slicing(None) + + def set_logger(self, logger): + r""" + Set logger. This is useful to log information about the model. + """ + self._logger = logger + + @property + def components(self) -> Dict[str, Any]: + # Return the non-private variables + return {k : getattr(self, k) for k in self.config.keys() if not k.startswith("_")} + + @torch.no_grad() + def inpaint( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + if self._enable_prompt_generation: + prompt = self._generate_prompt(p, **kwargs)[0] + self._logger.info(f"Generated prompt: {prompt}") + # For more information on how this function works, please see: https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline + return diffusers.StableDiffusionInpaintPipelineLegacy(**self.components)( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback=callback, + ) + + @torch.no_grad() + def img2img( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + if self._enable_prompt_generation: + prompt = self._generate_prompt(p, **kwargs)[0] + self._logger.info(f"Generated prompt: {prompt}") + # For more information on how this function works, please see: https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline + return diffusers.StableDiffusionImg2ImgPipeline(**self.components)( + prompt=prompt, + init_image=init_image, + strength=strength, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text2img( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + if self._enable_prompt_generation: + prompt = self._generate_prompt(p, **kwargs)[0] + self._logger.info(f"Generated prompt: {prompt}") + + # For more information on how this function https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline + return diffusers.StableDiffusionPipeline(**self.components)( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def upscale( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + num_inference_steps: Optional[int] = 75, + guidance_scale: Optional[float] = 9.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + """ + Upscale an image using the StableDiffusionUpscalePipeline. + """ + if self._enable_prompt_generation: + prompt = self._generate_prompt(p, **kwargs)[0] + self._logger.info(f"Generated prompt: {prompt}") + + return diffusers.StableDiffusionUpscalePipeline(**self.components)( + prompt=prompt, + image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type = output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps) + + def set_scheduler(self, scheduler: Union[diffusers.DDIMScheduler, diffusers.PNDMScheduler, diffusers.LMSDiscreteScheduler, diffusers.EulerDiscreteScheduler]): + """ + Set the scheduler for the pipeline. This is useful for controlling the diffusion process. + Args: + scheduler (Union[diffusers.DDIMScheduler, diffusers.PNDMScheduler, diffusers.LMSDiscreteScheduler]): The scheduler to use. + + """ + self.components["scheduler"] = scheduler \ No newline at end of file diff --git a/diffmodels/textual_inversion.py b/diffmodels/textual_inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..8f641eabc1567d3bd4f20253c6f9517e2d9970ae --- /dev/null +++ b/diffmodels/textual_inversion.py @@ -0,0 +1,269 @@ +#@title Import required libraries +import argparse +import itertools +import math +import os +import random + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.data import Dataset + +import PIL +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +pretrained_model_name_or_path = "stabilityai/stable-diffusion-2" #@param ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] {allow-input: true} + +# example image urls +urls = [ + "https://huggingface.co./datasets/valhalla/images/resolve/main/2.jpeg", + "https://huggingface.co./datasets/valhalla/images/resolve/main/3.jpeg", + "https://huggingface.co./datasets/valhalla/images/resolve/main/5.jpeg", + "https://huggingface.co./datasets/valhalla/images/resolve/main/6.jpeg", + ] + +# what is it that you are teaching? `object` enables you to teach the model a new object to be used, `style` allows you to teach the model a new style one can use. +what_to_teach = "object" #@param ["object", "style"] +# the token you are going to use to represent your new concept (so when you prompt the model, you will say "A `` in an amusement park"). We use angle brackets to differentiate a token from other words/tokens, to avoid collision. +placeholder_token = "" #@param {type:"string"} +# is a word that can summarise what your new concept is, to be used as a starting point +initializer_token = "toy" #@param {type:"string"} + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows*cols + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i%cols*w, i//cols*h)) + return grid + +#@title Setup the prompt templates for training +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + +#@title Setup the dataset +class TextualInversionDataset(Dataset): + def __init__( + self, + data_root, + tokenizer, + learnable_property="object", # [object, style] + size=512, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + center_crop=False, + ): + + self.data_root = data_root + self.tokenizer = tokenizer + self.learnable_property = learnable_property + self.size = size + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + self.num_images = len(self.image_paths) + self._length = self.num_images + + if set == "train": + self._length = self.num_images * repeats + + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + text = random.choice(self.templates).format(placeholder_string) + + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip_transform(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + return example + + +#@title Load the tokenizer and add the placeholder token as a additional special token. +tokenizer = CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", +) + +# Add the placeholder token in tokenizer +num_added_tokens = tokenizer.add_tokens(placeholder_token) +if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + + +#@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token +# Convert the initializer_token, placeholder_token to ids +token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) +# Check if initializer_token is a single token or a sequence of tokens +if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + +initializer_token_id = token_ids[0] +placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) + + +#@title Load the Stable Diffusion model +# Load models and create wrapper for stable diffusion +# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path) +# del pipeline +text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder" +) +vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, subfolder="vae" +) +unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet" +) + +text_encoder.resize_token_embeddings(len(tokenizer)) + +token_embeds = text_encoder.get_input_embeddings().weight.data +token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + +def freeze_params(params): + for param in params: + param.requires_grad = False + +# Freeze vae and unet +freeze_params(vae.parameters()) +freeze_params(unet.parameters()) +# Freeze all parameters except for the token embeddings in text encoder +params_to_freeze = itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), +) +freeze_params(params_to_freeze) + +train_dataset = TextualInversionDataset( + data_root=save_path, + tokenizer=tokenizer, + size=vae.sample_size, + placeholder_token=placeholder_token, + repeats=100, + learnable_property=what_to_teach, #Option selected above between object and style + center_crop=False, + set="train", +) + +def create_dataloader(train_batch_size=1): + return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) + +noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler") + +# TODO: Add training scripts diff --git a/image_0.png b/image_0.png new file mode 100644 index 0000000000000000000000000000000000000000..6304f0a85be5206347837ce48e9cbdb1be84da89 Binary files /dev/null and b/image_0.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6eeb217a9a4fb385180322bb9c56b615f068546 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,80 @@ +accelerate==0.19.0 +aiofiles==23.1.0 +aiohttp==3.8.4 +aiosignal==1.3.1 +altair==5.0.0 +anyio==3.6.2 +async-timeout==4.0.2 +attrs==23.1.0 +certifi==2023.5.7 +charset-normalizer==3.1.0 +click==8.1.3 +contourpy==1.0.7 +cycler==0.11.0 +diffusers==0.16.1 +fastapi==0.95.2 +ffmpy==0.3.0 +filelock==3.12.0 +fonttools==4.39.4 +frozenlist==1.3.3 +fsspec==2023.5.0 +gradio==3.30.0 +gradio_client==0.2.4 +h11==0.14.0 +httpcore==0.17.0 +httpx==0.24.0 +huggingface-hub==0.14.1 +idna==3.4 +importlib-metadata==6.6.0 +importlib-resources==5.12.0 +Jinja2==3.1.2 +jsonschema==4.17.3 +kiwisolver==1.4.4 +linkify-it-py==2.0.2 +markdown-it-py==2.2.0 +MarkupSafe==2.1.2 +matplotlib==3.7.1 +mdit-py-plugins==0.3.3 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.0.4 +networkx==3.1 +numpy==1.24.3 +opencv-python==4.7.0.72 +orjson==3.8.10 +packaging==23.1 +pandas==2.0.1 +Pillow==9.5.0 +pip==23.1.2 +psutil==5.9.5 +pydantic==1.10.7 +pydub==0.25.1 +Pygments==2.15.1 +pyparsing==3.0.9 +pyrsistent==0.19.3 +python-dateutil==2.8.2 +python-multipart==0.0.6 +pytz==2023.3 +PyYAML==6.0 +regex==2023.5.5 +requests==2.30.0 +semantic-version==2.10.0 +setuptools==67.7.2 +six==1.16.0 +sniffio==1.3.0 +starlette==0.27.0 +sympy==1.12 +tokenizers==0.13.3 +toolz==0.12.0 +torch==2.0.1 +tqdm==4.65.0 +transformers==4.29.1 +typing_extensions==4.5.0 +tzdata==2023.3 +uc-micro-py==1.0.2 +urllib3==2.0.2 +uvicorn==0.22.0 +websockets==11.0.3 +wheel==0.40.0 +yarl==1.9.2 +zipp==3.15.0 diff --git a/static/load_from_artwork.js b/static/load_from_artwork.js new file mode 100644 index 0000000000000000000000000000000000000000..e2d2d2ab76642acd749dfab58196e69f8a6bc715 --- /dev/null +++ b/static/load_from_artwork.js @@ -0,0 +1,46 @@ +async () => { + const urlParams = new URLSearchParams(window.location.search); + const username = urlParams.get('username'); + const artworkId = urlParams.get('artworkId'); + + const LOAD_URL = `http://127.0.0.1:5000/v1/api/load-parameters/${artworkId}`; + const response = await fetch(LOAD_URL, { + method: 'GET', + headers: { + 'X-Requested-With': 'XMLHttpRequest', + } + }); + + // Check if the response is okay + if (!response.ok) { + console.error("An error occurred while fetching the parameters."); + return; + } + + const parameters = await response.json(); // Assuming you're getting a JSON response + + // Get the necessary elements + const gradioEl = document.querySelector('gradio-app'); + const promptInput = gradioEl.querySelector('#prompt-text-input textarea'); + const negativePromptInput = gradioEl.querySelector('#negative-prompt-text-input textarea'); + + // Get the slider inputs + const guidanceScaleInput = gradioEl.querySelector('#guidance-scale-slider input'); + const numInferenceStepInput = gradioEl.querySelector('#num-inference-step-slider input'); + const imageSizeInput = gradioEl.querySelector('#image-size-slider input'); + const seedInput = gradioEl.querySelector('#seed-slider input'); + + // Get the dropdown inputs + const modelDropdown = gradioEl.querySelector('#model-dropdown input'); + const schedulerDropdown = gradioEl.querySelector('#scheduler-dropdown input'); + + // Set the values based on the parameters received + promptInput.value = parameters.text_prompt; + negativePromptInput.value = parameters.negative_prompt; + guidanceScaleInput.value = parameters.model_guidance_scale; + numInferenceStepInput.value = parameters.model_num_steps; + imageSizeInput.value = parameters.model_image_size; + seedInput.value = parameters.seed; + modelDropdown.value = parameters.model_name; + schedulerDropdown.value = parameters.scheduler_name; +} diff --git a/static/save_artwork.js b/static/save_artwork.js new file mode 100644 index 0000000000000000000000000000000000000000..887645a7ea5d0e3de4e5f1753f101f01fc330476 --- /dev/null +++ b/static/save_artwork.js @@ -0,0 +1,63 @@ +async () => { + // Get the username from the URL itself + const gradioEl = document.querySelector('gradio-app'); + const imgEls = gradioEl.querySelectorAll('#gallery img'); + + // Get the necessary fields + const promptTxt = gradioEl.querySelector('#prompt-text-input textarea').value; + const negativePromptTxt = gradioEl.querySelector('#negative-prompt-text-input textarea').value; + + // Get values from the sliders + const modelGuidanceScale = parseFloat(gradioEl.querySelector('#guidance-scale-slider input').value); + + const numSteps = parseInt(gradioEl.querySelector('#num-inference-step-slider input').value); + const imageSize = parseInt(gradioEl.querySelector('#image-size-slider input').value); + const seed = parseInt(gradioEl.querySelector('#seed-slider input').value); + + // Get the values from dropdowns + const modelName = gradioEl.querySelector('#model-dropdown input').value; + const schedulerName = gradioEl.querySelector('#scheduler-dropdown input').value; + + const shareBtnEl = gradioEl.querySelector('#share-btn'); + const shareIconEl = gradioEl.querySelector('#share-btn-share-icon'); + const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon'); + + if(!imgEls.length){ + return; + }; + + shareBtnEl.style.pointerEvents = 'none'; + shareIconEl.style.display = 'none'; + loadingIconEl.style.removeProperty('display'); + const files = await Promise.all( + [...imgEls].map(async (imgEl) => { + const res = await fetch(imgEl.src); + const blob = await res.blob(); + const fileSrc = imgEl.src.split('/').pop(); // Get the file name from the img src path + const imgId = Date.now(); + const fileName = `${fileSrc}-${imgId}.jpg`; // Fixed fileName construction + return new File([blob], fileName, { type: 'image/jpeg' }); + }) + ); + + // Ensure that only one image is uploaded by taking the first element if there are multiple + if (files.length > 1) { + files.splice(1, files.length - 1); + } + + const urls = await Promise.all(files.map((f) => uploadFile( + f, + promptTxt, + negativePromptTxt, + modelName, + schedulerName, + modelGuidanceScale, + numSteps, + imageSize, + seed, + ))); + + shareBtnEl.style.removeProperty('pointer-events'); + shareIconEl.style.removeProperty('display'); + loadingIconEl.style.display = 'none'; + } \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79f5ab8bb19adf22d7f6359debb472b763a69a00 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/device.cpython-310.pyc b/utils/__pycache__/device.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd158d91d29d279588b4f43a2caec19697f7d13e Binary files /dev/null and b/utils/__pycache__/device.cpython-310.pyc differ diff --git a/utils/__pycache__/image.cpython-310.pyc b/utils/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71595b047953b5dc6f36d80c68dc6aa88c301f8a Binary files /dev/null and b/utils/__pycache__/image.cpython-310.pyc differ diff --git a/utils/__pycache__/log.cpython-310.pyc b/utils/__pycache__/log.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06e4ab20a87da667f8634a8b66a9d1930d0134f8 Binary files /dev/null and b/utils/__pycache__/log.cpython-310.pyc differ diff --git a/utils/__pycache__/prompt2prompt.cpython-310.pyc b/utils/__pycache__/prompt2prompt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5e577fa78b2bf301f02da8a9d2b6e363e7febfd Binary files /dev/null and b/utils/__pycache__/prompt2prompt.cpython-310.pyc differ diff --git a/utils/device.py b/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..a707c355f9264424611d7728626bd253cc832c8d --- /dev/null +++ b/utils/device.py @@ -0,0 +1,22 @@ +from typing import Union +import torch + +def set_device(device : Union[str, torch.device]) -> torch.device: + """ + Set the device to use for inference. Recommended to use GPU. + Arguments: + device Union[str, torch.device] + The device to use for inference. Can be either a string or a torch.device object. + + Returns: + torch.device + The device to use for inference. + """ + if isinstance(device, str): + if device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + elif device == 'mps' and torch.backends.mps.is_built(): + device = torch.device('mps') + else: + device = torch.device(device) + return device \ No newline at end of file diff --git a/utils/image.py b/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..e9477805047f44e3f2f6c2ffdbc36e1fd03fe7c1 --- /dev/null +++ b/utils/image.py @@ -0,0 +1,29 @@ +import io +from datetime import datetime, timezone +import uuid +import boto3 +from config import AWSConfig + +#Upload data to s3 +def write_to_s3(image, fname, region_name='ap-south-1'): + """ + Write an image to s3. Returns the url. Requires AWSConfig + # TODO : Add error handling + # TODO : Add logging + """ + s3 = boto3.client('s3', region_name,aws_access_key_id=AWSConfig.aws_access_key_id, aws_secret_access_key=AWSConfig.aws_secret_access_key) + s3.upload_fileobj(image,AWSConfig.bucket_name,fname) + return f'https://{AWSConfig.bucket_name}.s3.{region_name}.amazonaws.com/{fname}' + +def save_image(img): + """ + Save an image to s3. Returns the url and filename for JSON output + # TODO : Add error handling + """ + in_mem_file = io.BytesIO() + img.save(in_mem_file, format = 'png') + in_mem_file.seek(0) + dt = datetime.now() + file_name = str(uuid.uuid4())+'-'+str(int(dt.replace(tzinfo=timezone.utc).timestamp())) + img_url = write_to_s3(in_mem_file,f'sdimage/{file_name}.jpeg') + return img_url,file_name diff --git a/utils/log.py b/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..58c68a0193127564d1f440c43c4a046e7be37be4 --- /dev/null +++ b/utils/log.py @@ -0,0 +1,27 @@ +########### +# Utlities for logging +########### +import logging + +def set_logger(): + """ + Custom logger for logging to console and file + Returns: + logger + The logger object + """ + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + + # create formatter + formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s') + + # add formatter to ch + ch.setFormatter(formatter) + + logger.addHandler(ch) + + return logger