add stream params for chat and char_crop
#24
by
weege007
- opened
- modeling_GOT.py +5 -5
modeling_GOT.py
CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
@@ -563,8 +563,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
564 |
keywords = [stop_str]
|
565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
566 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
-
|
568 |
if stream_flag:
|
569 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
570 |
output_ids = self.generate(
|
@@ -728,7 +728,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
728 |
return processed_images
|
729 |
|
730 |
|
731 |
-
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
732 |
# Model
|
733 |
self.disable_torch_init()
|
734 |
multi_page=False
|
@@ -817,7 +817,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
818 |
keywords = [stop_str]
|
819 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
820 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
821 |
|
822 |
if stream_flag:
|
823 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False, streamer=None):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
564 |
keywords = [stop_str]
|
565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
566 |
+
streamer = streamer if streamer else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
+
|
568 |
if stream_flag:
|
569 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
570 |
output_ids = self.generate(
|
|
|
728 |
return processed_images
|
729 |
|
730 |
|
731 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False, streamer=None):
|
732 |
# Model
|
733 |
self.disable_torch_init()
|
734 |
multi_page=False
|
|
|
817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
818 |
keywords = [stop_str]
|
819 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
820 |
+
streamer = streamer if streamer else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
821 |
|
822 |
if stream_flag:
|
823 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|