patrickvonplaten commited on
Commit
d7c590b
·
1 Parent(s): f56edba
Files changed (2) hide show
  1. convert_flax_to_pt.py +5 -148
  2. parti_prompts.py +2 -2
convert_flax_to_pt.py CHANGED
@@ -2,109 +2,16 @@ import argparse
2
  import json
3
  import os
4
  import shutil
 
5
  from tempfile import TemporaryDirectory
6
  from typing import List, Optional
 
7
 
8
  from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
9
  from huggingface_hub.file_download import repo_folder_name
10
 
11
 
12
- class AlreadyExists(Exception):
13
- pass
14
-
15
-
16
- def is_index_stable_diffusion_like(config_dict):
17
- if "_class_name" not in config_dict:
18
- return False
19
-
20
- compatible_classes = [
21
- "AltDiffusionImg2ImgPipeline",
22
- "AltDiffusionPipeline",
23
- "CycleDiffusionPipeline",
24
- "StableDiffusionImageVariationPipeline",
25
- "StableDiffusionImg2ImgPipeline",
26
- "StableDiffusionInpaintPipeline",
27
- "StableDiffusionInpaintPipelineLegacy",
28
- "StableDiffusionPipeline",
29
- "StableDiffusionPipelineSafe",
30
- "StableDiffusionUpscalePipeline",
31
- "VersatileDiffusionDualGuidedPipeline",
32
- "VersatileDiffusionImageVariationPipeline",
33
- "VersatileDiffusionPipeline",
34
- "VersatileDiffusionTextToImagePipeline",
35
- "OnnxStableDiffusionImg2ImgPipeline",
36
- "OnnxStableDiffusionInpaintPipeline",
37
- "OnnxStableDiffusionInpaintPipelineLegacy",
38
- "OnnxStableDiffusionPipeline",
39
- "StableDiffusionOnnxPipeline",
40
- "FlaxStableDiffusionPipeline",
41
- ]
42
- return config_dict["_class_name"] in compatible_classes
43
-
44
-
45
- def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
46
- config_file = "model_index.json"
47
- # os.makedirs(os.path.join(folder, "scheduler"), exist_ok=True)
48
- model_index_file = hf_hub_download(repo_id=model_id, filename="model_index.json")
49
-
50
- with open(model_index_file, "r") as f:
51
- index_dict = json.load(f)
52
- if index_dict.get("feature_extractor", None) is None:
53
- print(f"{model_id} has no feature extractor")
54
- return False, False
55
-
56
- if index_dict["feature_extractor"][-1] != "CLIPFeatureExtractor":
57
- print(f"{model_id} is not out of date or is not CLIP")
58
- return False, False
59
-
60
- # old_config_file = hf_hub_download(repo_id=model_id, filename=config_file)
61
- old_config_file = model_index_file
62
-
63
- new_config_file = os.path.join(folder, config_file)
64
- success = convert_file(old_config_file, new_config_file)
65
- if success:
66
- operations = [CommitOperationAdd(path_in_repo=config_file, path_or_fileobj=new_config_file)]
67
- model_type = success
68
- return operations, model_type
69
- else:
70
- return False, False
71
-
72
-
73
- def convert_file(
74
- old_config: str,
75
- new_config: str,
76
- ):
77
- with open(old_config, "r") as f:
78
- old_dict = json.load(f)
79
-
80
- old_dict["feature_extractor"][-1] = "CLIPImageProcessor"
81
- # if "clip_sample" not in old_dict:
82
- # print("Make scheduler DDIM compatible")
83
- # old_dict["clip_sample"] = False
84
- # else:
85
- # print("No matching config")
86
- # return False
87
-
88
- with open(new_config, 'w') as f:
89
- json_str = json.dumps(old_dict, indent=2, sort_keys=True) + "\n"
90
- f.write(json_str)
91
-
92
- return "Stable Diffusion"
93
-
94
-
95
- def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
96
- try:
97
- discussions = api.get_repo_discussions(repo_id=model_id)
98
- except Exception:
99
- return None
100
- for discussion in discussions:
101
- if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
102
- return discussion
103
-
104
-
105
  def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
106
- # pr_title = "Correct `sample_size` of {}'s unet to have correct width and height default"
107
- pr_title = "Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`."
108
  info = api.model_info(model_id)
109
  filenames = set(s.rfilename for s in info.siblings)
110
 
@@ -134,54 +41,9 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
134
  folder_path=folder,
135
  repo_id=model_id,
136
  repo_type="model",
 
