Duskfallcrew commited on
Commit
9eb8ba7
·
verified ·
1 Parent(s): ea14a31

Update app.py

Browse files

I DONT EVEN KNOW IF THIS WORKS BUT THIS IS THE OLD VERSION BUT YET NEW SINCE THE IDE KEPT EATING MY FILES

Files changed (1) hide show
  1. app.py +143 -138
app.py CHANGED
@@ -1,16 +1,10 @@
1
- # ---------------------- IMPORTS ----------------------
2
- # Core functionality
3
  import os
4
  import gradio as gr
5
  import torch
6
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
7
  from transformers import CLIPTextModel
8
-
9
- # Model handling
10
  from safetensors.torch import load_file
11
  from collections import OrderedDict
12
-
13
- # Utilities
14
  import re
15
  import json
16
  import gdown
@@ -26,142 +20,153 @@ import shutil
26
  import hashlib
27
  from datetime import datetime
28
  from typing import Dict, List, Optional
 
29
 
30
- # Hugging Face integration
31
- from huggingface_hub import login, HfApi
32
- from types import SimpleNamespace
 
 
 
 
 
33
 
34
  # ---------------------- UTILITY FUNCTIONS ----------------------
35
- def is_valid_url(url):
36
- """Check if a string is a valid URL."""
37
- try:
38
- result = urlparse(url)
39
- return all([result.scheme, result.netloc])
40
- except:
41
- return False
 
 
 
42
 
43
- def get_filename(url):
44
- """Extract filename from URL with error handling."""
45
- try:
46
- response = requests.get(url, stream=True)
47
- response.raise_for_status()
48
- if 'content-disposition' in response.headers:
49
- return re.findall('filename="?([^"]+)"?', response.headers['content-disposition'])[0]
50
- return os.path.basename(urlparse(url).path)
51
- except Exception as e:
52
- print(f"Error getting filename: {e}")
53
- return "downloaded_model"
54
-
55
- def get_supported_extensions():
56
- """Return supported model extensions."""
57
- return (".ckpt", ".safetensors", ".pt", ".pth")
58
-
59
- # ---------------------- MODEL CONVERSION CORE ----------------------
60
- class ConversionHistory:
61
- """Track conversion attempts and provide optimization suggestions."""
62
- def __init__(self, history_file="conversion_history.json"):
63
- self.history_file = history_file
64
- self.history = self._load_history()
65
-
66
- def _load_history(self):
67
- try:
68
- with open(self.history_file, 'r') as f:
69
- return json.load(f)
70
- except:
71
- return []
72
-
73
- def add_entry(self, model_path, settings, success, message):
74
- entry = {
75
- "timestamp": datetime.now().isoformat(),
76
- "model": model_path,
77
- "settings": settings,
78
- "success": success,
79
- "message": message
80
- }
81
- self.history.append(entry)
82
- self._save_history()
83
-
84
- def get_optimization_suggestions(self, model_path):
85
- """Generate suggestions based on conversion history."""
86
- suggestions = []
87
- for entry in self.history:
88
- if entry["model"] == model_path and not entry["success"]:
89
- suggestions.append(f"Previous failure: {entry['message']}")
90
- return suggestions
91
-
92
- def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private, output_widget):
93
- """Main conversion logic with error handling."""
94
- history = ConversionHistory()
95
  try:
