duzx16 commited on
Commit
99564c0
·
1 Parent(s): ccb0160

Update modeling_chatglm.py

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_chatglm.py +12 -8
config.json CHANGED
@@ -1,5 +1,6 @@
1
  {
2
  "_name_or_path": "THUDM/chatglm2-6b",
 
3
  "architectures": [
4
  "ChatGLMModel"
5
  ],
 
1
  {
2
  "_name_or_path": "THUDM/chatglm2-6b",
3
+ "model_type": "chatglm",
4
  "architectures": [
5
  "ChatGLMModel"
6
  ],
modeling_chatglm.py CHANGED
@@ -702,6 +702,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
702
  dtype=config.torch_dtype, **init_kwargs)
703
  self.gradient_checkpointing = False
704
 
 
 
 
705
  def forward(
706
  self,
707
  input_ids,
@@ -932,7 +935,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
932
 
933
 
934
  @torch.no_grad()
935
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
936
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
937
  if history is None:
938
  history = []
@@ -951,7 +954,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
951
 
952
  @torch.no_grad()
953
  def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
954
- max_length: int = 2048, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
955
  return_past_key_values=False, **kwargs):
956
  if history is None:
957
  history = []
@@ -976,12 +979,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
976
  outputs, past_key_values = outputs
977
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
978
  response = tokenizer.decode(outputs)
979
- response = self.process_response(response)
980
- new_history = history + [(query, response)]
981
- if return_past_key_values:
982
- yield response, new_history, past_key_values
983
- else:
984
- yield response, new_history
 
985
 
986
  @torch.no_grad()
987
  def stream_generate(
 
702
  dtype=config.torch_dtype, **init_kwargs)
703
  self.gradient_checkpointing = False
704
 
705
+ def get_input_embeddings(self):
706
+ return self.embedding.word_embeddings
707
+
708
  def forward(
709
  self,
710
  input_ids,
 
935
 
936
 
937
  @torch.no_grad()
938
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
939
  do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
940
  if history is None:
941
  history = []
 
954
 
955
  @torch.no_grad()
956
  def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
957
+ max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
958
  return_past_key_values=False, **kwargs):
959
  if history is None:
960
  history = []
 
979
  outputs, past_key_values = outputs
980
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
981
  response = tokenizer.decode(outputs)
982
+ if response and response[-1] != "�":
983
+ response = self.process_response(response)
984
+ new_history = history + [(query, response)]
985
+ if return_past_key_values:
986
+ yield response, new_history, past_key_values
987
+ else:
988
+ yield response, new_history
989
 
990
  @torch.no_grad()
991
  def stream_generate(