JOY-Huang commited on
Commit
fbecea1
1 Parent(s): cef57ef
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -22,6 +22,12 @@ from module.ip_adapter.resampler import Resampler
22
  from module.aggregator import Aggregator
23
  from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
24
 
 
 
 
 
 
 
25
 
26
  transform = transforms.Compose([
27
  transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
@@ -66,7 +72,7 @@ image_proj_model = Resampler(
66
  init_ip_adapter_in_unet(
67
  unet,
68
  image_proj_model,
69
- "InstantX/InstantIR/adapter.pt",
70
  adapter_tokens=64,
71
  )
72
  print("Initializing InstantIR...")
@@ -77,7 +83,7 @@ pipe = InstantIRPipeline(
77
 
78
  # Add Previewer LoRA.
79
  lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
80
- "InstantX/InstantIR/previewer_lora_weights.bin",
81
  # weight_name="previewer_lora_weights.bin",
82
 
83
  )
@@ -145,7 +151,7 @@ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
145
  # Load weights.
146
  print("Loading checkpoint...")
147
  aggregator_state_dict = torch.load(
148
- "InstantX/InstantIR/aggregator.pt",
149
  map_location="cpu"
150
  )
151
  aggregator.load_state_dict(aggregator_state_dict, strict=True)
 
22
  from module.aggregator import Aggregator
23
  from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
24
 
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ hf_hub_download(repo_id="InstantX/InstantIR", filename="adapter.pt", local_dir="./checkpoints")
28
+ hf_hub_download(repo_id="InstantX/InstantIR", filename="aggregator.pt", local_dir="./checkpoints")
29
+ hf_hub_download(repo_id="InstantX/InstantIR", filename="previewer_lora_weights.bin", local_dir="./checkpoints")
30
+
31
 
32
  transform = transforms.Compose([
33
  transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
 
72
  init_ip_adapter_in_unet(
73
  unet,
74
  image_proj_model,
75
+ "checkpoints/adapter.pt",
76
  adapter_tokens=64,
77
  )
78
  print("Initializing InstantIR...")
 
83
 
84
  # Add Previewer LoRA.
85
  lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
86
+ "checkpoints/previewer_lora_weights.bin",
87
  # weight_name="previewer_lora_weights.bin",
88
 
89
  )
 
151
  # Load weights.
152
  print("Loading checkpoint...")
153
  aggregator_state_dict = torch.load(
154
+ "checkpoints/aggregator.pt",
155
  map_location="cpu"
156
  )
157
  aggregator.load_state_dict(aggregator_state_dict, strict=True)