add set_input_embedding to support resize token embedding
#49
by
Yes365
- opened
- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
@@ -766,6 +766,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
766 |
def get_input_embeddings(self):
|
767 |
return self.embedding.word_embeddings
|
768 |
|
|
|
|
|
|
|
769 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
770 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
771 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
|
|
766 |
def get_input_embeddings(self):
|
767 |
return self.embedding.word_embeddings
|
768 |
|
769 |
+
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
770 |
+
self.embedding.word_embeddings = new_embeddings
|
771 |
+
|
772 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
773 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
774 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|