ohayonguy commited on
Commit
c50d79a
1 Parent(s): eafdcb2

fixing kdiff

Browse files
arch/hourglass/axial_rope.py CHANGED
@@ -21,7 +21,6 @@ def rotate_half(x):
21
  return x.view(*shape, d * r)
22
 
23
 
24
- @flags.compile_wrap
25
  def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
26
  freqs = freqs.to(t)
27
  rot_dim = freqs.shape[-1]
 
21
  return x.view(*shape, d * r)
22
 
23
 
 
24
  def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
25
  freqs = freqs.to(t)
26
  rot_dim = freqs.shape[-1]
arch/hourglass/image_transformer_v2.py CHANGED
@@ -87,7 +87,6 @@ def filter_params(function, module):
87
 
88
  # Kernels
89
 
90
- @flags.compile_wrap
91
  def linear_geglu(x, weight, bias=None):
92
  x = x @ weight.mT
93
  if bias is not None:
@@ -96,7 +95,6 @@ def linear_geglu(x, weight, bias=None):
96
  return x * F.gelu(gate)
97
 
98
 
99
- @flags.compile_wrap
100
  def rms_norm(x, scale, eps):
101
  dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
102
  mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
@@ -104,7 +102,6 @@ def rms_norm(x, scale, eps):
104
  return x * scale.to(x.dtype)
105
 
106
 
107
- @flags.compile_wrap
108
  def scale_for_cosine_sim(q, k, scale, eps):
109
  dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
110
  sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
@@ -115,7 +112,6 @@ def scale_for_cosine_sim(q, k, scale, eps):
115
  return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
116
 
117
 
118
- @flags.compile_wrap
119
  def scale_for_cosine_sim_qkv(qkv, scale, eps):
120
  q, k, v = qkv.unbind(2)
121
  q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
@@ -179,7 +175,6 @@ class AdaRMSNorm(nn.Module):
179
 
180
  # Rotary position embeddings
181
 
182
- @flags.compile_wrap
183
  def apply_rotary_emb(x, theta, conj=False):
184
  out_dtype = x.dtype
185
  dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
@@ -195,7 +190,6 @@ def apply_rotary_emb(x, theta, conj=False):
195
  return torch.cat((y1, y2, x3), dim=-1)
196
 
197
 
198
- @flags.compile_wrap
199
  def _apply_rotary_emb_inplace(x, theta, conj):
200
  dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
201
  d = theta.shape[-1]
 
87
 
88
  # Kernels
89
 
 
90
  def linear_geglu(x, weight, bias=None):
91
  x = x @ weight.mT
92
  if bias is not None:
 
95
  return x * F.gelu(gate)
96
 
97
 
 
98
  def rms_norm(x, scale, eps):
99
  dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
100
  mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
 
102
  return x * scale.to(x.dtype)
103
 
104
 
 
105
  def scale_for_cosine_sim(q, k, scale, eps):
106
  dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
107
  sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
 
112
  return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
113
 
114
 
 
115
  def scale_for_cosine_sim_qkv(qkv, scale, eps):
116
  q, k, v = qkv.unbind(2)
117
  q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
 
175
 
176
  # Rotary position embeddings
177
 
 
178
  def apply_rotary_emb(x, theta, conj=False):
179
  out_dtype = x.dtype
180
  dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
 
190
  return torch.cat((y1, y2, x3), dim=-1)
191
 
192
 
 
193
  def _apply_rotary_emb_inplace(x, theta, conj):
194
  dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
195
  d = theta.shape[-1]