Spaces:
Running
on
Zero
Running
on
Zero
fix dir
Browse files
app.py
CHANGED
@@ -22,6 +22,12 @@ from module.ip_adapter.resampler import Resampler
|
|
22 |
from module.aggregator import Aggregator
|
23 |
from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
transform = transforms.Compose([
|
27 |
transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
|
@@ -66,7 +72,7 @@ image_proj_model = Resampler(
|
|
66 |
init_ip_adapter_in_unet(
|
67 |
unet,
|
68 |
image_proj_model,
|
69 |
-
"
|
70 |
adapter_tokens=64,
|
71 |
)
|
72 |
print("Initializing InstantIR...")
|
@@ -77,7 +83,7 @@ pipe = InstantIRPipeline(
|
|
77 |
|
78 |
# Add Previewer LoRA.
|
79 |
lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
|
80 |
-
"
|
81 |
# weight_name="previewer_lora_weights.bin",
|
82 |
|
83 |
)
|
@@ -145,7 +151,7 @@ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
|
145 |
# Load weights.
|
146 |
print("Loading checkpoint...")
|
147 |
aggregator_state_dict = torch.load(
|
148 |
-
"
|
149 |
map_location="cpu"
|
150 |
)
|
151 |
aggregator.load_state_dict(aggregator_state_dict, strict=True)
|
|
|
22 |
from module.aggregator import Aggregator
|
23 |
from pipelines.sdxl_instantir import InstantIRPipeline, LCM_LORA_MODULES, PREVIEWER_LORA_MODULES
|
24 |
|
25 |
+
from huggingface_hub import hf_hub_download
|
26 |
+
|
27 |
+
hf_hub_download(repo_id="InstantX/InstantIR", filename="adapter.pt", local_dir="./checkpoints")
|
28 |
+
hf_hub_download(repo_id="InstantX/InstantIR", filename="aggregator.pt", local_dir="./checkpoints")
|
29 |
+
hf_hub_download(repo_id="InstantX/InstantIR", filename="previewer_lora_weights.bin", local_dir="./checkpoints")
|
30 |
+
|
31 |
|
32 |
transform = transforms.Compose([
|
33 |
transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
|
72 |
init_ip_adapter_in_unet(
|
73 |
unet,
|
74 |
image_proj_model,
|
75 |
+
"checkpoints/adapter.pt",
|
76 |
adapter_tokens=64,
|
77 |
)
|
78 |
print("Initializing InstantIR...")
|
|
|
83 |
|
84 |
# Add Previewer LoRA.
|
85 |
lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(
|
86 |
+
"checkpoints/previewer_lora_weights.bin",
|
87 |
# weight_name="previewer_lora_weights.bin",
|
88 |
|
89 |
)
|
|
|
151 |
# Load weights.
|
152 |
print("Loading checkpoint...")
|
153 |
aggregator_state_dict = torch.load(
|
154 |
+
"checkpoints/aggregator.pt",
|
155 |
map_location="cpu"
|
156 |
)
|
157 |
aggregator.load_state_dict(aggregator_state_dict, strict=True)
|