marioggil commited on
Commit
1fa72e2
·
1 Parent(s): 96be5b8
Files changed (1) hide show
  1. app.py +74 -2
app.py CHANGED
@@ -30,15 +30,87 @@ def ClassificateDocs(pathimage):
30
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
31
  return processor.token2json(sequence)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  demo = gr.Blocks()
34
 
35
- gradio_app = gr.Interface(
36
  fn=ClassificateDocs,
37
  inputs=[
38
  gr.Image(type='filepath')
39
  ],
40
  outputs="text",
41
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  if __name__ == "__main__":
44
- gradio_app.launch()
 
30
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
31
  return processor.token2json(sequence)
32
 
33
+ processor_prs= DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
34
+ model_prs = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
35
+
36
+ def ProcessBill(pathimage ):
37
+ image = Image.open(pathimage)
38
+ pixel_values = processor_prs(image, return_tensors="pt").pixel_values
39
+ task_prompt = "<s_cord-v2>"
40
+ decoder_input_ids = processor_prs.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ model_prs.to(device)
43
+ outputs = model_prs.generate(pixel_values.to(device),
44
+ decoder_input_ids=decoder_input_ids.to(device),
45
+ max_length=model_prs.decoder.config.max_position_embeddings,
46
+ early_stopping=True,
47
+ pad_token_id=processor_prs.tokenizer.pad_token_id,
48
+ eos_token_id=processor_prs.tokenizer.eos_token_id,
49
+ use_cache=True,
50
+ num_beams=1,
51
+ bad_words_ids=[[processor_prs.tokenizer.unk_token_id]],
52
+ return_dict_in_generate=True,
53
+ output_scores=True,)
54
+ sequence = processor_prs.batch_decode(outputs.sequences)[0]
55
+ sequence = sequence.replace(processor_prs.tokenizer.eos_token, "").replace(processor_prs.tokenizer.pad_token, "")
56
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
57
+ return processor_prs.token2json(sequence)
58
+
59
+ processor_qa= DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
60
+ model_qa = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
61
+ def QAsBill(pathimage,question="When is the coffee break?" ):
62
+ image = Image.open(pathimage)
63
+ pixel_values = processor_qa(image, return_tensors="pt").pixel_values
64
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
65
+ prompt = task_prompt.replace("{user_input}", question)
66
+ decoder_input_ids = processor_qa.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
67
+
68
+ device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ model_qa.to(device)
70
+ outputs = model_qa.generate(
71
+ pixel_values.to(device),
72
+ decoder_input_ids=decoder_input_ids.to(device),
73
+ max_length=model.decoder.config.max_position_embeddings,
74
+ pad_token_id=processor_qa.tokenizer.pad_token_id,
75
+ eos_token_id=processor_qa.tokenizer.eos_token_id,
76
+ use_cache=True,
77
+ bad_words_ids=[[processor_qa.tokenizer.unk_token_id]],
78
+ return_dict_in_generate=True,
79
+ )
80
+
81
+ sequence = processor_qa.batch_decode(outputs.sequences)[0]
82
+ sequence = sequence.replace(processor_qa.tokenizer.eos_token, "").replace(processor._qatokenizer.pad_token, "")
83
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
84
+ return processor_qa.token2json(sequence)
85
+
86
+
87
  demo = gr.Blocks()
88
 
89
+ gradio_app_cls = gr.Interface(
90
  fn=ClassificateDocs,
91
  inputs=[
92
  gr.Image(type='filepath')
93
  ],
94
  outputs="text",
95
  )
96
+ gradio_app_prs = gr.Interface(
97
+ fn=ProcessBill,
98
+ inputs=[
99
+ gr.Image(type='filepath')
100
+ ],
101
+ outputs="text",
102
+ )
103
+ gradio_app_qa = gr.Interface(
104
+ fn=QAsBill,
105
+ inputs=[
106
+ gr.Image(type='filepath'),
107
+ gr.Text()
108
+ ],
109
+ outputs="text",
110
+ )
111
+
112
+
113
+ demo = gr.TabbedInterface([gradio_app_cls, gradio_app_prs,gradio_app_qa], ["class", "parse","QA"])
114
 
115
  if __name__ == "__main__":
116
+ demo.launch()