96
- # Conversion steps here
97
- return "Conversion successful!"
98
- except Exception as e:
99
- history.add_entry(model_to_load, locals(), False, str(e))
100
- return f" Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  # ---------------------- GRADIO INTERFACE ----------------------
103
- # ---------------------- GRADIO INTERFACE ----------------------
104
- # Temporarily disabled theme configuration
105
- # def build_theme(theme_name, font):
106
- # """Create accessible theme with dynamic settings."""
107
- # return gr.themes.Default().set(
108
- # primary_hue="violet" if "dark" in theme_name else "indigo",
109
- # font=[font, "ui-sans-serif", "sans-serif"],
110
- # button_primary_background="*primary_300",
111
- # button_primary_text_color="white",
112
- # background_fill="*neutral_50" if "light" in theme_name else "*neutral_950"
113
- # )
114
-
115
- with gr.Blocks(
116
- css="""
117
- .single-column {max-width: 800px; margin: 0 auto;}
118
- .output-panel {background: rgba(0,0,0,0.05); padding: 20px; border-radius: 8px;}
119
- """,
120
- # theme=build_theme("dark", "Arial") # Theme disabled temporarily
121
- ) as demo:
122
-
123
- # Accessibility Controls
124
- with gr.Accordion(" Accessibility Settings", open=False):
125
- with gr.Row():
126
- theme_selector = gr.Dropdown(
127
- ["Dark Mode", "Light Mode", "High Contrast"],
128
- label="Color Theme",
129
- value="Dark Mode"
130
- )
131
- font_selector = gr.Dropdown(
132
- ["Arial", "OpenDyslexic", "Comic Neue"],
133
- label="Font Choice",
134
- value="Arial"
135
- )
136
- font_size = gr.Slider(12, 24, value=16, label="Font Size (px)")
137
-
138
- # Main Content
139
- with gr.Column(elem_classes="single-column"):
140
- gr.Markdown("""
141
- # 🎨 SDXL Model Converter
142
- Convert models between formats with accessibility in mind!
143
-
144
- ### Features:
145
- - 🧠 Memory-efficient conversions
146
- - ♿ Dyslexia-friendly fonts
147
- - 🌓 Dark/Light modes
148
- - 🤗 HF Hub integration
149
- """)
150
-
151
- # Input Fields
152
- model_to_load = gr.Textbox(label="Model Path/URL")
153
- save_precision_as = gr.Dropdown(["float32", "float16"], label="Precision")
154
-
155
- with gr.Row():
156
- epoch = gr.Number(label="Epoch", value=0)
157
- global_step = gr.Number(label="Global Step", value=0)
158
-
159
- # Conversion Button
160
- convert_btn = gr.Button("Convert", variant="primary")
161
-
162
- # Output Panel
163
- output = gr.Markdown(elem_classes="output-panel")
164
-
165
- # ---------------------- MAIN EXECUTION ----------------------
166
- if __name__ == "__main__":
167
- demo.launch(share=True)
 
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
  from transformers import CLIPTextModel
 
 
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
 
 
8
  import re
9
  import json
10
  import gdown
 
20
  import hashlib
21
  from datetime import datetime
22
  from typing import Dict, List, Optional
23
+ from huggingface_hub import login, HfApi, validate_repo_id, HfHubHTTPError
24
 
25
+ # ---------------------- DEPENDENCIES ----------------------
26
+ def install_dependencies_gradio():
27
+ """Installs the necessary dependencies for the Gradio app. Run this ONCE."""
28
+ try:
29
+ subprocess.run(["pip", "install", "-U", "torch", "diffusers", "transformers", "accelerate", "safetensors", "huggingface_hub", "xformers"])
30
+ print("Dependencies installed successfully.")
31
+ except Exception as e:
32
+ print(f"Error installing dependencies: {e}")
33
 
34
  # ---------------------- UTILITY FUNCTIONS ----------------------
35
+ def get_save_dtype(save_precision_as):
36
+ """Determines the save dtype based on the user's choice."""
37
+ if save_precision_as == "fp16":
38
+ return torch.float16
39
+ elif save_precision_as == "bf16":
40
+ return torch.bfloat16
41
+ elif save_precision_as == "float":
42
+ return torch.float32
43
+ else:
44
+ return None
45
 
