bharatcoder commited on
Commit
58a7875
·
verified ·
1 Parent(s): aa4a213

Initial Commits

Browse files
Files changed (2) hide show
  1. text_image.py +124 -0
  2. utils.py +66 -0
text_image.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv, find_dotenv
3
+ from huggingface_hub import InferenceClient
4
+ import gradio as gr
5
+
6
+ from utils import get_default_hyperparameters # Import the utility function
7
+
8
+ load_dotenv(find_dotenv())
9
+ hf_token = os.getenv("HF_TOKEN")
10
+
11
+ headers = {
12
+ "x-wait-for-model": "true",
13
+ "x-use-cache": "true", # Use past generations
14
+ }
15
+
16
+ imagegen_models_list = [
17
+ "black-forest-labs/FLUX.1-schnell",
18
+ "black-forest-labs/FLUX.1-dev",
19
+ "strangerzonehf/Flux-Midjourney-Mix2-LoRA",
20
+ "stabilityai/stable-diffusion-3.5-large",
21
+ "stabilityai/stable-diffusion-xl-base-1.0",
22
+ "stable-diffusion-v1-5/stable-diffusion-v1-5"
23
+ ]
24
+
25
+ promptgen_models_list = [
26
+ "meta-llama/Llama-3.2-11B-Vision-Instruct",
27
+ "mistralai/Mistral-Nemo-Instruct-2407"
28
+ ]
29
+
30
+ # Function to generate image
31
+ def generate_image(model, prompt, guidance, width, height, num_inference_steps, seed):
32
+ print(f"""Generating image with following parameters:
33
+ Model: {model}, Prompt: {prompt}, Guidance Scale: {guidance}, Width: {width}, Height: {height},""")
34
+ client = InferenceClient(model=model, headers=headers, token=hf_token)
35
+ image = client.text_to_image(
36
+ model=model,
37
+ prompt=prompt,
38
+ guidance_scale=guidance,
39
+ height=height,
40
+ width=width,
41
+ num_inference_steps=num_inference_steps,
42
+ seed=seed
43
+ )
44
+ return image
45
+
46
+ # Function to update hyperparameters dynamically
47
+ def update_hyperparameters(model_name):
48
+ default_params = get_default_hyperparameters(model_name)
49
+ return (default_params['guidance_scale'],
50
+ default_params['width'],
51
+ default_params['height'],
52
+ default_params['num_inference_steps'],
53
+ default_params['seed'])
54
+
55
+ # Function to expand the idea using the selected prompt generation model
56
+ def expand_idea(promptgen_model, idea_text):
57
+ print(f"Expanding idea with model: {promptgen_model}")
58
+ client = InferenceClient(model=promptgen_model, headers=headers, token=hf_token)
59
+ response = client.chat_completion(
60
+ messages=[
61
+ {
62
+ 'role': 'user',
63
+ 'content': f'For the given idea, generate a text prompt to generate an image from a text to image generator. Be creative and include both subject and style prompts into one. Do not, explain your decisions. Idea: {idea_text}',
64
+ },
65
+ ],
66
+ max_tokens=80,
67
+ temperature=1.1, # Set temperature higher for dynamic responses
68
+ top_p=0.9,
69
+ )
70
+ expanded_prompt = response.choices[0].message.content # Assuming this is how the prompt is expanded
71
+ return expanded_prompt
72
+
73
+ # Interface for generating images and expanding idea into a prompt
74
+ def run_interface():
75
+ with gr.Blocks() as iface:
76
+ with gr.Row():
77
+ # Image generation controls
78
+ with gr.Column(scale=1):
79
+ model_dropdown = gr.Dropdown(choices=imagegen_models_list, label="Image Model", value=imagegen_models_list[0])
80
+ prompt_textbox = gr.Textbox(label="Prompt", lines=5, value="Astronaut floating in space")
81
+
82
+ # Initial default values based on the first model
83
+ default_params = get_default_hyperparameters(model_dropdown.value)
84
+
85
+ guidance_slider = gr.Slider(0, 10, step=0.1, label="Guidance Scale", value=default_params['guidance_scale'])
86
+ width_slider = gr.Slider(256, 2048, step=32, label="Width", value=default_params['width'])
87
+ height_slider = gr.Slider(256, 2048, step=32, label="Height", value=default_params['height'])
88
+ steps_slider = gr.Slider(1, 100, step=1, label="Number of Inference Steps", value=default_params['num_inference_steps'])
89
+ seed_number = gr.Number(label="Seed", value=default_params['seed'])
90
+
91
+ # Update sliders based on model selection
92
+ model_dropdown.change(
93
+ fn=update_hyperparameters,
94
+ inputs=model_dropdown,
95
+ outputs=[guidance_slider, width_slider, height_slider, steps_slider, seed_number]
96
+ )
97
+
98
+ generate_button = gr.Button("Generate Image")
99
+ output_image = gr.Image(type="pil", format="png")
100
+
101
+ generate_button.click(
102
+ fn=generate_image,
103
+ inputs=[model_dropdown, prompt_textbox, guidance_slider, width_slider, height_slider, steps_slider, seed_number],
104
+ outputs=output_image
105
+ )
106
+
107
+ # Prompt expansion section in a collapsible panel
108
+ with gr.Column(scale=1):
109
+ gr.Markdown("**Expand an Idea into Prompt**")
110
+ with gr.Accordion("Your Image Idea", open=False):
111
+ promptgen_dropdown = gr.Dropdown(choices=promptgen_models_list, label="Prompt Generation Model", value=promptgen_models_list[0])
112
+ idea_textbox = gr.Textbox(label="Enter your idea/sketch (rough text)", lines=5, placeholder="Describe your idea here...")
113
+ expand_button = gr.Button("Expand to Prompt")
114
+
115
+ # Expand idea and directly output to the main prompt textbox
116
+ expand_button.click(
117
+ fn=expand_idea,
118
+ inputs=[promptgen_dropdown, idea_textbox],
119
+ outputs=prompt_textbox # Updating the prompt_textbox directly
120
+ )
121
+
122
+ iface.launch()
123
+
124
+ run_interface()
utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # util.py
2
+
3
+ def get_default_hyperparameters(model_name):
4
+ """
5
+ Returns default hyperparameters based on the model name.
6
+
7
+ Args:
8
+ model_name (str): Name of the selected model.
9
+
10
+ Returns:
11
+ dict: A dictionary of default hyperparameters.
12
+ """
13
+ # Define default hyperparameters for each model
14
+ default_hyperparameters = {
15
+ "black-forest-labs/FLUX.1-schnell": {
16
+ "guidance_scale": 0,
17
+ "width": 1024,
18
+ "height": 1024,
19
+ "num_inference_steps": 4,
20
+ "seed": 0
21
+ },
22
+ "black-forest-labs/FLUX.1-dev": {
23
+ "guidance_scale": 3.5,
24
+ "width": 1024,
25
+ "height": 1024,
26
+ "num_inference_steps": 28,
27
+ "seed": 0
28
+ },
29
+ "strangerzonehf/Flux-Midjourney-Mix2-LoRA": {
30
+ "guidance_scale": 3.5,
31
+ "width": 1024,
32
+ "height": 1024,
33
+ "num_inference_steps": 28,
34
+ "seed": 0
35
+ },
36
+ "stabilityai/stable-diffusion-3.5-large": {
37
+ "guidance_scale": 4.5,
38
+ "width": 1024,
39
+ "height": 1024,
40
+ "num_inference_steps": 35,
41
+ "seed": 0
42
+ },
43
+ "stabilityai/stable-diffusion-xl-base-1.0": {
44
+ "guidance_scale": 7,
45
+ "width": 1024,
46
+ "height": 1024,
47
+ "num_inference_steps": 30,
48
+ "seed": 0
49
+ },
50
+ "stable-diffusion-v1-5/stable-diffusion-v1-5": {
51
+ "guidance_scale": 5.0,
52
+ "width": 512,
53
+ "height": 512,
54
+ "num_inference_steps": 20,
55
+ "seed": 0
56
+ }
57
+ }
58
+
59
+ # Return the hyperparameters for the selected model or a default set if not found
60
+ return default_hyperparameters.get(model_name, {
61
+ "guidance_scale": 7.5,
62
+ "width": 512,
63
+ "height": 512,
64
+ "num_inference_steps": 20,
65
+ "seed": 42
66
+ })