add stream params for chat and char_crop

#24
Files changed (1) hide show
  1. modeling_GOT.py +5 -5
modeling_GOT.py CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
488
 
489
  self.disable_torch_init()
490
 
@@ -563,8 +563,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
565
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
566
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
-
568
  if stream_flag:
569
  with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
@@ -728,7 +728,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
728
  return processed_images
729
 
730
 
731
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
  # Model
733
  self.disable_torch_init()
734
  multi_page=False
@@ -817,7 +817,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
819
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
820
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
  with torch.autocast("cuda", dtype=torch.bfloat16):
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False, streamer=None):
488
 
489
  self.disable_torch_init()
490
 
 
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
565
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
566
+ streamer = streamer if streamer else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
+
568
  if stream_flag:
569
  with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
 
728
  return processed_images
729
 
730
 
731
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False, streamer=None):
732
  # Model
733
  self.disable_torch_init()
734
  multi_page=False
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
819
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
820
+ streamer = streamer if streamer else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
  with torch.autocast("cuda", dtype=torch.bfloat16):