ahmed-masry commited on
Commit
e145bd5
·
1 Parent(s): e287cff

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -0
README.md CHANGED
@@ -1,3 +1,90 @@
1
  ---
2
  license: gpl-3.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: gpl-3.0
3
+ language:
4
+ - en
5
  ---
6
+ # TL;DR
7
+
8
+ The abstract of the paper states that:
9
+ > Charts are very popular for analyzing data, visualizing key insights and answering complex reasoning questions about data. To facilitate chart-based data analysis using natural language, several downstream tasks have been introduced recently such as chart question answering and chart summarization. However, most of the methods that solve these tasks use pretraining on language or vision-language tasks that do not attempt to explicitly model the structure of the charts (e.g., how data is visually encoded and how chart elements are related to each other). To address this, we first build a large corpus of charts covering a wide variety of topics and visual styles. We then present UniChart, a pretrained model for chart comprehension and reasoning. UniChart encodes the relevant text, data, and visual elements of charts and then uses a chart-grounded text decoder to generate the expected output in natural language. We propose several chart-specific pretraining tasks that include: (i) low-level tasks to extract the visual elements (e.g., bars, lines) and data from charts, and (ii) high-level tasks to acquire chart understanding and reasoning skills. We find that pretraining the model on a large corpus with chart-specific low- and high-level tasks followed by finetuning on three down-streaming tasks results in state-of-the-art performance on three downstream tasks.
10
+
11
+ # Web Demo
12
+ If you wish to quickly try our models, you can access our public web demoes hosted on the Hugging Face Spaces platform with a friendly interface!
13
+
14
+ | Tasks | Web Demo |
15
+ | ------------- | ------------- |
16
+ | Base Model (Best for Chart Summarization and Data Table Generation) | [UniChart-Base](https://huggingface.co/spaces/ahmed-masry/UniChart-Base) |
17
+ | Chart Question Answering | [UniChart-ChartQA](https://huggingface.co/spaces/ahmed-masry/UniChart-ChartQA) |
18
+
19
+ The input prompt for Chart summarization is **\<summarize_chart\>** and Data Table Generation is **\<extract_data_table\>**
20
+
21
+
22
+ # Inference
23
+ You can easily use our models for inference with the huggingface library!
24
+ You just need to do the following:
25
+ 1. Change _model_name_ to your prefered checkpoint.
26
+ 2. Chage the _imag_path_ to your chart example image path on your system
27
+ 3. Write the _input_prompt_ based on your prefered task as shown in the table below.
28
+
29
+ | Task | Input Prompt |
30
+ | ------------- | ------------- |
31
+ | Chart Question Answering | \<chartqa\> question \<s_answer\> |
32
+ | Open Chart Question Answering | \<opencqa\> question \<s_answer\> |
33
+ | Chart Summarization | \<summarize_chart\> \<s_answer\> |
34
+ | Data Table Extraction | \<extract_data_table\> \<s_answer\> |
35
+
36
+ ```
37
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
38
+ from PIL import Image
39
+ import torch, os, re
40
+
41
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')
42
+
43
+ model_name = "ahmed-masry/unichart-chartqa-960"
44
+ image_path = "/content/chart_example_1.png"
45
+ input_prompt = "<chartqa> What is the lowest value in blue bar? <s_answer>"
46
+
47
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
48
+ processor = DonutProcessor.from_pretrained(model_name)
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ model.to(device)
51
+
52
+ image = Image.open(image_path).convert("RGB")
53
+ decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
54
+ pixel_values = processor(image, return_tensors="pt").pixel_values
55
+
56
+ outputs = model.generate(
57
+ pixel_values.to(device),
58
+ decoder_input_ids=decoder_input_ids.to(device),
59
+ max_length=model.decoder.config.max_position_embeddings,
60
+ early_stopping=True,
61
+ pad_token_id=processor.tokenizer.pad_token_id,
62
+ eos_token_id=processor.tokenizer.eos_token_id,
63
+ use_cache=True,
64
+ num_beams=4,
65
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
66
+ return_dict_in_generate=True,
67
+ )
68
+ sequence = processor.batch_decode(outputs.sequences)[0]
69
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
70
+ sequence = sequence.split("<s_answer>")[1].strip()
71
+ print(sequence)
72
+
73
+ ```
74
+
75
+ # Contact
76
+ If you have any questions about this work, please contact **[Ahmed Masry](https://ahmedmasryku.github.io/)** using the following email addresses: **[email protected]** or **[email protected]**.
77
+
78
+ # Reference
79
+ Please cite our paper if you use our models or dataset in your research.
80
+
81
+ ```
82
+ @misc{masry2023unichart,
83
+ title={UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning},
84
+ author={Ahmed Masry and Parsa Kavehzadeh and Xuan Long Do and Enamul Hoque and Shafiq Joty},
85
+ year={2023},
86
+ eprint={2305.14761},
87
+ archivePrefix={arXiv},
88
+ primaryClass={cs.CL}
89
+ }
90
+ ```