Spaces:
Sleeping
Sleeping
ChenyangSi
commited on
Commit
•
dd97a63
1
Parent(s):
97dc735
Update free_lunch_utils.py
Browse files- free_lunch_utils.py +25 -2
free_lunch_utils.py
CHANGED
@@ -93,13 +93,36 @@ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
|
93 |
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
94 |
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# --------------- FreeU code -----------------------
|
97 |
# Only operate on the first two stages
|
98 |
if hidden_states.shape[1] == 1280:
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
101 |
if hidden_states.shape[1] == 640:
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
104 |
# ---------------------------------------------------------
|
105 |
|
|
|
93 |
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
94 |
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
95 |
|
96 |
+
# # --------------- FreeU code -----------------------
|
97 |
+
# # Only operate on the first two stages
|
98 |
+
# if hidden_states.shape[1] == 1280:
|
99 |
+
# hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
100 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
101 |
+
# if hidden_states.shape[1] == 640:
|
102 |
+
# hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
103 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
104 |
+
# # ---------------------------------------------------------
|
105 |
+
|
106 |
# --------------- FreeU code -----------------------
|
107 |
# Only operate on the first two stages
|
108 |
if hidden_states.shape[1] == 1280:
|
109 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
110 |
+
B = hidden_mean.shape[0]
|
111 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
112 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
113 |
+
|
114 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
115 |
+
|
116 |
+
hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
|
117 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
118 |
if hidden_states.shape[1] == 640:
|
119 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
120 |
+
B = hidden_mean.shape[0]
|
121 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
122 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
123 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
124 |
+
|
125 |
+
hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
|
126 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
127 |
# ---------------------------------------------------------
|
128 |
|