diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..0cebb3f50a1ec24241e261472a48ae2b3e04dfd8
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,20 @@
+# Use an official CUDA runtime as a parent image
+FROM nvidia/cuda:11.2.2-runtime-ubuntu20.04
+
+# Set the working directory in the container
+WORKDIR /usr/src/app
+
+# Install Python
+RUN apt-get update && apt-get install -y python3.8 python3-pip
+
+# Copy the current directory contents into the container at /usr/src/app
+COPY . /usr/src/app
+
+# Install any needed packages specified in requirements.txt
+RUN pip3 install --no-cache-dir -r requirements.txt
+
+# Make port 80 available to the world outside this container
+EXPOSE 7860
+
+# Run app.py when the container launches
+CMD ["python3", "app.py"]
diff --git a/README.md b/README.md
index ce8b683ef0c224f09c8bb3fd279a3f09011815d6..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +0,0 @@
----
-title: Minerva Generate Docker
-emoji: 😻
-colorFrom: red
-colorTo: purple
-sdk: docker
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
diff --git a/__pycache__/app.cpython-310.pyc b/__pycache__/app.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0c712b54bd6825f797efcd5621667b551c64765
Binary files /dev/null and b/__pycache__/app.cpython-310.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f0bf48876e1dc6ada816686a35e32a4f8ddeb2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,5 @@
+import gradio as gr
+from blocks.main import main_box
+blocks = main_box()
+blocks.launch(share = True)
+
diff --git a/blocks/.DS_Store b/blocks/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..df93a9b0d9655287359ac7432a9c6c959589de12
Binary files /dev/null and b/blocks/.DS_Store differ
diff --git a/blocks/__init__.py b/blocks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08103f4d0e1c4396623022cfd2c4ff0ec17381b
--- /dev/null
+++ b/blocks/__init__.py
@@ -0,0 +1,17 @@
+IMG2IMG_MODEL_LIST = {
+ "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5",
+ "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1",
+ "OpenJourney v4" : "prompthero/openjourney-v4",
+ "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0",
+ "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0"
+}
+
+TEXT2IMG_MODEL_LIST = {
+ "OpenJourney v4" : "prompthero/openjourney-v4",
+ "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5",
+ "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1",
+ "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0",
+ "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0",
+ "DreamShaper" : "Lykon/DreamShaper",
+ "NeverEnding-Dream" : "Lykon/NeverEnding-Dream"
+}
\ No newline at end of file
diff --git a/blocks/__pycache__/__init__.cpython-310.pyc b/blocks/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e32727616cd70082efd0f4db814fc1a9c2e6856
Binary files /dev/null and b/blocks/__pycache__/__init__.cpython-310.pyc differ
diff --git a/blocks/__pycache__/__init__.cpython-39.pyc b/blocks/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bb98b075ea4d2e2b1b5186edf7160707c123984
Binary files /dev/null and b/blocks/__pycache__/__init__.cpython-39.pyc differ
diff --git a/blocks/__pycache__/download.cpython-310.pyc b/blocks/__pycache__/download.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26929a76993e7fa3eb6e1f7d16a7420647843bb7
Binary files /dev/null and b/blocks/__pycache__/download.cpython-310.pyc differ
diff --git a/blocks/__pycache__/download.cpython-39.pyc b/blocks/__pycache__/download.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7a7adaa7111b7d81198b617bbd28a47e6a11fa4
Binary files /dev/null and b/blocks/__pycache__/download.cpython-39.pyc differ
diff --git a/blocks/__pycache__/img2img.cpython-39.pyc b/blocks/__pycache__/img2img.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44815539cb900ef45262923d9084e4f40c4981e6
Binary files /dev/null and b/blocks/__pycache__/img2img.cpython-39.pyc differ
diff --git a/blocks/__pycache__/inpainting.cpython-310.pyc b/blocks/__pycache__/inpainting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d2e10b5209af9dfe1721f85bbf5fef3beff2d6a
Binary files /dev/null and b/blocks/__pycache__/inpainting.cpython-310.pyc differ
diff --git a/blocks/__pycache__/inpainting.cpython-39.pyc b/blocks/__pycache__/inpainting.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2597c8dd8a29cda5f38bf346a4ae394886c88134
Binary files /dev/null and b/blocks/__pycache__/inpainting.cpython-39.pyc differ
diff --git a/blocks/__pycache__/main.cpython-310.pyc b/blocks/__pycache__/main.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4851fcef4b191a71a9ba5fcf7401416482b2d5be
Binary files /dev/null and b/blocks/__pycache__/main.cpython-310.pyc differ
diff --git a/blocks/__pycache__/main.cpython-39.pyc b/blocks/__pycache__/main.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60fce1c6c945b3033bd4f3ce099b98ca30c2e495
Binary files /dev/null and b/blocks/__pycache__/main.cpython-39.pyc differ
diff --git a/blocks/__pycache__/text2img.cpython-39.pyc b/blocks/__pycache__/text2img.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4517a9c56a0eccfc9d379a49e9442a5c223ebc08
Binary files /dev/null and b/blocks/__pycache__/text2img.cpython-39.pyc differ
diff --git a/blocks/download.py b/blocks/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6931e58d4518e32f3cb4e0f1e6d77fbbc5aa80
--- /dev/null
+++ b/blocks/download.py
@@ -0,0 +1,400 @@
+community_icon_html = """"""
+
+def get_community_loading_icon(task = "text2img"):
+ if task == "text2img":
+ community_icon = """"""
+ loading_icon = """"""
+
+ elif task == "img2img":
+ community_icon = """"""
+ loading_icon = """"""
+
+ elif task == "inpainting":
+ community_icon = """"""
+ loading_icon = """"""
+ return community_icon, loading_icon
+
+loading_icon_html = """"""
+
+CSS = """
+ #col-container {margin-left: auto; margin-right: auto;}
+ a {text-decoration-line: underline; font-weight: 600;}
+ .animate-spin {
+ animation: spin 1s linear infinite;
+ }
+ @keyframes spin {
+ from { transform: rotate(0deg); }
+ to { transform: rotate(360deg); }
+ }
+ .gradio-container {
+ font-family: 'IBM Plex Sans', sans-serif;
+ }
+ .gr-button {
+ color: white;
+ border-color: black;
+ background: black;
+ }
+ input[type='range'] {
+ accent-color: black;
+ }
+ .dark input[type='range'] {
+ accent-color: #dfdfdf;
+ }
+ .container {
+ max-width: 730px;
+ margin: auto;
+ padding-top: 1.5rem;
+ }
+ #gallery {
+ min-height: 22rem;
+ margin-bottom: 15px;
+ margin-left: auto;
+ margin-right: auto;
+ border-bottom-right-radius: .5rem !important;
+ border-bottom-left-radius: .5rem !important;
+ }
+ #gallery>div>.h-full {
+ min-height: 20rem;
+ }
+ .details:hover {
+ text-decoration: underline;
+ }
+ .gr-button {
+ white-space: nowrap;
+ }
+ .gr-button:focus {
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
+ outline: none;
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
+ --tw-border-opacity: 1;
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
+ --tw-ring-opacity: .5;
+ }
+ #advanced-btn {
+ font-size: .7rem !important;
+ line-height: 19px;
+ margin-top: 12px;
+ margin-bottom: 12px;
+ padding: 2px 8px;
+ border-radius: 14px !important;
+ }
+ #advanced-options {
+ display: none;
+ margin-bottom: 20px;
+ }
+ .footer {
+ margin-bottom: 45px;
+ margin-top: 35px;
+ text-align: center;
+ border-bottom: 1px solid #e5e5e5;
+ }
+ .footer>p {
+ font-size: .8rem;
+ display: inline-block;
+ padding: 0 10px;
+ transform: translateY(10px);
+ background: white;
+ }
+ .dark .footer {
+ border-color: #303030;
+ }
+ .dark .footer>p {
+ background: #0b0f19;
+ }
+ .acknowledgments h4{
+ margin: 1.25em 0 .25em 0;
+ font-weight: bold;
+ font-size: 115%;
+ }
+ .animate-spin {
+ animation: spin 1s linear infinite;
+ }
+ @keyframes spin {
+ from {
+ transform: rotate(0deg);
+ }
+ to {
+ transform: rotate(360deg);
+ }
+ }
+ #share-btn-container {
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
+ margin-top: 10px;
+ margin-left: auto;
+ margin-right: auto;
+ }
+ #share-btn {
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
+ }
+ #share-btn * {
+ all: unset;
+ }
+
+ #share-btn-inpainting {
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
+ }
+ #share-btn-inpainting * {
+ all: unset;
+ }
+
+ #share-btn-img2img {
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
+ }
+ #share-btn-img2img * {
+ all: unset;
+ }
+ #share-btn-container div:nth-child(-n+2){
+ width: auto !important;
+ min-height: 0px !important;
+ }
+ #share-btn-container .wrap {
+ display: none !important;
+ }
+
+ #download-btn-container {
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
+ margin-top: 10px;
+ margin-left: auto;
+ margin-right: auto;
+ }
+ #download-btn {
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
+ }
+ #download-btn * {
+ all: unset;
+ }
+ #download-btn-container div:nth-child(-n+2){
+ width: auto !important;
+ min-height: 0px !important;
+ }
+ #download-btn-container .wrap {
+ display: none !important;
+ }
+
+ .gr-form{
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
+ }
+ #prompt-container{
+ gap: 0;
+ }
+
+ #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
+ #component-16{border-top-width: 1px!important;margin-top: 1em}
+ .image_duplication{position: absolute; width: 100px; left: 50px}
+"""
+
+# window.addEventListener('message', function (event) {
+# if (event.origin !== 'http://127.0.0.1:5000'){
+# console.log('Origin not allowed');
+# return;
+# }
+# userId = event.data.userId;
+# console.log('User ID received from parent:', userId);
+# });
+
+def get_share_js():
+ share_js = """
+ async () => {
+ // Get the username from the URL itself
+ const urlParams = new URLSearchParams(window.location.search);
+ const username = urlParams.get('username');
+
+ async function uploadFile(
+ file,
+ _meta_prompt,
+ _meta_negative_prompt,
+ _meta_model_name,
+ _meta_scheduler_name,
+ _meta_model_guidance_scale,
+ _meta_model_num_steps,
+ _meta_model_image_size,
+ _meta_seed,
+ _meta_mask = null,
+ _meta_reference_image = null,
+ ){
+ const UPLOAD_URL = 'http://127.0.0.1:5000/v1/api/upload-image';
+ const formData = new FormData();
+ formData.append('file', file);
+
+ // Add the meta data headers to the form data
+ formData.append('text_prompt', _meta_prompt);
+ formData.append('negative_prompt', _meta_negative_prompt);
+ formData.append('model_name', _meta_model_name);
+ formData.append('model_guidance_scale', _meta_model_guidance_scale);
+ formData.append('model_num_steps', _meta_model_num_steps);
+ formData.append('scheduler_name', _meta_scheduler_name);
+ formData.append('seed', _meta_seed);
+ formData.append('model_image_size', _meta_model_image_size);
+
+ // Add the optional meta data headers to the form data
+ if(_meta_mask){
+ formData.append('mask', _meta_mask);
+ }
+ if(_meta_reference_image){
+ formData.append('reference_image', _meta_reference_image);
+ }
+
+ formData.append('username',username); // This is constant for all the images
+ const response = await fetch(UPLOAD_URL, {
+ method: 'POST',
+ headers: {
+ 'X-Requested-With': 'XMLHttpRequest',
+ },
+ body: formData,
+ });
+ const url = await response.text(); // This returns the URL of the uploaded file (S3) bucket
+ return url;
+ }
+
+ const gradioEl = document.querySelector('gradio-app');
+ const imgEls = gradioEl.querySelectorAll('#gallery img');
+
+ // Get the necessary fields
+ const promptTxt = gradioEl.querySelector('#prompt-text-input textarea').value;
+ const negativePromptTxt = gradioEl.querySelector('#negative-prompt-text-input textarea').value;
+
+ console.log(promptTxt);
+ console.log(negativePromptTxt);
+
+ // Get values from the sliders
+ const modelGuidanceScale = parseFloat(gradioEl.querySelector('#guidance-scale-slider input').value);
+ console.log(modelGuidanceScale);
+
+ const numSteps = parseInt(gradioEl.querySelector('#num-inference-step-slider input').value);
+ const imageSize = parseInt(gradioEl.querySelector('#image-size-slider input').value);
+ const seed = parseInt(gradioEl.querySelector('#seed-slider input').value);
+
+ console.log(numSteps);
+ console.log(imageSize);
+ console.log(seed);
+
+ // Get the values from dropdowns
+ const modelName = gradioEl.querySelector('#model-dropdown input').value;
+ const schedulerName = gradioEl.querySelector('#scheduler-dropdown input').value;
+
+ console.log(modelName);
+ console.log(schedulerName);
+
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
+
+ if(!imgEls.length){
+ return;
+ };
+
+ shareBtnEl.style.pointerEvents = 'none';
+ shareIconEl.style.display = 'none';
+ loadingIconEl.style.removeProperty('display');
+ const files = await Promise.all(
+ [...imgEls].map(async (imgEl) => {
+ const res = await fetch(imgEl.src);
+ const blob = await res.blob();
+ const fileSrc = imgEl.src.split('/').pop(); // Get the file name from the img src path
+ const imgId = Date.now();
+ const fileName = `${fileSrc}-${imgId}.jpg`; // Fixed fileName construction
+ return new File([blob], fileName, { type: 'image/jpeg' });
+ })
+ );
+
+ // Ensure that only one image is uploaded by taking the first element if there are multiple
+ if (files.length > 1) {
+ files.splice(1, files.length - 1);
+ }
+
+ const urls = await Promise.all(files.map((f) => uploadFile(
+ f,
+ promptTxt,
+ negativePromptTxt,
+ modelName,
+ schedulerName,
+ modelGuidanceScale,
+ numSteps,
+ imageSize,
+ seed,
+ )));
+ const htmlImgs = urls.map(url => ``);
+
+ shareBtnEl.style.removeProperty('pointer-events');
+ shareIconEl.style.removeProperty('display');
+ loadingIconEl.style.display = 'none';
+ }
+ """
+ return share_js
+
+def get_load_from_artwork_js():
+ load_artwork_js = """
+ async () => {
+ const urlParams = new URLSearchParams(window.location.search);
+ const username = urlParams.get('username');
+ const artworkId = urlParams.get('artworkId');
+
+ const LOAD_URL = `http://127.0.0.1:5000/v1/api/load-parameters?artworkId=${artworkId}`;
+ const response = await fetch(LOAD_URL, {
+ method: 'GET',
+ headers: {
+ 'X-Requested-With': 'XMLHttpRequest',
+ }
+ });
+
+ // Check if the response is okay
+ if (!response.ok) {
+ console.error("An error occurred while fetching the parameters.");
+ return;
+ }
+
+ const parameters = await response.json(); // Assuming you're getting a JSON response
+
+ // Get the necessary elements
+ const gradioEl = document.querySelector('gradio-app');
+ const promptInput = gradioEl.querySelector('#prompt-text-input textarea');
+ const negativePromptInput = gradioEl.querySelector('#negative-prompt-text-input textarea');
+
+ // Get the slider inputs
+ const guidanceScaleInput = gradioEl.querySelector('#guidance-scale-slider input');
+ const numInferenceStepInput = gradioEl.querySelector('#num-inference-step-slider input');
+ const imageSizeInput = gradioEl.querySelector('#image-size-slider input');
+ const seedInput = gradioEl.querySelector('#seed-slider input');
+
+ // Get the dropdown inputs
+ const modelDropdown = gradioEl.querySelector('#model-dropdown input');
+ const schedulerDropdown = gradioEl.querySelector('#scheduler-dropdown input');
+
+ // Set the values based on the parameters received
+ promptInput.value = parameters.text_prompt;
+ negativePromptInput.value = parameters.negative_prompt;
+ guidanceScaleInput.value = parameters.model_guidance_scale;
+ numInferenceStepInput.value = parameters.model_num_steps;
+ imageSizeInput.value = parameters.model_image_size;
+ seedInput.value = parameters.seed;
+ modelDropdown.value = parameters.model_name;
+ schedulerDropdown.value = parameters.scheduler_name;
+ }
+ """
+ return load_artwork_js
\ No newline at end of file
diff --git a/blocks/img2img.py b/blocks/img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..094eab4f1ccde4125dc9fa07846c598d496b0de6
--- /dev/null
+++ b/blocks/img2img.py
@@ -0,0 +1,204 @@
+import gradio as gr
+import torch
+from diffusers import StableDiffusionImg2ImgPipeline
+from .utils.schedulers import SCHEDULER_LIST, get_scheduler_list
+from .utils.prompt2prompt import generate
+from .utils.device import get_device
+from PIL import Image
+from .download import get_share_js, get_community_loading_icon, CSS
+
+IMG2IMG_MODEL_LIST = {
+ "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5",
+ "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1",
+ "OpenJourney v4" : "prompthero/openjourney-v4",
+ "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0",
+ "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0"
+}
+
+class StableDiffusionImage2ImageGenerator:
+ def __init__(self):
+ self.pipe = None
+
+ def load_model(self, model_path, scheduler):
+ model_path = IMG2IMG_MODEL_LIST[model_path]
+ if self.pipe is None:
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ model_path, safety_checker=None, torch_dtype=torch.float32
+ )
+
+ device = get_device()
+ self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler)
+
+ self.pipe.to(device)
+ #self.pipe.enable_attention_slicing()
+
+ return self.pipe
+
+ def generate_image(
+ self,
+ image_path: str,
+ model_path: str,
+ prompt: str,
+ negative_prompt: str,
+ num_images_per_prompt: int,
+ scheduler: str,
+ guidance_scale: int,
+ num_inference_step: int,
+ seed_generator=0,
+ ):
+ pipe = self.load_model(
+ model_path=model_path,
+ scheduler=scheduler,
+ )
+
+ if seed_generator == 0:
+ random_seed = torch.randint(0, 1000000, (1,))
+ generator = torch.manual_seed(random_seed)
+ else:
+ generator = torch.manual_seed(seed_generator)
+
+ image = Image.open(image_path)
+ images = pipe(
+ prompt,
+ image=image,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ num_inference_steps=num_inference_step,
+ guidance_scale=guidance_scale,
+ generator=generator,
+ ).images
+
+ return images
+
+ def app():
+ demo = gr.Blocks(css=CSS)
+ with demo:
+ with gr.Row():
+ with gr.Column():
+ image2image_image_file = gr.Image(
+ type="filepath", label="Upload",elem_id="image-upload-img2img"
+ ).style(height=260)
+
+ image2image_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Prompt",
+ show_label=False,
+ elem_id="prompt-text-input-img2img",
+ value=''
+ )
+
+ image2image_negative_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Negative Prompt",
+ show_label=False,
+ elem_id = "negative-prompt-text-input-img2img",
+ value=''
+ )
+
+ # add button for generating a prompt from the prompt
+ image2image_generate_prompt_button = gr.Button(
+ label="Generate Prompt",
+ type="primary",
+ align="center",
+ value = "Generate Prompt"
+ )
+
+ # show a text box with the generated prompt
+ image2image_generated_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Generated Prompt",
+ show_label=False,
+ )
+
+ with gr.Row():
+ with gr.Column():
+ image2image_model_path = gr.Dropdown(
+ choices=list(IMG2IMG_MODEL_LIST.keys()),
+ value=list(IMG2IMG_MODEL_LIST.keys())[0],
+ label="Imaget2Image Model Selection",
+ elem_id="model-dropdown-img2img",
+ )
+
+ image2image_guidance_scale = gr.Slider(
+ minimum=0.1,
+ maximum=15,
+ step=0.1,
+ value=7.5,
+ label="Guidance Scale",
+ elem_id = "guidance-scale-slider-img2img"
+ )
+
+ image2image_num_inference_step = gr.Slider(
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=50,
+ label="Num Inference Step",
+ elem_id = "num-inference-step-slider-img2img"
+ )
+ with gr.Row():
+ with gr.Column():
+ image2image_scheduler = gr.Dropdown(
+ choices=SCHEDULER_LIST,
+ value=SCHEDULER_LIST[0],
+ label="Scheduler",
+ elem_id="scheduler-dropdown-img2img",
+ )
+ image2image_num_images_per_prompt = gr.Slider(
+ minimum=1,
+ maximum=30,
+ step=1,
+ value=1,
+ label="Number Of Images",
+ )
+
+ image2image_seed_generator = gr.Slider(
+ label="Seed(0 for random)",
+ minimum=0,
+ maximum=1000000,
+ value=0,
+ elem_id="seed-slider-img2img",
+ )
+
+ image2image_predict_button = gr.Button(value="Generator")
+
+ with gr.Column():
+ output_image = gr.Gallery(
+ label="Generated images",
+ show_label=False,
+ elem_id="gallery",
+ ).style(grid=(1, 2))
+
+ with gr.Group(elem_id="container-advanced-btns"):
+ with gr.Group(elem_id="share-btn-container"):
+ community_icon_html, loading_icon_html = get_community_loading_icon("img2img")
+ community_icon = gr.HTML(community_icon_html)
+ loading_icon = gr.HTML(loading_icon_html)
+ share_button = gr.Button("Save artwork", elem_id="share-btn-img2img")
+
+ image2image_predict_button.click(
+ fn=StableDiffusionImage2ImageGenerator().generate_image,
+ inputs=[
+ image2image_image_file,
+ image2image_model_path,
+ image2image_prompt,
+ image2image_negative_prompt,
+ image2image_num_images_per_prompt,
+ image2image_scheduler,
+ image2image_guidance_scale,
+ image2image_num_inference_step,
+ image2image_seed_generator,
+ ],
+ outputs=[output_image],
+ )
+
+ image2image_generate_prompt_button.click(
+ fn=generate,
+ inputs=[image2image_prompt],
+ outputs=[image2image_generated_prompt],
+ )
+
+
+ return demo
+
+
\ No newline at end of file
diff --git a/blocks/inpainting.py b/blocks/inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..a61d2d487b76ad2384a087d07bf8176fc6b19050
--- /dev/null
+++ b/blocks/inpainting.py
@@ -0,0 +1,219 @@
+import gradio as gr
+from diffusers import DiffusionPipeline,StableDiffusionInpaintPipeline
+import torch
+from .utils.prompt2prompt import generate
+from .utils.device import get_device
+from .utils.schedulers import SCHEDULER_LIST, get_scheduler_list
+from .download import get_share_js, CSS, get_community_loading_icon
+
+INPAINT_MODEL_LIST = {
+ "Stable Diffusion 2" : "stabilityai/stable-diffusion-2-inpainting",
+ "Stable Diffusion 1" : "runwayml/stable-diffusion-inpainting",
+}
+
+class StableDiffusionInpaintGenerator:
+ def __init__(self):
+ self.pipe = None
+
+ def load_model(self, model_path, scheduler):
+ model_path = INPAINT_MODEL_LIST[model_path]
+ if self.pipe is None:
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ model_path, torch_dtype=torch.float32
+ )
+ device = get_device()
+ self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler)
+ self.pipe.to(device)
+ #self.pipe.enable_attention_slicing()
+ return self.pipe
+
+ def generate_image(
+ self,
+ pil_image: str,
+ model_path: str,
+ prompt: str,
+ negative_prompt: str,
+ num_images_per_prompt: int,
+ scheduler: str,
+ guidance_scale: int,
+ num_inference_step: int,
+ height: int,
+ width: int,
+ seed_generator=0,
+ ):
+
+ image = pil_image["image"].convert("RGB").resize((width, height))
+ mask_image = pil_image["mask"].convert("RGB").resize((width, height))
+
+ pipe = self.load_model(model_path,scheduler)
+
+ if seed_generator == 0:
+ random_seed = torch.randint(0, 1000000, (1,))
+ generator = torch.manual_seed(random_seed)
+ else:
+ generator = torch.manual_seed(seed_generator)
+
+ output = pipe(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_image,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ num_inference_steps=num_inference_step,
+ guidance_scale=guidance_scale,
+ generator=generator,
+ ).images
+
+ return output
+
+
+ def app():
+ demo = gr.Blocks(css=CSS)
+ with demo:
+ with gr.Row():
+ with gr.Column():
+ stable_diffusion_inpaint_image_file = gr.Image(
+ source="upload",
+ tool="sketch",
+ elem_id="image-upload-inpainting",
+ type="pil",
+ label="Upload",
+
+ ).style(height=260)
+
+ stable_diffusion_inpaint_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Prompt",
+ show_label=False,
+ elem_id="prompt-text-input-inpainting",
+ value=''
+ )
+
+ stable_diffusion_inpaint_negative_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Negative Prompt",
+ show_label=False,
+ elem_id = "negative-prompt-text-input-inpainting",
+ value=''
+ )
+ # add button for generating a prompt from the prompt
+ stable_diffusion_inpaint_generate = gr.Button(
+ label="Generate Prompt",
+ type="primary",
+ align="center",
+ value = "Generate Prompt"
+ )
+
+ # show a text box with the generated prompt
+ stable_diffusion_inpaint_generated_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Generated Prompt",
+ show_label=False,
+ )
+
+ stable_diffusion_inpaint_model_id = gr.Dropdown(
+ choices=list(INPAINT_MODEL_LIST.keys()),
+ value=list(INPAINT_MODEL_LIST.keys())[0],
+ label="Inpaint Model Selection",
+ elem_id="model-dropdown-inpainting",
+ )
+ with gr.Row():
+ with gr.Column():
+ stable_diffusion_inpaint_guidance_scale = gr.Slider(
+ minimum=0.1,
+ maximum=15,
+ step=0.1,
+ value=7.5,
+ label="Guidance Scale",
+ elem_id = "guidance-scale-slider-inpainting"
+ )
+
+ stable_diffusion_inpaint_num_inference_step = gr.Slider(
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=50,
+ label="Num Inference Step",
+ elem_id = "num-inference-step-slider-inpainting"
+ )
+
+ stable_diffusion_inpiant_num_images_per_prompt = gr.Slider(
+ minimum=1,
+ maximum=10,
+ step=1,
+ value=1,
+ label="Number Of Images",
+ )
+
+ with gr.Row():
+ with gr.Column():
+ stable_diffusion_inpaint_scheduler = gr.Dropdown(
+ choices=SCHEDULER_LIST,
+ value=SCHEDULER_LIST[0],
+ label="Scheduler",
+ elem_id="scheduler-dropdown-inpainting",
+ )
+
+ stable_diffusion_inpaint_size = gr.Slider(
+ minimum=128,
+ maximum=1280,
+ step=32,
+ value=512,
+ label="Image Size",
+ elem_id="image-size-slider-inpainting",
+ )
+
+ stable_diffusion_inpaint_seed_generator = gr.Slider(
+ label="Seed(0 for random)",
+ minimum=0,
+ maximum=1000000,
+ value=0,
+ elem_id="seed-slider-inpainting",
+ )
+
+ stable_diffusion_inpaint_predict = gr.Button(
+ value="Generator"
+ )
+
+ with gr.Column():
+ output_image = gr.Gallery(
+ label="Generated images",
+ show_label=False,
+ elem_id="gallery-inpainting",
+ ).style(grid=(1, 2))
+
+ with gr.Group(elem_id="container-advanced-btns"):
+ with gr.Group(elem_id="share-btn-container"):
+ community_icon_html, loading_icon_html = get_community_loading_icon("inpainting")
+ community_icon = gr.HTML(community_icon_html)
+ loading_icon = gr.HTML(loading_icon_html)
+ share_button = gr.Button("Save artwork", elem_id="share-btn-inpainting")
+
+ stable_diffusion_inpaint_predict.click(
+ fn=StableDiffusionInpaintGenerator().generate_image,
+ inputs=[
+ stable_diffusion_inpaint_image_file,
+ stable_diffusion_inpaint_model_id,
+ stable_diffusion_inpaint_prompt,
+ stable_diffusion_inpaint_negative_prompt,
+ stable_diffusion_inpiant_num_images_per_prompt,
+ stable_diffusion_inpaint_scheduler,
+ stable_diffusion_inpaint_guidance_scale,
+ stable_diffusion_inpaint_num_inference_step,
+ stable_diffusion_inpaint_size,
+ stable_diffusion_inpaint_size,
+ stable_diffusion_inpaint_seed_generator,
+ ],
+ outputs=[output_image],
+ )
+
+ stable_diffusion_inpaint_generate.click(
+ fn=generate,
+ inputs=[stable_diffusion_inpaint_prompt],
+ outputs=[stable_diffusion_inpaint_generated_prompt],
+ )
+
+
+
+
+ return demo
diff --git a/blocks/main.py b/blocks/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff9ac3bdfce4efb9352dcf2585aca75748fb7efc
--- /dev/null
+++ b/blocks/main.py
@@ -0,0 +1,31 @@
+import gradio as gr
+from .download import CSS
+from .inpainting import StableDiffusionInpaintGenerator
+from .text2img import StableDiffusionText2ImageGenerator
+from .img2img import StableDiffusionImage2ImageGenerator
+
+def main_box(username : str = "admin"):
+ """
+ Implement the main interface for the app which will be served
+ to the frontend.
+ """
+ # customize the share_js button by letting username
+ app = gr.Blocks(css = CSS)
+ with app:
+ with gr.Row():
+ with gr.Column():
+ with gr.Tab("Text-to-Image", id = 'text-to-image', elem_id='text-to-image-tab'):
+ StableDiffusionText2ImageGenerator.app()
+ with gr.Tab("Image-to-Image", id = 'image-to-image', elem_id='image-to-image-tab'):
+ StableDiffusionImage2ImageGenerator.app()
+ with gr.Tab("Inpainting", id = 'inpainting', elem_id = 'inpainting-tab'):
+ StableDiffusionInpaintGenerator.app()
+
+ # Add a footer that will be displayed at the bottom of the app
+
+ gr.HTML("""
+
Minerva : Only your imagination is the limit!
+ """)
+
+ app.queue(concurrency_count=2)
+ return app
\ No newline at end of file
diff --git a/blocks/text2img.py b/blocks/text2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ac5edbab1d7d934102c492b38ebddd67f898bb
--- /dev/null
+++ b/blocks/text2img.py
@@ -0,0 +1,235 @@
+import gradio as gr
+import torch
+from diffusers import StableDiffusionPipeline
+from .utils.schedulers import SCHEDULER_LIST, get_scheduler_list
+from .utils.prompt2prompt import generate
+from .utils.device import get_device
+from .download import get_share_js, community_icon_html, loading_icon_html, CSS
+
+#--- create a download button that takes the output image from gradio and downloads it
+
+TEXT2IMG_MODEL_LIST = {
+ "OpenJourney v4" : "prompthero/openjourney-v4",
+ "StableDiffusion 1.5" : "runwayml/stable-diffusion-v1-5",
+ "StableDiffusion 2.1" : "stabilityai/stable-diffusion-2-1",
+ "DreamLike 1.0" : "dreamlike-art/dreamlike-diffusion-1.0",
+ "DreamLike 2.0" : "dreamlike-art/dreamlike-photoreal-2.0",
+ "DreamShaper" : "Lykon/DreamShaper",
+ "NeverEnding-Dream" : "Lykon/NeverEnding-Dream"
+}
+
+class StableDiffusionText2ImageGenerator:
+ def __init__(self):
+ self.pipe = None
+
+ def load_model(
+ self,
+ model_path,
+ scheduler
+ ):
+ model_path = TEXT2IMG_MODEL_LIST[model_path]
+ if self.pipe is None:
+ self.pipe = StableDiffusionPipeline.from_pretrained(
+ model_path, safety_checker=None, torch_dtype=torch.float32
+ )
+
+ device = get_device()
+ self.pipe = get_scheduler_list(pipe=self.pipe, scheduler=scheduler)
+ self.pipe.to(device)
+ #self.pipe.enable_attention_slicing()
+
+ return self.pipe
+
+ def generate_image(
+ self,
+ model_path: str,
+ prompt: str,
+ negative_prompt: str,
+ num_images_per_prompt: int,
+ scheduler: str,
+ guidance_scale: int,
+ num_inference_step: int,
+ height: int,
+ width: int,
+ seed_generator=0,
+ ):
+ print("model_path", model_path)
+ print("prompt", prompt)
+ print("negative_prompt", negative_prompt)
+ print("num_images_per_prompt", num_images_per_prompt)
+ print("scheduler", scheduler)
+ print("guidance_scale", guidance_scale)
+ print("num_inference_step", num_inference_step)
+ print("height", height)
+ print("width", width)
+ print("seed_generator", seed_generator)
+
+ pipe = self.load_model(
+ model_path=model_path,
+ scheduler=scheduler,
+ )
+ if seed_generator == 0:
+ random_seed = torch.randint(0, 1000000, (1,))
+ generator = torch.manual_seed(random_seed)
+ else:
+ generator = torch.manual_seed(seed_generator)
+
+ images = pipe(
+ prompt=prompt,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ num_inference_steps=num_inference_step,
+ guidance_scale=guidance_scale,
+ generator=generator,
+ ).images
+
+ return images
+
+
+ def app(username : str = "admin"):
+ demo = gr.Blocks(css = CSS)
+ with demo:
+ with gr.Row():
+ with gr.Column():
+ text2image_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Prompt",
+ show_label=False,
+ elem_id="prompt-text-input",
+ value=''
+ )
+
+ text2image_negative_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Negative Prompt",
+ show_label=False,
+ elem_id = "negative-prompt-text-input",
+ value=''
+ )
+
+ # add button for generating a prompt from the prompt
+ text2image_prompt_generate_button = gr.Button(
+ label="Generate Prompt",
+ type="primary",
+ align="center",
+ value = "Generate Prompt"
+ )
+
+ # show a text box with the generated prompt
+ text2image_prompt_generated_prompt = gr.Textbox(
+ lines=1,
+ placeholder="Generated Prompt",
+ show_label=False,
+ )
+ with gr.Row():
+ with gr.Column():
+ text2image_model_path = gr.Dropdown(
+ choices=list(TEXT2IMG_MODEL_LIST.keys()),
+ value=list(TEXT2IMG_MODEL_LIST.keys())[0],
+ label="Text2Image Model Selection",
+ elem_id="model-dropdown",
+ )
+
+ text2image_guidance_scale = gr.Slider(
+ minimum=0.1,
+ maximum=15,
+ step=0.1,
+ value=7.5,
+ label="Guidance Scale",
+ elem_id = "guidance-scale-slider"
+ )
+
+ text2image_num_inference_step = gr.Slider(
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=50,
+ label="Num Inference Step",
+ elem_id = "num-inference-step-slider"
+ )
+ text2image_num_images_per_prompt = gr.Slider(
+ minimum=1,
+ maximum=30,
+ step=1,
+ value=1,
+ label="Number Of Images",
+ )
+ with gr.Row():
+ with gr.Column():
+
+ text2image_scheduler = gr.Dropdown(
+ choices=SCHEDULER_LIST,
+ value=SCHEDULER_LIST[0],
+ label="Scheduler",
+ elem_id="scheduler-dropdown",
+ )
+
+ text2image_size = gr.Slider(
+ minimum=128,
+ maximum=1280,
+ step=32,
+ value=512,
+ label="Image Size",
+ elem_id="image-size-slider",
+ )
+
+ text2image_seed_generator = gr.Slider(
+ label="Seed(0 for random)",
+ minimum=0,
+ maximum=1000000,
+ value=0,
+ elem_id="seed-slider",
+ )
+ text2image_predict = gr.Button(value="Generator")
+
+ with gr.Column():
+ output_image = gr.Gallery(
+ label="Generated images",
+ show_label=False,
+ elem_id="gallery",
+ ).style(grid=(1, 2), height='auto')
+
+ with gr.Group(elem_id="container-advanced-btns"):
+ with gr.Group(elem_id="share-btn-container"):
+ community_icon = gr.HTML(community_icon_html)
+ loading_icon = gr.HTML(loading_icon_html)
+ share_button = gr.Button("Save artwork", elem_id="share-btn")
+
+
+ text2image_predict.click(
+ fn=StableDiffusionText2ImageGenerator().generate_image,
+ inputs=[
+ text2image_model_path,
+ text2image_prompt,
+ text2image_negative_prompt,
+ text2image_num_images_per_prompt,
+ text2image_scheduler,
+ text2image_guidance_scale,
+ text2image_num_inference_step,
+ text2image_size,
+ text2image_size,
+ text2image_seed_generator,
+ ],
+ outputs=output_image,
+ )
+
+ text2image_prompt_generate_button.click(
+ fn=generate,
+ inputs=[text2image_prompt],
+ outputs=[text2image_prompt_generated_prompt],
+ )
+
+ # share_button.click(
+ # None,
+ # [],
+ # [],
+ # _js=get_share_js(),
+ # )
+
+ # autoclik the share button
+
+
+
+ return demo
\ No newline at end of file
diff --git a/blocks/utils/__init__.py b/blocks/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/blocks/utils/__pycache__/__init__.cpython-39.pyc b/blocks/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d208b7409dd98f7964029118abe7df906b05d000
Binary files /dev/null and b/blocks/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/blocks/utils/__pycache__/device.cpython-39.pyc b/blocks/utils/__pycache__/device.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..841f24cccd338e5da4a5328c27ef1eb2d05bf657
Binary files /dev/null and b/blocks/utils/__pycache__/device.cpython-39.pyc differ
diff --git a/blocks/utils/__pycache__/prompt2prompt.cpython-39.pyc b/blocks/utils/__pycache__/prompt2prompt.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a9f1a07c1b8386a5db5c76050e3ee6db5afae1f5
Binary files /dev/null and b/blocks/utils/__pycache__/prompt2prompt.cpython-39.pyc differ
diff --git a/blocks/utils/__pycache__/schedulers.cpython-39.pyc b/blocks/utils/__pycache__/schedulers.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3299fa7fb7e0082568e28018ad7afeef141ff27
Binary files /dev/null and b/blocks/utils/__pycache__/schedulers.cpython-39.pyc differ
diff --git a/blocks/utils/device.py b/blocks/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5e27d43454f1ad2557c5f3971f93893ceff4ae3
--- /dev/null
+++ b/blocks/utils/device.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def get_device(device = None):
+ if device is None:
+ # get cuda -> mps -> cpu
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif torch.backends.mps.is_available():
+ if torch.backends.mps.is_built():
+ device = "mps"
+ else:
+ device = "cpu"
+ else:
+ device = "cpu"
+ return device
\ No newline at end of file
diff --git a/blocks/utils/prompt2prompt.py b/blocks/utils/prompt2prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d50b450c9e048b808156fe22b48f850f7f06331
--- /dev/null
+++ b/blocks/utils/prompt2prompt.py
@@ -0,0 +1,23 @@
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import pipeline, set_seed
+import re
+import random
+gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
+
+def generate(starting_text):
+ seed = random.randint(100, 1000000)
+ set_seed(seed)
+
+ response = gpt2_pipe(starting_text, max_length=(len(starting_text) + random.randint(60, 90)), num_return_sequences=4)
+ response_list = []
+ for x in response:
+ resp = x['generated_text'].strip()
+ if resp != starting_text and len(resp) > (len(starting_text) + 4) and resp.endswith((":", "-", "—")) is False:
+ response_list.append(resp+'\n')
+
+ response_end = "\n".join(response_list)
+ response_end = re.sub('[^ ]+\.[^ ]+','', response_end)
+ response_end = response_end.replace("<", "").replace(">", "")
+
+ if response_end != "":
+ return response_end
diff --git a/blocks/utils/schedulers.py b/blocks/utils/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..926bd3d43cb072f0229578e33db2858c7762ee69
--- /dev/null
+++ b/blocks/utils/schedulers.py
@@ -0,0 +1,47 @@
+from diffusers import (
+ DDIMScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ LMSDiscreteScheduler,
+ DPMSolverMultistepScheduler
+)
+
+SCHEDULER_LIST = [
+ "DDIM",
+ "EulerA",
+ "Euler",
+ "LMS",
+ "Heun",
+ "DPMMultistep",
+]
+
+
+def get_scheduler_list(pipe, scheduler):
+ if scheduler == SCHEDULER_LIST[0]:
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+
+ elif scheduler == SCHEDULER_LIST[1]:
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
+ pipe.scheduler.config
+ )
+
+ elif scheduler == SCHEDULER_LIST[2]:
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
+ pipe.scheduler.config
+ )
+
+ elif scheduler == SCHEDULER_LIST[3]:
+ pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
+
+ elif scheduler == SCHEDULER_LIST[4]:
+ pipe.scheduler = HeunDiscreteScheduler.from_config(
+ pipe.scheduler.config
+ )
+
+ elif scheduler == SCHEDULER_LIST[5]:
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipe.scheduler.config
+ )
+
+ return pipe
\ No newline at end of file
diff --git a/diffmodels/__init__.py b/diffmodels/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3f735f33280d6b9b40e8e9129f3621794e489b
--- /dev/null
+++ b/diffmodels/__init__.py
@@ -0,0 +1,25 @@
+from .diffusion_utils import build_pipeline
+
+NAME_TO_MODEL = {
+ "stable-diffusion-v1-4":
+ {
+ "model" : "CompVis/stable-diffusion-v1-4",
+ "unet" : "CompVis/stable-diffusion-v1-4",
+ "tokenizer" : "openai/clip-vit-large-patch14",
+ "text_encoder" : "openai/clip-vit-large-patch14",
+ },
+ "stable_diffusion_v2_1":
+ {
+ "model" : "stabilityai/stable-diffusion-2-1",
+ "unet" : "stabilityai/stable-diffusion-2-1",
+ "tokenizer" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ "text_encoder" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ }
+}
+
+def get_model(model_name):
+ model = NAME_TO_MODEL.get(model_name)
+ if model is None:
+ raise ValueError(f"Model name {model_name} not found. Available models: {list(NAME_TO_MODEL.keys())}")
+ vae, tokenizer, text_encoder, unet = build_pipeline(model["model"], model["tokenizer"], model["text_encoder"], model["unet"])
+ return vae, tokenizer, text_encoder, unet
\ No newline at end of file
diff --git a/diffmodels/__pycache__/__init__.cpython-310.pyc b/diffmodels/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cb1f31977ffdd98571638bfdc1585024865f6fd
Binary files /dev/null and b/diffmodels/__pycache__/__init__.cpython-310.pyc differ
diff --git a/diffmodels/__pycache__/__init__.cpython-39.pyc b/diffmodels/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0daccf864cf8ee03645aff7e6ded950f21c4fe7
Binary files /dev/null and b/diffmodels/__pycache__/__init__.cpython-39.pyc differ
diff --git a/diffmodels/__pycache__/diffusion_utils.cpython-310.pyc b/diffmodels/__pycache__/diffusion_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e5a1f0ed42f10c696b0c069aae3af09bc762e4e
Binary files /dev/null and b/diffmodels/__pycache__/diffusion_utils.cpython-310.pyc differ
diff --git a/diffmodels/__pycache__/diffusion_utils.cpython-39.pyc b/diffmodels/__pycache__/diffusion_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7d2ccbc53bc4fe3657079b4e10084870e17e295
Binary files /dev/null and b/diffmodels/__pycache__/diffusion_utils.cpython-39.pyc differ
diff --git a/diffmodels/__pycache__/simple_diffusion.cpython-310.pyc b/diffmodels/__pycache__/simple_diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cce2ffd19a7fa667ff19b03b3d0e53fdbcf70400
Binary files /dev/null and b/diffmodels/__pycache__/simple_diffusion.cpython-310.pyc differ
diff --git a/diffmodels/diffusion_utils.py b/diffmodels/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..512a477f5a7b3caa61191da014618a2670f83ef7
--- /dev/null
+++ b/diffmodels/diffusion_utils.py
@@ -0,0 +1,218 @@
+# Utility class for loading and using diffusers model
+import diffusers
+import transformers
+
+import torch
+from typing import Union
+import os
+import warnings
+import numpy as np
+from PIL import Image
+import tqdm
+from copy import deepcopy
+import matplotlib.pyplot as plt
+
+def build_generator(
+ device : torch.device,
+ seed : int,
+):
+ """
+ Build a torch.Generator with a given seed.
+ """
+ generator = torch.Generator(device).manual_seed(seed)
+ return generator
+
+def load_stablediffusion_model(
+ model_id : Union[str, os.PathLike],
+ device : torch.device,
+ ):
+ """
+ Load a complete diffusion model from a model id.
+ Returns a tuple of the model and a torch.Generator if seed is not None.
+
+ """
+ pipe = diffusers.DiffusionPipeline.from_pretrained(
+ model_id,
+ revision="fp16",
+ torch_dtype=torch.float16,
+ use_auth_token=True,
+ )
+ pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ try:
+ pipe = pipe.to(device)
+ except:
+ warnings.warn(
+ f'Could not load model to device:{device}. Using CPU instead.'
+ )
+ pipe = pipe.to('cpu')
+ device = 'cpu'
+
+ return pipe
+
+
+def visualize_image_grid(
+ imgs : np.array,
+ rows : int,
+ cols : int):
+
+ assert len(imgs) == rows*cols
+
+ # create grid
+ w, h = imgs[0].size # assuming each image is the same size
+
+ grid = Image.new('RGB', size=(cols*w, rows*h))
+
+ for i,img in enumerate(imgs):
+ grid.paste(img, box=(i%cols*w, i//cols*h))
+ return grid
+
+
+def build_pipeline(
+ autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
+ tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
+ text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
+ unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
+ device : torch.device = torch.device('cuda'),
+ ):
+ """
+ Create a pipeline for StableDiffusion by loading the model and component seperetely.
+ Arguments:
+ autoencoder: path to model that autoencoder will be loaded from
+ tokenizer: path to tokenizer
+ text_encoder: path to text_encoder
+ unet: path to unet
+ """
+ # Load the VAE for encoding images into the latent space
+ vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae')
+
+ # Load tokenizer & text encoder for encoding text into the latent space
+ tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer)
+ text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder)
+
+ # Use the UNet model for conditioning the diffusion process
+ unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet')
+
+ # Move all the components to device
+ vae = vae.to(device)
+ text_encoder = text_encoder.to(device)
+ unet = unet.to(device)
+
+ return vae, tokenizer, text_encoder, unet
+
+#TODO : Add negative prompting
+def custom_stablediffusion_inference(
+ vae,
+ tokenizer,
+ text_encoder,
+ unet,
+ noise_scheduler,
+ prompt : list,
+ device : torch.device,
+ num_inference_steps = 100,
+ image_size = (512,512),
+ guidance_scale = 8,
+ seed = 42,
+ return_image_step = 5,
+ ):
+ # Get the text embeddings that will condition the diffusion process
+ if isinstance(prompt,str):
+ prompt = [prompt]
+
+ batch_size = len(prompt)
+ text_input = tokenizer(
+ prompt,
+ padding = 'max_length',
+ truncation = True,
+ max_length = tokenizer.model_max_length,
+ return_tensors = 'pt').to(device)
+
+ text_embeddings = text_encoder(
+ text_input.input_ids.to(device)
+ )[0]
+
+ # Get the text embeddings for classifier-free guidance
+ max_length = text_input.input_ids.shape[-1]
+ empty = [""] * batch_size
+ uncond_input = tokenizer(
+ empty,
+ padding = 'max_length',
+ truncation = True,
+ max_length = max_length,
+ return_tensors = 'pt').to(device)
+
+ uncond_embeddings = text_encoder(
+ uncond_input.input_ids.to(device)
+ )[0]
+
+ # Concatenate the text embeddings to get the conditioning vector
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # Generate initial noise
+ latents = torch.randn(
+ (1, unet.in_channels, image_size[0] // 8, image_size[1] // 8),
+ generator=torch.manual_seed(seed) if seed is not None else None
+ )
+ print(latents.shape)
+
+ latents = latents.to(device)
+
+ # Initialize scheduler for noise generation
+ noise_scheduler.set_timesteps(num_inference_steps)
+
+ latents = latents * noise_scheduler.init_noise_sigma
+
+ noise_scheduler.set_timesteps(num_inference_steps)
+ for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)):
+ # If no text embedding is provided (classifier-free guidance), extend the conditioning vector
+ latent_model_input = torch.cat([latents] * 2)
+
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
+
+ with torch.no_grad():
+ # Get the noise prediction from the UNet
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample
+
+ # Perform guidance from the text embeddings
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # Compute the previously noisy sample x_t -> x_t-1
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
+
+ # Now that latent is generated from a noise, use unet decoder to generate images
+ if i % return_image_step == 0:
+ with torch.no_grad():
+ latents_copy = deepcopy(latents)
+ image = vae.decode(1/0.18215 * latents_copy).sample
+
+ image = (image / 2 + 0.5).clamp(0,1)
+ image = image.detach().cpu().permute(0,2,3,1).numpy() # bxhxwxc
+ images = (image * 255).round().astype("uint8")
+
+ pil_images = [Image.fromarray(img) for img in images]
+
+ yield pil_images[0]
+
+ yield pil_images[0]
+
+if __name__ == "__main__":
+ device = torch.device("cpu")
+ model_id = "stabilityai/stable-diffusion-2-1"
+ tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+ #noise_scheduler = diffusers.LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+ noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler")
+ prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\
+ cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\
+ render, architectural HD, HQ, 4k, 8k"
+
+ vae, tokenizer, text_encoder, unet = build_pipeline(
+ autoencoder = model_id,
+ tokenizer=tokenizer_id,
+ text_encoder=tokenizer_id,
+ unet=model_id,
+ device=device,
+ )
+ image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None)
+ for i, image in enumerate(image_iter):
+ image.save(f"step_{i}.png")
+
diff --git a/diffmodels/simple_diffusion.py b/diffmodels/simple_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..088998d09310b4a0df4332c5d997d8c01dabecdb
--- /dev/null
+++ b/diffmodels/simple_diffusion.py
@@ -0,0 +1,309 @@
+import diffusers
+import transformers
+import utils.log
+import torch
+import PIL
+from typing import Union, Dict, Any, Optional, List, Tuple, Callable
+import os
+import re
+
+class SimpleDiffusion(diffusers.DiffusionPipeline):
+ """
+ An unified interface for diffusion models. This allow us to use :
+ - txt2img
+ - img2img
+ - inpainting
+ - unconditional image generation
+
+ This class is highly inspired from the Stable-Diffusion-Mega pipeline.
+ DiffusionPipeline class allow us to load/download all the models hubbed by HuggingFace with an ease. Read more information
+ about the DiffusionPipeline class here: https://huggingface.co./transformers/main_classes/pipelines.html#transformers.DiffusionPipeline
+
+ Args:
+ logger (:obj:`utils.log.Logger`):
+ The logger to use for logging any information.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]):
+ Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co./runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+
+ """
+ def __init__(
+ self,
+ vae: diffusers.AutoencoderKL,
+ text_encoder: transformers.CLIPTextModel,
+ tokenizer: transformers.CLIPTokenizer,
+ unet: diffusers.UNet2DConditionModel,
+ scheduler: Union[diffusers.DDIMScheduler, diffusers.PNDMScheduler, diffusers.LMSDiscreteScheduler],
+ safety_checker: diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker,
+ feature_extractor: transformers.CLIPFeatureExtractor,
+ prompt_generation = "succinctly/text2image-prompt-generator"
+ ):
+ super().__init__()
+ self._logger = None
+ self.register_modules( # already defined in ConfigMixin class, from_pretrained loads these modules
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+
+ )
+ self._generated_prompts = []
+ self._enable_prompt_generation = False
+ if prompt_generation:
+ self._enable_prompt_generation = True
+ self._prompt_generator = transformers.pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
+
+ def _generate_prompt(self, prompt, **kwargs):
+ """
+ Generate a prompt from a given text.
+ Args:
+ prompt (str): The text to generate a prompt from.
+ **kwargs: Additional keyword arguments passed to the prompt generator pipeline.
+ """
+ max_length = kwargs.pop("max_length", None)
+ num_return_sequences = kwargs.pop("num_return_sequences", None)
+
+ prompt = self._prompt_generator(prompt, max_length=max_length, num_return_sequences=num_return_sequences)
+ prompt = self._process_prompt(prompt, **kwargs)
+ return prompt[0]['generated_text']
+
+ def _process_prompt(self,original_prompt, prompt_list):
+ # TODO : Add documentation; add more prompt processing
+ response_list = []
+ for x in prompt_list:
+ resp = x['generated_text'].strip()
+ if resp != original_prompt and len(resp) > (len(original_prompt) + 4) and resp.endswith((":", "-", "—")) is False:
+ response_list.append(resp+'\n')
+
+ response_end = "\n".join(response_list)
+ response_end = re.sub('[^ ]+\.[^ ]+','', response_end)
+ response_end = response_end.replace("<", "").replace(">", "")
+
+ if response_end != "":
+ return response_end
+
+ # Following components are required for the DiffusionPipeline class - but they exist in the StableDiffusionModel class
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+ Refer to the [StableDiffusionModel](https://github.com/huggingface/diffusers/blob/main/examples/community/stable_diffusion_mega.py) repo
+ for more information.
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ if self._logger is not None:
+ self._logger.info("Attention slicing enabled!")
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ if self._logger is not None:
+ self._logger.info("Attention slicing disabled!")
+ self.enable_attention_slicing(None)
+
+ def set_logger(self, logger):
+ r"""
+ Set logger. This is useful to log information about the model.
+ """
+ self._logger = logger
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ # Return the non-private variables
+ return {k : getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
+
+ @torch.no_grad()
+ def inpaint(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ if self._enable_prompt_generation:
+ prompt = self._generate_prompt(p, **kwargs)[0]
+ self._logger.info(f"Generated prompt: {prompt}")
+ # For more information on how this function works, please see: https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
+ return diffusers.StableDiffusionInpaintPipelineLegacy(**self.components)(
+ prompt=prompt,
+ init_image=init_image,
+ mask_image=mask_image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ )
+
+ @torch.no_grad()
+ def img2img(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ if self._enable_prompt_generation:
+ prompt = self._generate_prompt(p, **kwargs)[0]
+ self._logger.info(f"Generated prompt: {prompt}")
+ # For more information on how this function works, please see: https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
+ return diffusers.StableDiffusionImg2ImgPipeline(**self.components)(
+ prompt=prompt,
+ init_image=init_image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+
+ @torch.no_grad()
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ if self._enable_prompt_generation:
+ prompt = self._generate_prompt(p, **kwargs)[0]
+ self._logger.info(f"Generated prompt: {prompt}")
+
+ # For more information on how this function https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
+ return diffusers.StableDiffusionPipeline(**self.components)(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+
+ @torch.no_grad()
+ def upscale(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ num_inference_steps: Optional[int] = 75,
+ guidance_scale: Optional[float] = 9.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ ):
+ """
+ Upscale an image using the StableDiffusionUpscalePipeline.
+ """
+ if self._enable_prompt_generation:
+ prompt = self._generate_prompt(p, **kwargs)[0]
+ self._logger.info(f"Generated prompt: {prompt}")
+
+ return diffusers.StableDiffusionUpscalePipeline(**self.components)(
+ prompt=prompt,
+ image=init_image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type = output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps)
+
+ def set_scheduler(self, scheduler: Union[diffusers.DDIMScheduler, diffusers.PNDMScheduler, diffusers.LMSDiscreteScheduler, diffusers.EulerDiscreteScheduler]):
+ """
+ Set the scheduler for the pipeline. This is useful for controlling the diffusion process.
+ Args:
+ scheduler (Union[diffusers.DDIMScheduler, diffusers.PNDMScheduler, diffusers.LMSDiscreteScheduler]): The scheduler to use.
+
+ """
+ self.components["scheduler"] = scheduler
\ No newline at end of file
diff --git a/diffmodels/textual_inversion.py b/diffmodels/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f641eabc1567d3bd4f20253c6f9517e2d9970ae
--- /dev/null
+++ b/diffmodels/textual_inversion.py
@@ -0,0 +1,269 @@
+#@title Import required libraries
+import argparse
+import itertools
+import math
+import os
+import random
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.utils.data import Dataset
+
+import PIL
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+pretrained_model_name_or_path = "stabilityai/stable-diffusion-2" #@param ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] {allow-input: true}
+
+# example image urls
+urls = [
+ "https://huggingface.co./datasets/valhalla/images/resolve/main/2.jpeg",
+ "https://huggingface.co./datasets/valhalla/images/resolve/main/3.jpeg",
+ "https://huggingface.co./datasets/valhalla/images/resolve/main/5.jpeg",
+ "https://huggingface.co./datasets/valhalla/images/resolve/main/6.jpeg",
+ ]
+
+# what is it that you are teaching? `object` enables you to teach the model a new object to be used, `style` allows you to teach the model a new style one can use.
+what_to_teach = "object" #@param ["object", "style"]
+# the token you are going to use to represent your new concept (so when you prompt the model, you will say "A `` in an amusement park"). We use angle brackets to differentiate a token from other words/tokens, to avoid collision.
+placeholder_token = "" #@param {type:"string"}
+# is a word that can summarise what your new concept is, to be used as a starting point
+initializer_token = "toy" #@param {type:"string"}
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows*cols
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(cols*w, rows*h))
+ grid_w, grid_h = grid.size
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i%cols*w, i//cols*h))
+ return grid
+
+#@title Setup the prompt templates for training
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+#@title Setup the dataset
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ h, w, = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+#@title Load the tokenizer and add the placeholder token as a additional special token.
+tokenizer = CLIPTokenizer.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="tokenizer",
+)
+
+# Add the placeholder token in tokenizer
+num_added_tokens = tokenizer.add_tokens(placeholder_token)
+if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+
+
+#@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token
+# Convert the initializer_token, placeholder_token to ids
+token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
+# Check if initializer_token is a single token or a sequence of tokens
+if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+initializer_token_id = token_ids[0]
+placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
+
+
+#@title Load the Stable Diffusion model
+# Load models and create wrapper for stable diffusion
+# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
+# del pipeline
+text_encoder = CLIPTextModel.from_pretrained(
+ pretrained_model_name_or_path, subfolder="text_encoder"
+)
+vae = AutoencoderKL.from_pretrained(
+ pretrained_model_name_or_path, subfolder="vae"
+)
+unet = UNet2DConditionModel.from_pretrained(
+ pretrained_model_name_or_path, subfolder="unet"
+)
+
+text_encoder.resize_token_embeddings(len(tokenizer))
+
+token_embeds = text_encoder.get_input_embeddings().weight.data
+token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+# Freeze vae and unet
+freeze_params(vae.parameters())
+freeze_params(unet.parameters())
+# Freeze all parameters except for the token embeddings in text encoder
+params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+)
+freeze_params(params_to_freeze)
+
+train_dataset = TextualInversionDataset(
+ data_root=save_path,
+ tokenizer=tokenizer,
+ size=vae.sample_size,
+ placeholder_token=placeholder_token,
+ repeats=100,
+ learnable_property=what_to_teach, #Option selected above between object and style
+ center_crop=False,
+ set="train",
+)
+
+def create_dataloader(train_batch_size=1):
+ return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
+
+noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler")
+
+# TODO: Add training scripts
diff --git a/image_0.png b/image_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..6304f0a85be5206347837ce48e9cbdb1be84da89
Binary files /dev/null and b/image_0.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6eeb217a9a4fb385180322bb9c56b615f068546
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,80 @@
+accelerate==0.19.0
+aiofiles==23.1.0
+aiohttp==3.8.4
+aiosignal==1.3.1
+altair==5.0.0
+anyio==3.6.2
+async-timeout==4.0.2
+attrs==23.1.0
+certifi==2023.5.7
+charset-normalizer==3.1.0
+click==8.1.3
+contourpy==1.0.7
+cycler==0.11.0
+diffusers==0.16.1
+fastapi==0.95.2
+ffmpy==0.3.0
+filelock==3.12.0
+fonttools==4.39.4
+frozenlist==1.3.3
+fsspec==2023.5.0
+gradio==3.30.0
+gradio_client==0.2.4
+h11==0.14.0
+httpcore==0.17.0
+httpx==0.24.0
+huggingface-hub==0.14.1
+idna==3.4
+importlib-metadata==6.6.0
+importlib-resources==5.12.0
+Jinja2==3.1.2
+jsonschema==4.17.3
+kiwisolver==1.4.4
+linkify-it-py==2.0.2
+markdown-it-py==2.2.0
+MarkupSafe==2.1.2
+matplotlib==3.7.1
+mdit-py-plugins==0.3.3
+mdurl==0.1.2
+mpmath==1.3.0
+multidict==6.0.4
+networkx==3.1
+numpy==1.24.3
+opencv-python==4.7.0.72
+orjson==3.8.10
+packaging==23.1
+pandas==2.0.1
+Pillow==9.5.0
+pip==23.1.2
+psutil==5.9.5
+pydantic==1.10.7
+pydub==0.25.1
+Pygments==2.15.1
+pyparsing==3.0.9
+pyrsistent==0.19.3
+python-dateutil==2.8.2
+python-multipart==0.0.6
+pytz==2023.3
+PyYAML==6.0
+regex==2023.5.5
+requests==2.30.0
+semantic-version==2.10.0
+setuptools==67.7.2
+six==1.16.0
+sniffio==1.3.0
+starlette==0.27.0
+sympy==1.12
+tokenizers==0.13.3
+toolz==0.12.0
+torch==2.0.1
+tqdm==4.65.0
+transformers==4.29.1
+typing_extensions==4.5.0
+tzdata==2023.3
+uc-micro-py==1.0.2
+urllib3==2.0.2
+uvicorn==0.22.0
+websockets==11.0.3
+wheel==0.40.0
+yarl==1.9.2
+zipp==3.15.0
diff --git a/static/load_from_artwork.js b/static/load_from_artwork.js
new file mode 100644
index 0000000000000000000000000000000000000000..e2d2d2ab76642acd749dfab58196e69f8a6bc715
--- /dev/null
+++ b/static/load_from_artwork.js
@@ -0,0 +1,46 @@
+async () => {
+ const urlParams = new URLSearchParams(window.location.search);
+ const username = urlParams.get('username');
+ const artworkId = urlParams.get('artworkId');
+
+ const LOAD_URL = `http://127.0.0.1:5000/v1/api/load-parameters/${artworkId}`;
+ const response = await fetch(LOAD_URL, {
+ method: 'GET',
+ headers: {
+ 'X-Requested-With': 'XMLHttpRequest',
+ }
+ });
+
+ // Check if the response is okay
+ if (!response.ok) {
+ console.error("An error occurred while fetching the parameters.");
+ return;
+ }
+
+ const parameters = await response.json(); // Assuming you're getting a JSON response
+
+ // Get the necessary elements
+ const gradioEl = document.querySelector('gradio-app');
+ const promptInput = gradioEl.querySelector('#prompt-text-input textarea');
+ const negativePromptInput = gradioEl.querySelector('#negative-prompt-text-input textarea');
+
+ // Get the slider inputs
+ const guidanceScaleInput = gradioEl.querySelector('#guidance-scale-slider input');
+ const numInferenceStepInput = gradioEl.querySelector('#num-inference-step-slider input');
+ const imageSizeInput = gradioEl.querySelector('#image-size-slider input');
+ const seedInput = gradioEl.querySelector('#seed-slider input');
+
+ // Get the dropdown inputs
+ const modelDropdown = gradioEl.querySelector('#model-dropdown input');
+ const schedulerDropdown = gradioEl.querySelector('#scheduler-dropdown input');
+
+ // Set the values based on the parameters received
+ promptInput.value = parameters.text_prompt;
+ negativePromptInput.value = parameters.negative_prompt;
+ guidanceScaleInput.value = parameters.model_guidance_scale;
+ numInferenceStepInput.value = parameters.model_num_steps;
+ imageSizeInput.value = parameters.model_image_size;
+ seedInput.value = parameters.seed;
+ modelDropdown.value = parameters.model_name;
+ schedulerDropdown.value = parameters.scheduler_name;
+}
diff --git a/static/save_artwork.js b/static/save_artwork.js
new file mode 100644
index 0000000000000000000000000000000000000000..887645a7ea5d0e3de4e5f1753f101f01fc330476
--- /dev/null
+++ b/static/save_artwork.js
@@ -0,0 +1,63 @@
+async () => {
+ // Get the username from the URL itself
+ const gradioEl = document.querySelector('gradio-app');
+ const imgEls = gradioEl.querySelectorAll('#gallery img');
+
+ // Get the necessary fields
+ const promptTxt = gradioEl.querySelector('#prompt-text-input textarea').value;
+ const negativePromptTxt = gradioEl.querySelector('#negative-prompt-text-input textarea').value;
+
+ // Get values from the sliders
+ const modelGuidanceScale = parseFloat(gradioEl.querySelector('#guidance-scale-slider input').value);
+
+ const numSteps = parseInt(gradioEl.querySelector('#num-inference-step-slider input').value);
+ const imageSize = parseInt(gradioEl.querySelector('#image-size-slider input').value);
+ const seed = parseInt(gradioEl.querySelector('#seed-slider input').value);
+
+ // Get the values from dropdowns
+ const modelName = gradioEl.querySelector('#model-dropdown input').value;
+ const schedulerName = gradioEl.querySelector('#scheduler-dropdown input').value;
+
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
+
+ if(!imgEls.length){
+ return;
+ };
+
+ shareBtnEl.style.pointerEvents = 'none';
+ shareIconEl.style.display = 'none';
+ loadingIconEl.style.removeProperty('display');
+ const files = await Promise.all(
+ [...imgEls].map(async (imgEl) => {
+ const res = await fetch(imgEl.src);
+ const blob = await res.blob();
+ const fileSrc = imgEl.src.split('/').pop(); // Get the file name from the img src path
+ const imgId = Date.now();
+ const fileName = `${fileSrc}-${imgId}.jpg`; // Fixed fileName construction
+ return new File([blob], fileName, { type: 'image/jpeg' });
+ })
+ );
+
+ // Ensure that only one image is uploaded by taking the first element if there are multiple
+ if (files.length > 1) {
+ files.splice(1, files.length - 1);
+ }
+
+ const urls = await Promise.all(files.map((f) => uploadFile(
+ f,
+ promptTxt,
+ negativePromptTxt,
+ modelName,
+ schedulerName,
+ modelGuidanceScale,
+ numSteps,
+ imageSize,
+ seed,
+ )));
+
+ shareBtnEl.style.removeProperty('pointer-events');
+ shareIconEl.style.removeProperty('display');
+ loadingIconEl.style.display = 'none';
+ }
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79f5ab8bb19adf22d7f6359debb472b763a69a00
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/utils/__pycache__/device.cpython-310.pyc b/utils/__pycache__/device.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd158d91d29d279588b4f43a2caec19697f7d13e
Binary files /dev/null and b/utils/__pycache__/device.cpython-310.pyc differ
diff --git a/utils/__pycache__/image.cpython-310.pyc b/utils/__pycache__/image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71595b047953b5dc6f36d80c68dc6aa88c301f8a
Binary files /dev/null and b/utils/__pycache__/image.cpython-310.pyc differ
diff --git a/utils/__pycache__/log.cpython-310.pyc b/utils/__pycache__/log.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06e4ab20a87da667f8634a8b66a9d1930d0134f8
Binary files /dev/null and b/utils/__pycache__/log.cpython-310.pyc differ
diff --git a/utils/__pycache__/prompt2prompt.cpython-310.pyc b/utils/__pycache__/prompt2prompt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5e577fa78b2bf301f02da8a9d2b6e363e7febfd
Binary files /dev/null and b/utils/__pycache__/prompt2prompt.cpython-310.pyc differ
diff --git a/utils/device.py b/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..a707c355f9264424611d7728626bd253cc832c8d
--- /dev/null
+++ b/utils/device.py
@@ -0,0 +1,22 @@
+from typing import Union
+import torch
+
+def set_device(device : Union[str, torch.device]) -> torch.device:
+ """
+ Set the device to use for inference. Recommended to use GPU.
+ Arguments:
+ device Union[str, torch.device]
+ The device to use for inference. Can be either a string or a torch.device object.
+
+ Returns:
+ torch.device
+ The device to use for inference.
+ """
+ if isinstance(device, str):
+ if device == 'cuda' and torch.cuda.is_available():
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ elif device == 'mps' and torch.backends.mps.is_built():
+ device = torch.device('mps')
+ else:
+ device = torch.device(device)
+ return device
\ No newline at end of file
diff --git a/utils/image.py b/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9477805047f44e3f2f6c2ffdbc36e1fd03fe7c1
--- /dev/null
+++ b/utils/image.py
@@ -0,0 +1,29 @@
+import io
+from datetime import datetime, timezone
+import uuid
+import boto3
+from config import AWSConfig
+
+#Upload data to s3
+def write_to_s3(image, fname, region_name='ap-south-1'):
+ """
+ Write an image to s3. Returns the url. Requires AWSConfig
+ # TODO : Add error handling
+ # TODO : Add logging
+ """
+ s3 = boto3.client('s3', region_name,aws_access_key_id=AWSConfig.aws_access_key_id, aws_secret_access_key=AWSConfig.aws_secret_access_key)
+ s3.upload_fileobj(image,AWSConfig.bucket_name,fname)
+ return f'https://{AWSConfig.bucket_name}.s3.{region_name}.amazonaws.com/{fname}'
+
+def save_image(img):
+ """
+ Save an image to s3. Returns the url and filename for JSON output
+ # TODO : Add error handling
+ """
+ in_mem_file = io.BytesIO()
+ img.save(in_mem_file, format = 'png')
+ in_mem_file.seek(0)
+ dt = datetime.now()
+ file_name = str(uuid.uuid4())+'-'+str(int(dt.replace(tzinfo=timezone.utc).timestamp()))
+ img_url = write_to_s3(in_mem_file,f'sdimage/{file_name}.jpeg')
+ return img_url,file_name
diff --git a/utils/log.py b/utils/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c68a0193127564d1f440c43c4a046e7be37be4
--- /dev/null
+++ b/utils/log.py
@@ -0,0 +1,27 @@
+###########
+# Utlities for logging
+###########
+import logging
+
+def set_logger():
+ """
+ Custom logger for logging to console and file
+ Returns:
+ logger
+ The logger object
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+
+ ch = logging.StreamHandler()
+ ch.setLevel(logging.INFO)
+
+ # create formatter
+ formatter = logging.Formatter('[%(asctime)s] %(levelname)s - %(message)s')
+
+ # add formatter to ch
+ ch.setFormatter(formatter)
+
+ logger.addHandler(ch)
+
+ return logger