137
  )
138
- )
139
-
140
- new_pr = None
141
- try:
142
- operations = None
143
- pr = previous_pr(api, model_id, pr_title)
144
- if pr is not None and not force:
145
- url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
146
- new_pr = pr
147
- raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
148
- else:
149
- operations, model_type = convert_single(model_id, folder)
150
-
151
- if operations:
152
- pr_title = pr_title.format(model_type)
153
- # if model_type == "Stable Diffusion 1":
154
- # sample_size = 64
155
- # image_size = 512
156
- # elif model_type == "Stable Diffusion 2":
157
- # sample_size = 96
158
- # image_size = 768
159
-
160
- # pr_description = (
161
- # f"Since `diffusers==0.9.0` the width and height is automatically inferred from the `sample_size` attribute of your unet's config. It seems like your diffusion model has the same architecture as {model_type} which means that when using this model, by default an image size of {image_size}x{image_size} should be generated. This in turn means the unet's sample size should be **{sample_size}**. \n\n In order to suppress to update your configuration on the fly and to suppress the deprecation warning added in this PR: https://github.com/huggingface/diffusers/pull/1406/files#r1035703505 it is strongly recommended to merge this PR."
162
- # )
163
- contributor = model_id.split("/")[0]
164
- pr_description = (
165
- f"Hey {contributor} 👋, \n\n Your model repository seems to contain logic to load a feature extractor that is deprecated, which you should notice by seeing the warning: "
166
- "\n\n ```\ntransformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. "
167
- f"Please use CLIPImageProcessor instead. warnings.warn(\n``` \n\n when running `pipe = DiffusionPipeline.from_pretrained({model_id})`."
168
- "This PR makes sure that the warning does not show anymore by replacing `CLIPFeatureExtractor` with `CLIPImageProcessor`. This will certainly not change or break your checkpoint, but only"
169
- "make sure that everything is up to date. \n\n Best, the 🧨 Diffusers team."
170
- )
171
- new_pr = api.create_commit(
172
- repo_id=model_id,
173
- operations=operations,
174
- commit_message=pr_title,
175
- commit_description=pr_description,
176
- create_pr=True,
177
- )
178
- print(f"Pr created at {new_pr.pr_url}")
179
- else:
180
- print(f"No files to convert for {model_id}")
181
- finally:
182
- shutil.rmtree(folder)
183
- return new_pr
184
-
185
 
186
  if __name__ == "__main__":
187
  DESCRIPTION = """
@@ -196,12 +58,7 @@ if __name__ == "__main__":
196
  type=str,
197
  help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
198
  )
199
- parser.add_argument(
200
- "--force",
201
- action="store_true",
202
- help="Create the PR even if it already exists of if the model was already converted.",
203
- )
204
  args = parser.parse_args()
205
  model_id = args.model_id
206
  api = HfApi()
207
- convert(api, model_id, force=args.force)
 
2
  import json
3
  import os
4
  import shutil
5
+ import torch
6
  from tempfile import TemporaryDirectory
7
  from typing import List, Optional
8
+ from diffusers import StableDiffusionPipeline, ControlNetModel
9
 
10
  from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
11
  from huggingface_hub.file_download import repo_folder_name
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
 
 
15
  info = api.model_info(model_id)
16
  filenames = set(s.rfilename for s in info.siblings)
17
 
 
41
  folder_path=folder,
42
  repo_id=model_id,
43
  repo_type="model",
44
+ create_pr=True,
45
  )
46
+ print(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  DESCRIPTION = """
 
58
  type=str,
59
  help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
60
  )
 
 
 
 
 
61
  args = parser.parse_args()
62
  model_id = args.model_id
63
  api = HfApi()
64
+ convert(api, model_id)
parti_prompts.py CHANGED
@@ -28,8 +28,8 @@ def get_karlo_eval(ckpt):
28
  pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16)
29
  pipe.to("cuda")
30
 
31
- def karlo_eval(prompt):
32
- images = pipe(prompt, prior_num_inference_steps=50, decoder_num_inference_steps=NUM_INFERENCE_STEPS).images
33
  return images
34
 
35
  return karlo_eval
 
28
  pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16)
29
  pipe.to("cuda")
30
 
31
+ def karlo_eval(prompt, generator=None):
32
+ images = pipe(prompt, prior_num_inference_steps=50, generator=generator, decoder_num_inference_steps=NUM_INFERENCE_STEPS).images
33
  return images
34
 
35
  return karlo_eval