Markus28 commited on
Commit
850b9a2
·
1 Parent(s): 5c4e4bf

fix: use proper initilization for embedding layer

Browse files
Files changed (1) hide show
  1. modeling_lora.py +28 -11
modeling_lora.py CHANGED
@@ -11,20 +11,37 @@ from torch.nn import Parameter
11
  from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class LoRAParametrization(nn.Module):
15
- def __init__(self, fan_in, fan_out, fan_in_fan_out=False, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
16
  super().__init__()
17
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
18
  # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
 
19
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
20
- lora_A_data = []
21
- for _ in range(num_adaptions):
22
- new_adaption = torch.zeros(self.swap((rank, fan_in)))
23
- nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
24
- lora_A_data.append(new_adaption)
25
- lora_A_data = torch.stack(lora_A_data, dim=0)
26
- self.lora_A = nn.Parameter(lora_A_data)
27
- self.lora_B = nn.Parameter(torch.zeros((num_adaptions, *self.swap((fan_out, rank)))))
 
 
28
  self.lora_alpha, self.rank = lora_alpha, rank
29
  self.scaling = lora_alpha / rank
30
  self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
@@ -55,14 +72,14 @@ class LoRAParametrization(nn.Module):
55
  def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
56
  fan_out, fan_in = layer.weight.shape
57
  return cls(
58
- fan_in, fan_out, num_adaptions=num_adaptions, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
59
  )
60
 
61
  @classmethod
62
  def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
63
  fan_in, fan_out = layer.weight.shape
64
  return cls(
65
- fan_in, fan_out, num_adaptions=num_adaptions, fan_in_fan_out=True, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
66
  )
67
 
68
  @classmethod
 
11
  from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
12
 
13
 
14
+ def initialized_weights(shape, num_adaptions, init='kaiming'):
15
+ weight_data = []
16
+ for _ in range(num_adaptions):
17
+ new_adaption = torch.zeros(shape)
18
+ if init == 'kaiming':
19
+ nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
20
+ elif init == 'normal':
21
+ nn.init.normal_(new_adaption)
22
+ else:
23
+ raise NotImplementedError
24
+ weight_data.append(new_adaption)
25
+ return torch.stack(weight_data, dim=0)
26
+
27
+
28
  class LoRAParametrization(nn.Module):
29
+ def __init__(self, fan_in, fan_out, layer_type='linear', num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
30
  super().__init__()
31
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
32
  # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
33
+ fan_in_fan_out = (layer_type == 'embedding')
34
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
35
+
36
+ if layer_type == 'linear':
37
+ self.lora_A = nn.Parameter(initialized_weights((rank, fan_in), num_adaptions, init='kaiming'))
38
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
39
+ elif layer_type == 'embedding':
40
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
41
+ self.lora_B = nn.Parameter(initialized_weights((rank, fan_out), num_adaptions=num_adaptions, init='normal'))
42
+ else:
43
+ raise NotImplementedError
44
+
45
  self.lora_alpha, self.rank = lora_alpha, rank
46
  self.scaling = lora_alpha / rank
47
  self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
 
72
  def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
73
  fan_out, fan_in = layer.weight.shape
74
  return cls(
75
+ fan_in, fan_out, num_adaptions=num_adaptions, layer_type='linear', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
76
  )
77
 
78
  @classmethod
79
  def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
80
  fan_in, fan_out = layer.weight.shape
81
  return cls(
82
+ fan_in, fan_out, num_adaptions=num_adaptions, layer_type='embedding', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
83
  )
84
 
85
  @classmethod