SunderAli17 commited on
Commit
f724fc0
·
verified ·
1 Parent(s): 52f9ca7

Update flux/model.py

Browse files
Files changed (1) hide show
  1. flux/model.py +8 -7
flux/model.py CHANGED
@@ -79,9 +79,9 @@ class Flux(nn.Module):
79
 
80
  self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
81
 
82
- self.toonmage_ca = None
83
- self.toonmage_double_interval = 2
84
- self.toonmage_single_interval = 4
85
 
86
  def forward(
87
  self,
@@ -115,8 +115,8 @@ class Flux(nn.Module):
115
  for i, block in enumerate(self.double_blocks):
116
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
 
118
- if i % self.toonmage_double_interval == 0 and id is not None:
119
- img = img + id_weight * self.toonmage_ca[ca_idx](id, img)
120
  ca_idx += 1
121
 
122
  img = torch.cat((txt, img), 1)
@@ -124,8 +124,8 @@ class Flux(nn.Module):
124
  x = block(img, vec=vec, pe=pe)
125
  real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
126
 
127
- if i % self.toonmage_single_interval == 0 and id is not None:
128
- real_img = real_img + id_weight * self.toonmage_ca[ca_idx](id, real_img)
129
  ca_idx += 1
130
 
131
  img = torch.cat((txt, real_img), 1)
@@ -133,3 +133,4 @@ class Flux(nn.Module):
133
 
134
  img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
135
  return img
 
 
79
 
80
  self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
81
 
82
+ self.pulid_ca = None
83
+ self.pulid_double_interval = 2
84
+ self.pulid_single_interval = 4
85
 
86
  def forward(
87
  self,
 
115
  for i, block in enumerate(self.double_blocks):
116
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
 
118
+ if i % self.pulid_double_interval == 0 and id is not None:
119
+ img = img + id_weight * self.pulid_ca[ca_idx](id, img)
120
  ca_idx += 1
121
 
122
  img = torch.cat((txt, img), 1)
 
124
  x = block(img, vec=vec, pe=pe)
125
  real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
126
 
127
+ if i % self.pulid_single_interval == 0 and id is not None:
128
+ real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
129
  ca_idx += 1
130
 
131
  img = torch.cat((txt, real_img), 1)
 
133
 
134
  img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
135
  return img
136
+