File size: 4,087 Bytes
1bc22a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f25063
1bc22a0
 
 
 
 
 
9f25063
 
1bc22a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f25063
 
1bc22a0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# This demo needs to be run from the repo folder.
# python demo/fake_gan/run.py
import replicate
import gradio as gr
import requests
from PIL import Image
from io import BytesIO

import aiohttp
import asyncio

async def get_image_from_url(api_url, payload, file_path):
    async with aiohttp.ClientSession() as session:
        print("making api request to", api_url)
        async with session.post(api_url, json=payload) as resp:
            data = await resp.read()
            with open(file_path, 'wb') as f:
                f.write(data)

async def fake_gan_async(text_input):
    # SD 2-1
    sd_response = replicate.run(
                "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
                input={"prompt": text_input}
            )
    sd_image_url = sd_response[0]

    # SD XL
    sdxl_api_url = "http://3bfb-34-142-187-46.ngrok-free.app/get_image"
    payload_sdxl = {"prompt": text_input,} 
    sdxl_filepath = 'sdxl.png'

    # SD LoRA
    lora_api_url = "http://841a-35-240-138-5.ngrok-free.app/get_image"
    lora_text_input = text_input.replace("baby yoda", "gr1gu")
    lora_text_input = lora_text_input.replace("grogu", "gr1gu")
    payload_lora = {"prompt": lora_text_input,
               "lora_weights_path": "/content/drive/MyDrive/kohya-trainer/output/grogu_lora_v4-000030.safetensors"
               }
    lora_filepath = 'grogu_sdxl_lora.png'

    # SD XL Fintuned, CI
    finetuned_ci_url = "http://f20b-34-125-36-4.ngrok-free.app/get_image"
    payload_finetuned_ci = {"prompt": text_input,
               "lora_weights_path": "/content/drive/MyDrive/kohya-trainer/output/sdxl_20_epochs_ci.safetensors"
               }
    finetuned_ci_filepath = 'sdxl_finetuned_ci.png'

    # make requests concurrently
    await asyncio.gather(
        get_image_from_url(sdxl_api_url, payload_sdxl, sdxl_filepath),
        get_image_from_url(lora_api_url, payload_lora, lora_filepath),
        get_image_from_url(finetuned_ci_url, payload_finetuned_ci, finetuned_ci_filepath)
    )

    return [
            sd_image_url, 
            sdxl_filepath,
            lora_filepath,
            # "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
            # "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
            ]

# Running the async function:
# asyncio.run(fake_gan_async('your_text_input'))



title = "The Mandalorian Vango Demo"
description = "We experimented with a few models to make this demo! Here we compare A) Stable Diffusion 2-1, B) Stable Diffusion XL, C) finetuned SD XL on stills and automated labels from The Mandolorian and D) LoRA finetuned SD XL on 20 hand written labels of Baby Yoda"
# article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_animegan' alt='visitor badge'></center></p>"

demo = gr.Interface(
    fn=lambda x: asyncio.run(fake_gan_async(x)), 
    inputs=gr.Textbox(
        label="Enter your prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
            ).style(
                container=False,
            ), 
    outputs=[gr.outputs.Image(type="pil", label="Stable Diffusion 2-1"),
             gr.outputs.Image(type="pil", label="Stable Diffusion XL"),
             gr.outputs.Image(type="pil", label="SD XL, Grogu LoRA on 20 hand written labels"),
            #  gr.outputs.Image(type="pil", label='SD XL, finetuned on 2,000 "shitty" CI labels'),
            #  gr.outputs.Image(type="pil", label='SD XL, finetuned on 2,000 "shitty" BLIP labels'),
             ],
    title=title,
    description=description,)
    # article=article
    # examples=examples)

demo.launch()
if __name__ == "__main__":
    demo.launch()