46
+ def determine_load_checkpoint(model_to_load):
47
+ """Determines if the model to load is a checkpoint or a Diffusers model."""
48
+ if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'):
49
+ return True
50
+ elif os.path.isdir(model_to_load):
51
+ required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
52
+ if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
53
+ return False
54
+ return None
55
+
56
+ def increment_filename(filename):
57
+ """Increments the filename to avoid overwriting existing files."""
58
+ base, ext = os.path.splitext(filename)
59
+ counter = 1
60
+ while os.path.exists(filename):
61
+ filename = f"{base}({counter}){ext}"
62
+ counter += 1
63
+ return filename
64
+
65
+ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
66
+ """Creates a Hugging Face model repository if it doesn't exist."""
67
+ repo_id = f"{orgs_name}/{model_name.strip()}" if orgs_name else f"{user['name']}/{model_name.strip()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  try:
69
+ validate_repo_id(repo_id)
70
+ api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
71
+ print(f"Model repo '{repo_id}' created.")
72
+ except HfHubHTTPError:
73
+ print(f"Model repo '{repo_id}' already exists.")
74
+
75
+ return repo_id
76
+
77
+ # ---------------------- MODEL LOADING AND CONVERSION ----------------------
78
+ def load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype):
79
+ """Loads the SDXL model from a checkpoint or Diffusers model."""
80
+ model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if load_dtype == torch.float16 else "")
81
+ print(f"Loading {model_load_message}: {model_to_load}")
82
+
83
+ if is_load_checkpoint:
84
+ return load_from_sdxl_checkpoint(model_to_load)
85
+ else:
86
+ return load_sdxl_from_diffusers(model_to_load, load_dtype)
87
+
88
+ def load_from_sdxl_checkpoint(model_to_load):
89
+ """Loads the SDXL model components from a checkpoint file."""
90
+ # Implement loading logic here
91
+ text_encoder1, text_encoder2, vae, unet = None, None, None, None
92
+ # Example loading logic (replace with actual loading code)
93
+ # text_encoder1, text_encoder2, vae, unet = sdxl_model_util.load_models_from_sdxl_checkpoint("sdxl_base_v1-0", model_to_load, "cpu")
94
+ print(f"Loaded from checkpoint: {model_to_load}")
95
+ return text_encoder1, text_encoder2, vae, unet
96
+
97
+ def load_sdxl_from_diffusers(model_to_load, load_dtype):
98
+ """Loads an SDXL model from a Diffusers model directory."""
99
+ pipeline = StableDiffusionXLPipeline.from_pretrained(model_to_load, torch_dtype=load_dtype)
100
+ text_encoder1 = pipeline.text_encoder
101
+ text_encoder2 = pipeline.text_encoder_2
102
+ vae = pipeline.vae
103
+ unet = pipeline.unet
104
+
105
+ return text_encoder1, text_encoder2, vae, unet
106
+
107
+ def convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, save_dtype):
108
+ """Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
109
+ text_encoder1, text_encoder2, vae, unet = loaded_model_data
110
+ if is_save_checkpoint:
111
+ save_sdxl_as_checkpoint(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype)
112
+ else:
113
+ save_sdxl_as_diffusers(model_to_load, text_encoder1, text_encoder2, vae, unet, save_dtype)
114
+
115
+ def save_sdxl_as_checkpoint(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype):
116
+ """Saves the SDXL model components as a checkpoint file."""
117
+ # Implement saving logic here
118
+ print(f"Model saved as checkpoint: {model_to_save}")
119
+
120
+ def save_sdxl_as_diffusers(model_to_save, text_encoder1, text_encoder2, vae, unet, save_dtype):
121
+ """Saves the SDXL model as a Diffusers model."""
122
+ pipeline = StableDiffusionXLPipeline(
123
+ vae=vae,
124
+ text_encoder=text_encoder1,
125
+ text_encoder_2=text_encoder2,
126
+ unet=unet
127
+ )
128
+ pipeline.save_pretrained(model_to_save)
129
+ print(f"Model saved as Diffusers format: {model_to_save}")
130
+
131
+ # ---------------------- UPLOAD FUNCTION ----------------------
132
+ def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
133
+ """Uploads a model to the Hugging Face Hub."""
134
+ login(hf_token, add_to_git_credential=True)
135
+ api = HfApi()
136
+ user = api.whoami(hf_token)
137
+ model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
138
+
139
+ # Upload logic here
140
+ print(f"Model uploaded to: https://huggingface.co/{model_repo}")
141
 
142
  # ---------------------- GRADIO INTERFACE ----------------------
143
+ def main(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private):
144
+ """Main function orchestrating the entire process."""
145
+ load_dtype = get_save_dtype(save_precision_as)
146
+ is_load_checkpoint = determine_load_checkpoint(model_to_load)
147
+ is_save_checkpoint = not is_load_checkpoint
148
+
149
+ loaded_model_data = load_sdxl_model(model_to_load, is_load_checkpoint, load_dtype)
150
+ convert_and_save_sdxl_model(model_to_load, is_save_checkpoint, loaded_model_data, load_dtype)
151
+ upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
152
+
153
+ return "Conversion and upload completed successfully!"
154
+
155
+ with gr.Blocks() as demo:
156
+ model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model")
157
+ save_precision_as = gr.Dropdown(choices=["fp16", "bf16", "float"], label="Save Precision As")
158
+ epoch = gr.Number(value=0, label="Epoch to Write (Checkpoint)")
159
+ global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)")
160
+ reference_model = gr.Textbox(label="Reference Diffusers Model", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
161
+ output_path = gr.Textbox(label="Output Path", value="/content/output")
162
+ hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
163
+ orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
164
+ model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
165
+ make_private = gr.Checkbox(label="Make Repository Private", value=False)
166
+
167
+ convert_button = gr.Button("Convert and Upload")
168
+ output = gr.Markdown()
169
+
170
+ convert_button.click(fn=main, inputs=[model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private], outputs=output)
171
+
172
+ demo.launch()