Hanzoe commited on
Commit
8a83f83
2 Parent(s): 5d03dd3 c2dcab0

Merge pull request #1 from binary-husky/master

Browse files
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug 简述**
11
+
12
+ **Screen Shot 截图**
13
+
14
+ **Terminal Traceback 终端traceback(如果有)**
15
+
16
+
17
+ Before submitting an issue 提交issue之前:
18
+ - Please try to upgrade your code. 如果您的代码不是最新的,建议您先尝试更新代码
19
+ - Please check project wiki for common problem solutions.项目[wiki](https://github.com/binary-husky/chatgpt_academic/wiki)有一些常见问题的解决方法
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+
.gitignore CHANGED
@@ -139,4 +139,6 @@ config_private.py
139
  gpt_log
140
  private.md
141
  private_upload
142
- other_llms
 
 
 
139
  gpt_log
140
  private.md
141
  private_upload
142
+ other_llms
143
+ cradle*
144
+ debug*
README.md CHANGED
@@ -2,29 +2,19 @@
2
 
3
  # ChatGPT 学术优化
4
 
5
- **如果喜欢这个项目,请给它一个Star;如果你发明了更好用的学术快捷键,欢迎发issue或者pull requests**
6
 
7
- If you like this project, please give it a Star. If you've come up with more useful academic shortcuts, feel free to open an issue or pull request.
8
-
9
- ```
10
- 代码中参考了很多其他优秀项目中的设计,主要包括:
11
-
12
- # 借鉴项目1:借鉴了ChuanhuChatGPT中读取OpenAI json的方法、记录历史问询记录的方法以及gradio queue的使用技巧
13
- https://github.com/GaiZhenbiao/ChuanhuChatGPT
14
-
15
- # 借鉴项目2:借鉴了mdtex2html中公式处理的方法
16
- https://github.com/polarwinkel/mdtex2html
17
-
18
- 项目使用OpenAI的gpt-3.5-turbo模型,期待gpt-4早点放宽门槛😂
19
- ```
20
 
21
  > **Note**
22
  >
23
- > 1.请注意只有“红颜色”标识的函数插件(按钮)才支持读取文件。目前暂不能完善地支持pdf格式文献的翻译解读,尚不支持word格式文件的读取。
24
  >
25
- > 2.本项目中每个文件的功能都在`project_self_analysis.md`详细说明。随着版本的迭代,您也可以随时自行点击相关函数插件,调用GPT重新生成项目的自我解析报告。
26
  >
27
- > 3.如果您不太习惯部分中文命名的函数,您可以随时点击相关函数插件,调用GPT一键生成纯英文的项目源代码。
 
 
28
 
29
  <div align="center">
30
 
@@ -33,24 +23,30 @@ https://github.com/polarwinkel/mdtex2html
33
  一键润色 | 支持一键润色、一键查找论文语法错误
34
  一键中英互译 | 一键中英互译
35
  一键代码解释 | 可以正确显示代码、解释代码
36
- 自定义快捷键 | 支持自定义快捷键
37
- 配置代理服务器 | 支持配置代理服务器
38
- 模块化设计 | 支持自定义高阶的实验性功能
39
- 自我程序剖析 | [实验性功能] 一键读懂本项目的源代码
40
- 程序剖析 | [实验性功能] 一键可以剖析其他Python/C++项目
41
- 读论文 | [实验性功能] 一键解读latex论文全文并生成摘要
42
- 批量注释生成 | [实验性功能] 一键批量生成函数注释
43
- chat分析报告生成 | [实验性功能] 运行后自动生成总结汇报
 
 
44
  公式显示 | 可以同时显示公式的tex形式和渲染形式
45
  图片显示 | 可以在markdown中显示图片
 
46
  支持GPT输出的markdown表格 | 可以输出支持GPT的markdown表格
 
 
47
  …… | ……
48
 
49
  </div>
50
 
 
51
  - 新界面
52
  <div align="center">
53
- <img src="https://user-images.githubusercontent.com/96192199/228600410-7d44e34f-63f1-4046-acb8-045cb05da8bb.png" width="700" >
54
  </div>
55
 
56
 
@@ -73,10 +69,11 @@ chat分析报告生成 | [实验性功能] 运行后自动生成总结汇报
73
 
74
  - 如果输出包含公式,会同时以tex形式和渲染形式显示,方便复制和阅读
75
  <div align="center">
76
- <img src="img/demo.jpg" width="500" >
77
  </div>
78
 
79
 
 
80
  - 懒得看项目代码?整个工程直接给chatgpt炫嘴里
81
  <div align="center">
82
  <img src="https://user-images.githubusercontent.com/96192199/226935232-6b6a73ce-8900-4aee-93f9-733c7e6fef53.png" width="700" >
@@ -84,45 +81,43 @@ chat分析报告生成 | [实验性功能] 运行后自动生成总结汇报
84
 
85
  ## 直接运行 (Windows, Linux or MacOS)
86
 
87
- 下载项目
88
-
89
  ```sh
90
  git clone https://github.com/binary-husky/chatgpt_academic.git
91
  cd chatgpt_academic
92
  ```
93
 
94
- 我们建议将`config.py`复制为`config_private.py`并将后者用作个性化配置文件以避免`config.py`中的变更影响你的使用或不小心将包含你的OpenAI API KEY的`config.py`提交至本项目。
95
-
96
- ```sh
97
- cp config.py config_private.py
98
- ```
99
 
100
- 在`config_private.py`中,配置 海外Proxy 和 OpenAI API KEY
101
  ```
102
- 1. 如果你在国内,需要设置海外代理才能够使用 OpenAI API,你可以通过 config.py 文件来进行设置。
103
  2. 配置 OpenAI API KEY。你需要在 OpenAI 官网上注册并获取 API KEY。一旦你拿到了 API KEY,在 config.py 文件里配置好即可。
 
104
  ```
105
- 安装依赖
106
 
 
 
107
  ```sh
108
- python -m pip install -r requirements.txt
109
- ```
110
 
111
- 或者,如果你希望使用`conda`
 
 
 
112
 
113
- ```sh
114
- conda create -n gptac 'gradio>=3.23' requests
115
- conda activate gptac
116
- python3 -m pip install mdtex2html
117
  ```
118
 
119
- 运行
120
-
121
  ```sh
122
  python main.py
123
  ```
124
 
125
- 测试实验性功能
126
  ```
127
  - 测试C++项目头文件分析
128
  input区域 输入 `./crazy_functions/test_project/cpp/libJPG` , 然后点击 "[实验] 解析整个C++项目(input输入项目根路径)"
@@ -136,8 +131,6 @@ python main.py
136
  点击 "[实验] 实验功能函数模板"
137
  ```
138
 
139
- 与代理网络有关的issue(网络超时、代理不起作用)汇总到 https://github.com/binary-husky/chatgpt_academic/issues/1
140
-
141
  ## 使用docker (Linux)
142
 
143
  ``` sh
@@ -145,7 +138,7 @@ python main.py
145
  git clone https://github.com/binary-husky/chatgpt_academic.git
146
  cd chatgpt_academic
147
  # 配置 海外Proxy 和 OpenAI API KEY
148
- config.py
149
  # 安装
150
  docker build -t gpt-academic .
151
  # 运行
@@ -166,20 +159,12 @@ input区域 输入 ./crazy_functions/test_project/python/dqn , 然后点击 "[
166
 
167
  ```
168
 
169
- ## 使用WSL2(Windows Subsystem for Linux 子系统)
170
- 选择这种方式默认您已经具备一定基本知识,因此不再赘述多余步骤。如果不是这样,您可以从[这里](https://learn.microsoft.com/zh-cn/windows/wsl/about)或GPT处获取更多关于子系统的信息。
171
-
172
- WSL2可以配置使用Windows侧的代理上网,前置步骤可以参考[这里](https://www.cnblogs.com/tuilk/p/16287472.html)
173
- 由于Windows相对WSL2的IP会发生变化,我们需要每次启动前先获取这个IP来保证顺利访问,将config.py中设置proxies的部分更改为如下代码:
174
- ```python
175
- import subprocess
176
- cmd_get_ip = 'grep -oP "(\d+\.)+(\d+)" /etc/resolv.conf'
177
- ip_proxy = subprocess.run(
178
- cmd_get_ip, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True
179
- ).stdout.strip() # 获取windows的IP
180
- proxies = { "http": ip_proxy + ":51837", "https": ip_proxy + ":51837", } # 请自行修改
181
- ```
182
- 在启动main.py后,可以在windows浏览器中访问服务。至此测试、使用与上面其他方法无异。
183
 
184
 
185
  ## 自定义新的便捷按钮(学术快捷键自定义)
@@ -204,7 +189,7 @@ proxies = { "http": ip_proxy + ":51837", "https": ip_proxy + ":51837", } # 请
204
  如果你发明了更好用的学术快捷键,欢迎发issue或者pull requests!
205
 
206
  ## 配置代理
207
-
208
  在```config.py```中修改端口与代理软件对应
209
 
210
  <div align="center">
@@ -216,6 +201,8 @@ proxies = { "http": ip_proxy + ":51837", "https": ip_proxy + ":51837", } # 请
216
  ```
217
  python check_proxy.py
218
  ```
 
 
219
 
220
  ## 兼容性测试
221
 
@@ -259,13 +246,44 @@ python check_proxy.py
259
 
260
  ### 模块化功能设计
261
  <div align="center">
262
- <img src="https://user-images.githubusercontent.com/96192199/227504981-4c6c39c0-ae79-47e6-bffe-0e6442d9da65.png" height="400" >
263
  <img src="https://user-images.githubusercontent.com/96192199/227504931-19955f78-45cd-4d1c-adac-e71e50957915.png" height="400" >
264
  </div>
265
 
266
- ## Todo:
267
 
268
- - (Top Priority) 调用另一个开源项目text-generation-webui的web接口,使用其他llm模型
269
- - 总结大工程源代码时,文本过长、token溢出的问题(目前的方法是直接二分丢弃处理溢出,过于粗暴,有效信息大量丢失)
270
- - UI不够美观
 
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # ChatGPT 学术优化
4
 
5
+ **如果喜欢这个项目,请给它一个Star;如果你发明了更好用的快捷键或函数插件,欢迎发issue或者pull requests(dev分支)**
6
 
7
+ If you like this project, please give it a Star. If you've come up with more useful academic shortcuts or functional plugins, feel free to open an issue or pull request (to `dev` branch).
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  > **Note**
10
  >
11
+ > 1.请注意只有“红颜色”标识的函数插件(按钮)才支持读取文件。目前对pdf/word格式文件的支持插件正在逐步完善中,需要更多developer的帮助。
12
  >
13
+ > 2.本项目中每个文件的功能都在自译解[`self_analysis.md`](https://github.com/binary-husky/chatgpt_academic/wiki/chatgpt-academic%E9%A1%B9%E7%9B%AE%E8%87%AA%E8%AF%91%E8%A7%A3%E6%8A%A5%E5%91%8A)详细说明。随着版本的迭代,您也可以随时自行点击相关函数插件,调用GPT重新生成项目的自我解析报告。常见问题汇总在[`wiki`](https://github.com/binary-husky/chatgpt_academic/wiki/%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98)当中。
14
  >
15
+ > 3.如果您不太习惯部分中文命名的函数、注释或者界面,您可以随时点击相关函数插件,调用ChatGPT一键生成纯英文的项目源代码。
16
+ >
17
+ > 4.项目使用OpenAI的gpt-3.5-turbo模型,期待gpt-4早点放宽门槛😂
18
 
19
  <div align="center">
20
 
 
23
  一键润色 | 支持一键润色、一键查找论文语法错误
24
  一键中英互译 | 一键中英互译
25
  一键代码解释 | 可以正确显示代码、解释代码
26
+ [自定义快捷键](https://www.bilibili.com/video/BV14s4y1E7jN) | 支持自定义快捷键
27
+ [配置代理服务器](https://www.bilibili.com/video/BV1rc411W7Dr) | 支持配置代理服务器
28
+ 模块化设计 | 支持自定义高阶的实验性功能与[函数插件],插件支持[热更新](https://github.com/binary-husky/chatgpt_academic/wiki/%E5%87%BD%E6%95%B0%E6%8F%92%E4%BB%B6%E6%8C%87%E5%8D%97)
29
+ [自我程序剖析](https://www.bilibili.com/video/BV1cj411A7VW) | [函数插件] 一键读懂本项目的源代码
30
+ [程序剖析](https://www.bilibili.com/video/BV1cj411A7VW) | [函数插件] 一键可以剖析其他Python/C/C++/Java项目树
31
+ 读论文 | [函数插件] 一键解读latex论文全文并生成摘要
32
+ 批量注释生成 | [函数插件] 一键批量生成函数注释
33
+ chat分析报告生成 | [函数插件] 运行后自动生成总结汇报
34
+ [arxiv小助手](https://www.bilibili.com/video/BV1LM4y1279X) | [函数插件] 输入arxiv文章url即可一键翻译摘要+下载PDF
35
+ [PDF论文全文翻译功能](https://www.bilibili.com/video/BV1KT411x7Wn) | [函数插件] PDF论文提取题目&摘要+翻译全文(多线程)
36
  公式显示 | 可以同时显示公式的tex形式和渲染形式
37
  图片显示 | 可以在markdown中显示图片
38
+ 多线程函数插件支持 | 支持多线调用chatgpt,一键处理海量文本或程序
39
  支持GPT输出的markdown表格 | 可以输出支持GPT的markdown表格
40
+ 启动暗色gradio[主题](https://github.com/binary-husky/chatgpt_academic/issues/173) | 在浏览器url后面添加```/?__dark-theme=true```可以切换dark主题
41
+ huggingface免科学上网[在线体验](https://huggingface.co/spaces/qingxu98/gpt-academic) | 登陆huggingface后复制[此空间](https://huggingface.co/spaces/qingxu98/gpt-academic)
42
  …… | ……
43
 
44
  </div>
45
 
46
+ <!-- - 新界面(左:master主分支, 右:dev开发前沿) -->
47
  - 新界面
48
  <div align="center">
49
+ <img src="https://user-images.githubusercontent.com/96192199/230361456-61078362-a966-4eb5-b49e-3c62ef18b860.gif" width="700" >
50
  </div>
51
 
52
 
 
69
 
70
  - 如果输出包含公式,会同时以tex形式和渲染形式显示,方便复制和阅读
71
  <div align="center">
72
+ <img src="https://user-images.githubusercontent.com/96192199/230598842-1d7fcddd-815d-40ee-af60-baf488a199df.png" width="700" >
73
  </div>
74
 
75
 
76
+
77
  - 懒得看项目代码?整个工程直接给chatgpt炫嘴里
78
  <div align="center">
79
  <img src="https://user-images.githubusercontent.com/96192199/226935232-6b6a73ce-8900-4aee-93f9-733c7e6fef53.png" width="700" >
 
81
 
82
  ## 直接运行 (Windows, Linux or MacOS)
83
 
84
+ ### 1. 下载项目
 
85
  ```sh
86
  git clone https://github.com/binary-husky/chatgpt_academic.git
87
  cd chatgpt_academic
88
  ```
89
 
90
+ ### 2. 配置API_KEY和代理设置
 
 
 
 
91
 
92
+ 在`config.py`中,配置 海外Proxy 和 OpenAI API KEY,说明如下
93
  ```
94
+ 1. 如果你在国内,需要设置海外代理才能够顺利使用 OpenAI API,设置方法请仔细阅读config.py(1.修改其中的USE_PROXY为True; 2.按照说明修改其中的proxies)。
95
  2. 配置 OpenAI API KEY。你需要在 OpenAI 官网上注册并获取 API KEY。一旦你拿到了 API KEY,在 config.py 文件里配置好即可。
96
+ 3. 与代理网络有关的issue(网络超时、代理不起作用)汇总到 https://github.com/binary-husky/chatgpt_academic/issues/1
97
  ```
98
+ (P.S. 程序运行时会优先检查是否存在名为`config_private.py`的私密配置文件,并用其中的配置覆盖`config.py`的同名配置。因此,如果您能理解我们的配置读取逻辑,我们强烈建议您在`config.py`旁边创建一个名为`config_private.py`的新配置文件,并把`config.py`中的配置转移(复制)到`config_private.py`中。`config_private.py`不受git管控,可以让您的隐私信息更加安全。)
99
 
100
+
101
+ ### 3. 安装依赖
102
  ```sh
103
+ # (选择一)推荐
104
+ python -m pip install -r requirements.txt
105
 
106
+ # (选择二)如果您使用anaconda,步骤也是类似的:
107
+ # (选择二.1)conda create -n gptac_venv python=3.11
108
+ # (选择二.2)conda activate gptac_venv
109
+ # (选择二.3)python -m pip install -r requirements.txt
110
 
111
+ # 备注:使用官方pip源或者阿里pip源,其他pip源(如一些大学的pip)有可能出问题,临时换源方法:
112
+ # python -m pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
 
 
113
  ```
114
 
115
+ ### 4. 运行
 
116
  ```sh
117
  python main.py
118
  ```
119
 
120
+ ### 5. 测试实验性功能
121
  ```
122
  - 测试C++项目头文件分析
123
  input区域 输入 `./crazy_functions/test_project/cpp/libJPG` , 然后点击 "[实验] 解析整个C++项目(input输入项目根路径)"
 
131
  点击 "[实验] 实验功能函数模板"
132
  ```
133
 
 
 
134
  ## 使用docker (Linux)
135
 
136
  ``` sh
 
138
  git clone https://github.com/binary-husky/chatgpt_academic.git
139
  cd chatgpt_academic
140
  # 配置 海外Proxy 和 OpenAI API KEY
141
+ 用任意文本编辑器编辑 config.py
142
  # 安装
143
  docker build -t gpt-academic .
144
  # 运行
 
159
 
160
  ```
161
 
162
+ ## 其他部署方式
163
+ - 使用WSL2(Windows Subsystem for Linux 子系统)
164
+ 请访问[部署wiki-1](https://github.com/binary-husky/chatgpt_academic/wiki/%E4%BD%BF%E7%94%A8WSL2%EF%BC%88Windows-Subsystem-for-Linux-%E5%AD%90%E7%B3%BB%E7%BB%9F%EF%BC%89%E9%83%A8%E7%BD%B2)
165
+
166
+ - nginx远程部署
167
+ 请访问[部署wiki-2](https://github.com/binary-husky/chatgpt_academic/wiki/%E8%BF%9C%E7%A8%8B%E9%83%A8%E7%BD%B2%E7%9A%84%E6%8C%87%E5%AF%BC)
 
 
 
 
 
 
 
 
168
 
169
 
170
  ## 自定义新的便捷按钮(学术快捷键自定义)
 
189
  如果你发明了更好用的学术快捷键,欢迎发issue或者pull requests!
190
 
191
  ## 配置代理
192
+ ### 方法一:常规方法
193
  在```config.py```中修改端口与代理软件对应
194
 
195
  <div align="center">
 
201
  ```
202
  python check_proxy.py
203
  ```
204
+ ### 方法二:纯新手教程
205
+ [纯新手教程](https://github.com/binary-husky/chatgpt_academic/wiki/%E4%BB%A3%E7%90%86%E8%BD%AF%E4%BB%B6%E9%97%AE%E9%A2%98%E7%9A%84%E6%96%B0%E6%89%8B%E8%A7%A3%E5%86%B3%E6%96%B9%E6%B3%95%EF%BC%88%E6%96%B9%E6%B3%95%E5%8F%AA%E9%80%82%E7%94%A8%E4%BA%8E%E6%96%B0%E6%89%8B%EF%BC%89)
206
 
207
  ## 兼容性测试
208
 
 
246
 
247
  ### 模块化功能设计
248
  <div align="center">
249
+ <img src="https://user-images.githubusercontent.com/96192199/229288270-093643c1-0018-487a-81e6-1d7809b6e90f.png" height="400" >
250
  <img src="https://user-images.githubusercontent.com/96192199/227504931-19955f78-45cd-4d1c-adac-e71e50957915.png" height="400" >
251
  </div>
252
 
 
253
 
254
+ ### 源代码转译英文
255
+
256
+ <div align="center">
257
+ <img src="https://user-images.githubusercontent.com/96192199/229720562-fe6c3508-6142-4635-a83d-21eb3669baee.png" height="400" >
258
+ </div>
259
 
260
+ ## Todo 与 版本规划:
261
+
262
+ - version 3 (Todo):
263
+ - - 支持gpt4和其他更多llm
264
+ - version 2.4+ (Todo):
265
+ - - 总结大工程源代码时文本过长、token溢出的问题
266
+ - - 实现项目打包部署
267
+ - - 函数插件参数接口优化
268
+ - - 自更新
269
+ - version 2.4: (1)新增PDF全文翻译功能; (2)新增输入区切换位置的功能; (3)新增垂直布局选项; (4)多线程函数插件优化。
270
+ - version 2.3: 增强多线程交互性
271
+ - version 2.2: 函数插件支持热重载
272
+ - version 2.1: 可折叠式布局
273
+ - version 2.0: 引入模块化函数插件
274
+ - version 1.0: 基础功能
275
+
276
+ ## 参考与学习
277
+
278
+
279
+ ```
280
+ 代码中参考了很多其他优秀项目中的设计,主要包括:
281
+
282
+ # 借鉴项目1:借鉴了ChuanhuChatGPT中读取OpenAI json的方法、记录历史问询记录的方法以及gradio queue的使用技巧
283
+ https://github.com/GaiZhenbiao/ChuanhuChatGPT
284
+
285
+ # 借鉴项目2:借鉴了mdtex2html中公式处理的方法
286
+ https://github.com/polarwinkel/mdtex2html
287
+
288
+
289
+ ```
check_proxy.py CHANGED
@@ -3,7 +3,8 @@ def check_proxy(proxies):
3
  import requests
4
  proxies_https = proxies['https'] if proxies is not None else '无'
5
  try:
6
- response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
 
7
  data = response.json()
8
  print(f'查询代理的地理位置,返回的结果是{data}')
9
  if 'country_name' in data:
@@ -19,9 +20,36 @@ def check_proxy(proxies):
19
  return result
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if __name__ == '__main__':
23
- import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
 
24
  from toolbox import get_conf
25
  proxies, = get_conf('proxies')
26
  check_proxy(proxies)
27
-
 
3
  import requests
4
  proxies_https = proxies['https'] if proxies is not None else '无'
5
  try:
6
+ response = requests.get("https://ipapi.co/json/",
7
+ proxies=proxies, timeout=4)
8
  data = response.json()
9
  print(f'查询代理的地理位置,返回的结果是{data}')
10
  if 'country_name' in data:
 
20
  return result
21
 
22
 
23
+ def auto_update():
24
+ from toolbox import get_conf
25
+ import requests
26
+ import time
27
+ import json
28
+ proxies, = get_conf('proxies')
29
+ response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version",
30
+ proxies=proxies, timeout=1)
31
+ remote_json_data = json.loads(response.text)
32
+ remote_version = remote_json_data['version']
33
+ if remote_json_data["show_feature"]:
34
+ new_feature = "新功能:" + remote_json_data["new_feature"]
35
+ else:
36
+ new_feature = ""
37
+ with open('./version', 'r', encoding='utf8') as f:
38
+ current_version = f.read()
39
+ current_version = json.loads(current_version)['version']
40
+ if (remote_version - current_version) >= 0.05:
41
+ print(
42
+ f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
43
+ print('Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
44
+ time.sleep(3)
45
+ return
46
+ else:
47
+ return
48
+
49
+
50
  if __name__ == '__main__':
51
+ import os
52
+ os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
53
  from toolbox import get_conf
54
  proxies, = get_conf('proxies')
55
  check_proxy(proxies)
 
config.py CHANGED
@@ -1,23 +1,31 @@
1
- # API_KEY = "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" key无效
2
  API_KEY = "sk-此处填API密钥"
3
- API_URL = "https://api.openai.com/v1/chat/completions"
4
 
5
- # 改为True应用代理
6
  USE_PROXY = False
7
  if USE_PROXY:
8
-
9
- # 填写格式是 [协议]:// [地址] :[端口] ,
10
  # 例如 "socks5h://localhost:11284"
11
- # [协议] 常见协议无非socks5h/http,例如 v2***s** 的默认本地协议是socks5hcl**h 的默认本地协议是http
12
  # [地址] 懂的都懂,不懂就填localhost或者127.0.0.1肯定错不了(localhost意思是代理软件安装在本机上)
13
- # [端口] 在代理软件的设置里,不同的代理软件界面不一样,但端口号都应该在最显眼的位置上
14
 
15
  # 代理网络的地址,打开你的科学上网软件查看代理的协议(socks5/http)、地址(localhost)和端口(11284)
16
- proxies = { "http": "socks5h://localhost:11284", "https": "socks5h://localhost:11284", }
17
- print('网络代理状态:运行。')
 
 
 
18
  else:
19
  proxies = None
20
- print('网络代理状态:未配置。无代理状态下很可能无法访问。')
 
 
 
 
 
 
 
21
 
22
  # 发送请求到OpenAI后,等待多久判定为超时
23
  TIMEOUT_SECONDS = 25
@@ -28,11 +36,15 @@ WEB_PORT = -1
28
  # 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制
29
  MAX_RETRY = 2
30
 
31
- # 选择的OpenAI模型是(gpt4现在只对申请成功的人开放)
32
  LLM_MODEL = "gpt-3.5-turbo"
33
 
 
 
 
34
  # 设置并行使用的线程数
35
  CONCURRENT_COUNT = 100
36
 
37
- # 设置用户名和密码
38
- AUTHENTICATION = [] # [("username", "password"), ("username2", "password2"), ...]
 
 
1
+ # [step 1]>> 例如: API_KEY = "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" (此key无效)
2
  API_KEY = "sk-此处填API密钥"
 
3
 
4
+ # [step 2]>> 改为True应用代理,如果直接在海外服务器部署,此处不修改
5
  USE_PROXY = False
6
  if USE_PROXY:
7
+ # 填写格式是 [协议]:// [地址] :[端口],填写之前不要忘记把USE_PROXY改成True,如果直接在海外服务器部署,此处不修改
 
8
  # 例如 "socks5h://localhost:11284"
9
+ # [协议] 常见协议无非socks5h/http; 例如 v2**yss* 的默认本地协议是socks5h; 而cl**h 的默认本地协议是http
10
  # [地址] 懂的都懂,不懂就填localhost或者127.0.0.1肯定错不了(localhost意思是代理软件安装在本机上)
11
+ # [端口] 在代理软件的设置里找。虽然不同的代理软件界面不一样,但端口号都应该在最显眼的位置上
12
 
13
  # 代理网络的地址,打开你的科学上网软件查看代理的协议(socks5/http)、地址(localhost)和端口(11284)
14
+ proxies = {
15
+ # [协议]:// [地址] :[端口]
16
+ "http": "socks5h://localhost:11284",
17
+ "https": "socks5h://localhost:11284",
18
+ }
19
  else:
20
  proxies = None
21
+
22
+
23
+ # [step 3]>> 以下配置可以优化体验,但大部分场合下并不需要修改
24
+ # 对话窗的高度
25
+ CHATBOT_HEIGHT = 1115
26
+
27
+ # 窗口布局
28
+ LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局)
29
 
30
  # 发送请求到OpenAI后,等待多久判定为超时
31
  TIMEOUT_SECONDS = 25
 
36
  # 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制
37
  MAX_RETRY = 2
38
 
39
+ # OpenAI模型选择是(gpt4现在只对申请成功的人开放)
40
  LLM_MODEL = "gpt-3.5-turbo"
41
 
42
+ # OpenAI的API_URL
43
+ API_URL = "https://api.openai.com/v1/chat/completions"
44
+
45
  # 设置并行使用的线程数
46
  CONCURRENT_COUNT = 100
47
 
48
+ # 设置用户名和密码(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个)
49
+ # [("username", "password"), ("username2", "password2"), ...]
50
+ AUTHENTICATION = []
functional.py → core_functional.py RENAMED
@@ -4,29 +4,38 @@
4
  # 默认按钮颜色是 secondary
5
  from toolbox import clear_line_break
6
 
7
- def get_functionals():
 
8
  return {
9
  "英语学术润色": {
10
  # 前言
11
  "Prefix": r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, " +
12
- r"improve the spelling, grammar, clarity, concision and overall readability. When neccessary, rewrite the whole sentence. " +
13
  r"Furthermore, list all modification and explain the reasons to do so in markdown table." + "\n\n",
14
- # 后语
15
  "Suffix": r"",
16
  "Color": r"secondary", # 按钮颜色
17
  },
18
  "中文学术润色": {
19
- "Prefix": r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," +
20
  r"同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请编辑以下文本" + "\n\n",
21
  "Suffix": r"",
22
  },
23
  "查找语法错误": {
24
- "Prefix": r"Below is a paragraph from an academic paper. " +
25
- r"Can you help me ensure that the grammar and the spelling is correct? " +
26
- r"Do not try to polish the text, if no mistake is found, tell me that this paragraph is good." +
27
- r"If you find grammar or spelling mistakes, please list mistakes you find in a two-column markdown table, " +
28
  r"put the original text the first column, " +
29
- r"put the corrected text in the second column and highlight the key words you fixed." + "\n\n",
 
 
 
 
 
 
 
 
 
30
  "Suffix": r"",
31
  "PreProcess": clear_line_break, # 预处理:清除换行符
32
  },
@@ -34,9 +43,17 @@ def get_functionals():
34
  "Prefix": r"Please translate following sentence to English:" + "\n\n",
35
  "Suffix": r"",
36
  },
37
- "学术中译英": {
38
- "Prefix": r"Please translate following sentence to English with academic writing, and provide some related authoritative examples:" + "\n\n",
39
- "Suffix": r"",
 
 
 
 
 
 
 
 
40
  },
41
  "英译中": {
42
  "Prefix": r"请翻译成中文:" + "\n\n",
 
4
  # 默认按钮颜色是 secondary
5
  from toolbox import clear_line_break
6
 
7
+
8
+ def get_core_functions():
9
  return {
10
  "英语学术润色": {
11
  # 前言
12
  "Prefix": r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, " +
13
+ r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. " +
14
  r"Furthermore, list all modification and explain the reasons to do so in markdown table." + "\n\n",
15
+ # 后语
16
  "Suffix": r"",
17
  "Color": r"secondary", # 按钮颜色
18
  },
19
  "中文学术润色": {
20
+ "Prefix": r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," +
21
  r"同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请编辑以下文本" + "\n\n",
22
  "Suffix": r"",
23
  },
24
  "查找语法错误": {
25
+ "Prefix": r"Can you help me ensure that the grammar and the spelling is correct? " +
26
+ r"Do not try to polish the text, if no mistake is found, tell me that this paragraph is good." +
27
+ r"If you find grammar or spelling mistakes, please list mistakes you find in a two-column markdown table, " +
 
28
  r"put the original text the first column, " +
29
+ r"put the corrected text in the second column and highlight the key words you fixed.""\n"
30
+ r"Example:""\n"
31
+ r"Paragraph: How is you? Do you knows what is it?""\n"
32
+ r"| Original sentence | Corrected sentence |""\n"
33
+ r"| :--- | :--- |""\n"
34
+ r"| How **is** you? | How **are** you? |""\n"
35
+ r"| Do you **knows** what **is** **it**? | Do you **know** what **it** **is** ? |""\n"
36
+ r"Below is a paragraph from an academic paper. "
37
+ r"You need to report all grammar and spelling mistakes as the example before."
38
+ + "\n\n",
39
  "Suffix": r"",
40
  "PreProcess": clear_line_break, # 预处理:清除换行符
41
  },
 
43
  "Prefix": r"Please translate following sentence to English:" + "\n\n",
44
  "Suffix": r"",
45
  },
46
+ "学术中英互译": {
47
+ "Prefix": r"I want you to act as a scientific English-Chinese translator, " +
48
+ r"I will provide you with some paragraphs in one language " +
49
+ r"and your task is to accurately and academically translate the paragraphs only into the other language. " +
50
+ r"Do not repeat the original provided paragraphs after translation. " +
51
+ r"You should use artificial intelligence tools, " +
52
+ r"such as natural language processing, and rhetorical knowledge " +
53
+ r"and experience about effective writing techniques to reply. " +
54
+ r"I'll give you my paragraphs as follows, tell me what language it is written in, and then translate:" + "\n\n",
55
+ "Suffix": "",
56
+ "Color": "secondary",
57
  },
58
  "英译中": {
59
  "Prefix": r"请翻译成中文:" + "\n\n",
crazy_functional.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
2
+
3
+
4
+ def get_crazy_functions():
5
+ ###################### 第一组插件 ###########################
6
+ # [第一组插件]: 最早期编写的项目插件和一些demo
7
+ from crazy_functions.读文章写摘要 import 读文章写摘要
8
+ from crazy_functions.生成函数注释 import 批量生成函数注释
9
+ from crazy_functions.解析项目源代码 import 解析项目本身
10
+ from crazy_functions.解析项目源代码 import 解析一个Python项目
11
+ from crazy_functions.解析项目源代码 import 解析一个C项目的头文件
12
+ from crazy_functions.解析项目源代码 import 解析一个C项目
13
+ from crazy_functions.解析项目源代码 import 解析一个Golang项目
14
+ from crazy_functions.解析项目源代码 import 解析一个Java项目
15
+ from crazy_functions.解析项目源代码 import 解析一个Rect项目
16
+ from crazy_functions.高级功能函数模板 import 高阶功能模板函数
17
+ from crazy_functions.代码重写为全英文_多线程 import 全项目切换英文
18
+
19
+ function_plugins = {
20
+ "请解析并解构此项目本身(源码自译解)": {
21
+ "AsButton": False, # 加入下拉菜单中
22
+ "Function": HotReload(解析项目本身)
23
+ },
24
+ "解析整个Python项目": {
25
+ "Color": "stop", # 按钮颜色
26
+ "Function": HotReload(解析一个Python项目)
27
+ },
28
+ "解析整个C++项目头文件": {
29
+ "Color": "stop", # 按钮颜色
30
+ "Function": HotReload(解析一个C项目的头文件)
31
+ },
32
+ "解析整个C++项目(.cpp/.h)": {
33
+ "Color": "stop", # 按钮颜色
34
+ "AsButton": False, # 加入下拉菜单中
35
+ "Function": HotReload(解析一个C项目)
36
+ },
37
+ "解析整个Go项目": {
38
+ "Color": "stop", # 按钮颜色
39
+ "AsButton": False, # 加入下拉菜单中
40
+ "Function": HotReload(解析一个Golang项目)
41
+ },
42
+ "解析整个Java项目": {
43
+ "Color": "stop", # 按钮颜色
44
+ "AsButton": False, # 加入下拉菜单中
45
+ "Function": HotReload(解析一个Java项目)
46
+ },
47
+ "解析整个React项目": {
48
+ "Color": "stop", # 按钮颜色
49
+ "AsButton": False, # 加入下拉菜单中
50
+ "Function": HotReload(解析一个Rect项目)
51
+ },
52
+ "读Tex论文写摘要": {
53
+ "Color": "stop", # 按钮颜色
54
+ "Function": HotReload(读文章写摘要)
55
+ },
56
+ "批量生成函数注释": {
57
+ "Color": "stop", # 按钮颜色
58
+ "Function": HotReload(批量生成函数注释)
59
+ },
60
+ "[多线程demo] 把本项目源代码切换成全英文": {
61
+ # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
62
+ "Function": HotReload(全项目切换英文)
63
+ },
64
+ "[函数插件模板demo] 历史上的今天": {
65
+ # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
66
+ "Function": HotReload(高阶功能模板函数)
67
+ },
68
+ }
69
+ ###################### 第二组插件 ###########################
70
+ # [第二组插件]: 经过充分测试,但功能上距离达到完美状态还差一点点
71
+ from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
72
+ from crazy_functions.批量总结PDF文档pdfminer import 批量总结PDF文档pdfminer
73
+ from crazy_functions.总结word文档 import 总结word文档
74
+ from crazy_functions.批量翻译PDF文档_多线程 import 批量翻译PDF文档
75
+
76
+ function_plugins.update({
77
+ "批量翻译PDF文档(多线程)": {
78
+ "Color": "stop",
79
+ "AsButton": True, # 加入下拉菜单中
80
+ "Function": HotReload(批量翻译PDF文档)
81
+ },
82
+ "[仅供开发调试] 批量总结PDF文档": {
83
+ "Color": "stop",
84
+ "AsButton": False, # 加入下拉菜单中
85
+ # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
86
+ "Function": HotReload(批量总结PDF文档)
87
+ },
88
+ "[仅供开发调试] 批量总结PDF文档pdfminer": {
89
+ "Color": "stop",
90
+ "AsButton": False, # 加入下拉菜单中
91
+ "Function": HotReload(批量总结PDF文档pdfminer)
92
+ },
93
+ "批量总结Word文档": {
94
+ "Color": "stop",
95
+ "Function": HotReload(总结word文档)
96
+ },
97
+ })
98
+
99
+ ###################### 第三组插件 ###########################
100
+ # [第三组插件]: 尚未充分测试的函数插件,放在这里
101
+ try:
102
+ from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要
103
+ function_plugins.update({
104
+ "一键下载arxiv论文并翻译摘要(先在input输入编号,如1812.10695)": {
105
+ "Color": "stop",
106
+ "AsButton": False, # 加入下拉菜单中
107
+ "Function": HotReload(下载arxiv论文并翻译摘要)
108
+ }
109
+ })
110
+
111
+ except Exception as err:
112
+ print(f'[下载arxiv论文并翻译摘要] 插件导入失败 {str(err)}')
113
+
114
+ ###################### 第n组插件 ###########################
115
+ return function_plugins
crazy_functions/__init__.py ADDED
File without changes
crazy_functions/crazy_utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def request_gpt_model_in_new_thread_with_ui_alive(inputs, inputs_show_user, top_p, temperature, chatbot, history, sys_prompt, refresh_interval=0.2):
4
+ import time
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from request_llm.bridge_chatgpt import predict_no_ui_long_connection
7
+ # 用户反馈
8
+ chatbot.append([inputs_show_user, ""])
9
+ msg = '正常'
10
+ yield chatbot, [], msg
11
+ executor = ThreadPoolExecutor(max_workers=16)
12
+ mutable = ["", time.time()]
13
+ future = executor.submit(lambda:
14
+ predict_no_ui_long_connection(
15
+ inputs=inputs, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt, observe_window=mutable)
16
+ )
17
+ while True:
18
+ # yield一次以刷新前端页面
19
+ time.sleep(refresh_interval)
20
+ # “喂狗”(看门狗)
21
+ mutable[1] = time.time()
22
+ if future.done():
23
+ break
24
+ chatbot[-1] = [chatbot[-1][0], mutable[0]]
25
+ msg = "正常"
26
+ yield chatbot, [], msg
27
+ return future.result()
28
+
29
+
30
+ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(inputs_array, inputs_show_user_array, top_p, temperature, chatbot, history_array, sys_prompt_array, refresh_interval=0.2, max_workers=10, scroller_max_len=30):
31
+ import time
32
+ from concurrent.futures import ThreadPoolExecutor
33
+ from request_llm.bridge_chatgpt import predict_no_ui_long_connection
34
+ assert len(inputs_array) == len(history_array)
35
+ assert len(inputs_array) == len(sys_prompt_array)
36
+ executor = ThreadPoolExecutor(max_workers=max_workers)
37
+ n_frag = len(inputs_array)
38
+ # 用户反馈
39
+ chatbot.append(["请开始多线程操作。", ""])
40
+ msg = '正常'
41
+ yield chatbot, [], msg
42
+ # 异步原子
43
+ mutable = [["", time.time()] for _ in range(n_frag)]
44
+
45
+ def _req_gpt(index, inputs, history, sys_prompt):
46
+ gpt_say = predict_no_ui_long_connection(
47
+ inputs=inputs, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt, observe_window=mutable[
48
+ index]
49
+ )
50
+ return gpt_say
51
+ # 异步任务开始
52
+ futures = [executor.submit(_req_gpt, index, inputs, history, sys_prompt) for index, inputs, history, sys_prompt in zip(
53
+ range(len(inputs_array)), inputs_array, history_array, sys_prompt_array)]
54
+ cnt = 0
55
+ while True:
56
+ # yield一次以刷新前端页面
57
+ time.sleep(refresh_interval)
58
+ cnt += 1
59
+ worker_done = [h.done() for h in futures]
60
+ if all(worker_done):
61
+ executor.shutdown()
62
+ break
63
+ # 更好的UI视觉效果
64
+ observe_win = []
65
+ # 每个线程都要“喂狗”(看门狗)
66
+ for thread_index, _ in enumerate(worker_done):
67
+ mutable[thread_index][1] = time.time()
68
+ # 在前端打印些好玩的东西
69
+ for thread_index, _ in enumerate(worker_done):
70
+ print_something_really_funny = "[ ...`"+mutable[thread_index][0][-scroller_max_len:].\
71
+ replace('\n', '').replace('```', '...').replace(
72
+ ' ', '.').replace('<br/>', '.....').replace('$', '.')+"`... ]"
73
+ observe_win.append(print_something_really_funny)
74
+ stat_str = ''.join([f'执行中: {obs}\n\n' if not done else '已完成\n\n' for done, obs in zip(
75
+ worker_done, observe_win)])
76
+ chatbot[-1] = [chatbot[-1][0],
77
+ f'多线程操作已经开始,完成情况: \n\n{stat_str}' + ''.join(['.']*(cnt % 10+1))]
78
+ msg = "正常"
79
+ yield chatbot, [], msg
80
+ # 异步任务结束
81
+ gpt_response_collection = []
82
+ for inputs_show_user, f in zip(inputs_show_user_array, futures):
83
+ gpt_res = f.result()
84
+ gpt_response_collection.extend([inputs_show_user, gpt_res])
85
+ return gpt_response_collection
86
+
87
+
88
+ def breakdown_txt_to_satisfy_token_limit(txt, get_token_fn, limit):
89
+ def cut(txt_tocut, must_break_at_empty_line): # 递归
90
+ if get_token_fn(txt_tocut) <= limit:
91
+ return [txt_tocut]
92
+ else:
93
+ lines = txt_tocut.split('\n')
94
+ estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines)
95
+ estimated_line_cut = int(estimated_line_cut)
96
+ for cnt in reversed(range(estimated_line_cut)):
97
+ if must_break_at_empty_line:
98
+ if lines[cnt] != "":
99
+ continue
100
+ print(cnt)
101
+ prev = "\n".join(lines[:cnt])
102
+ post = "\n".join(lines[cnt:])
103
+ if get_token_fn(prev) < limit:
104
+ break
105
+ if cnt == 0:
106
+ print('what the fuck ?')
107
+ raise RuntimeError("存在一行极长的文本!")
108
+ # print(len(post))
109
+ # 列表递归接龙
110
+ result = [prev]
111
+ result.extend(cut(post, must_break_at_empty_line))
112
+ return result
113
+ try:
114
+ return cut(txt, must_break_at_empty_line=True)
115
+ except RuntimeError:
116
+ return cut(txt, must_break_at_empty_line=False)
117
+
118
+
119
+ def breakdown_txt_to_satisfy_token_limit_for_pdf(txt, get_token_fn, limit):
120
+ def cut(txt_tocut, must_break_at_empty_line): # 递归
121
+ if get_token_fn(txt_tocut) <= limit:
122
+ return [txt_tocut]
123
+ else:
124
+ lines = txt_tocut.split('\n')
125
+ estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines)
126
+ estimated_line_cut = int(estimated_line_cut)
127
+ cnt = 0
128
+ for cnt in reversed(range(estimated_line_cut)):
129
+ if must_break_at_empty_line:
130
+ if lines[cnt] != "":
131
+ continue
132
+ print(cnt)
133
+ prev = "\n".join(lines[:cnt])
134
+ post = "\n".join(lines[cnt:])
135
+ if get_token_fn(prev) < limit:
136
+ break
137
+ if cnt == 0:
138
+ # print('what the fuck ? 存在一行极长的文本!')
139
+ raise RuntimeError("存在一行极长的文本!")
140
+ # print(len(post))
141
+ # 列表递归接龙
142
+ result = [prev]
143
+ result.extend(cut(post, must_break_at_empty_line))
144
+ return result
145
+ try:
146
+ return cut(txt, must_break_at_empty_line=True)
147
+ except RuntimeError:
148
+ try:
149
+ return cut(txt, must_break_at_empty_line=False)
150
+ except RuntimeError:
151
+ # 这个中文的句号是故意的,作为一个标识而存在
152
+ res = cut(txt.replace('.', '。\n'), must_break_at_empty_line=False)
153
+ return [r.replace('。\n', '.') for r in res]
crazy_functions/test_project/cpp/longcode/jpgd.cpp ADDED
@@ -0,0 +1,3276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // jpgd.cpp - C++ class for JPEG decompression.
2
+ // Public domain, Rich Geldreich <[email protected]>
3
+ // Last updated Apr. 16, 2011
4
+ // Alex Evans: Linear memory allocator (taken from jpge.h).
5
+ //
6
+ // Supports progressive and baseline sequential JPEG image files, and the most common chroma subsampling factors: Y, H1V1, H2V1, H1V2, and H2V2.
7
+ //
8
+ // Chroma upsampling quality: H2V2 is upsampled in the frequency domain, H2V1 and H1V2 are upsampled using point sampling.
9
+ // Chroma upsampling reference: "Fast Scheme for Image Size Change in the Compressed Domain"
10
+ // http://vision.ai.uiuc.edu/~dugad/research/dct/index.html
11
+
12
+ #include "jpgd.h"
13
+ #include <string.h>
14
+
15
+ #include <assert.h>
16
+ // BEGIN EPIC MOD
17
+ #define JPGD_ASSERT(x) { assert(x); CA_ASSUME(x); } (void)0
18
+ // END EPIC MOD
19
+
20
+ #ifdef _MSC_VER
21
+ #pragma warning (disable : 4611) // warning C4611: interaction between '_setjmp' and C++ object destruction is non-portable
22
+ #endif
23
+
24
+ // Set to 1 to enable freq. domain chroma upsampling on images using H2V2 subsampling (0=faster nearest neighbor sampling).
25
+ // This is slower, but results in higher quality on images with highly saturated colors.
26
+ #define JPGD_SUPPORT_FREQ_DOMAIN_UPSAMPLING 1
27
+
28
+ #define JPGD_TRUE (1)
29
+ #define JPGD_FALSE (0)
30
+
31
+ #define JPGD_MAX(a,b) (((a)>(b)) ? (a) : (b))
32
+ #define JPGD_MIN(a,b) (((a)<(b)) ? (a) : (b))
33
+
34
+ namespace jpgd {
35
+
36
+ static inline void *jpgd_malloc(size_t nSize) { return FMemory::Malloc(nSize); }
37
+ static inline void jpgd_free(void *p) { FMemory::Free(p); }
38
+
39
+ // BEGIN EPIC MOD
40
+ //@UE3 - use UE3 BGRA encoding instead of assuming RGBA
41
+ // stolen from IImageWrapper.h
42
+ enum ERGBFormatJPG
43
+ {
44
+ Invalid = -1,
45
+ RGBA = 0,
46
+ BGRA = 1,
47
+ Gray = 2,
48
+ };
49
+ static ERGBFormatJPG jpg_format;
50
+ // END EPIC MOD
51
+
52
+ // DCT coefficients are stored in this sequence.
53
+ static int g_ZAG[64] = { 0,1,8,16,9,2,3,10,17,24,32,25,18,11,4,5,12,19,26,33,40,48,41,34,27,20,13,6,7,14,21,28,35,42,49,56,57,50,43,36,29,22,15,23,30,37,44,51,58,59,52,45,38,31,39,46,53,60,61,54,47,55,62,63 };
54
+
55
+ enum JPEG_MARKER
56
+ {
57
+ M_SOF0 = 0xC0, M_SOF1 = 0xC1, M_SOF2 = 0xC2, M_SOF3 = 0xC3, M_SOF5 = 0xC5, M_SOF6 = 0xC6, M_SOF7 = 0xC7, M_JPG = 0xC8,
58
+ M_SOF9 = 0xC9, M_SOF10 = 0xCA, M_SOF11 = 0xCB, M_SOF13 = 0xCD, M_SOF14 = 0xCE, M_SOF15 = 0xCF, M_DHT = 0xC4, M_DAC = 0xCC,
59
+ M_RST0 = 0xD0, M_RST1 = 0xD1, M_RST2 = 0xD2, M_RST3 = 0xD3, M_RST4 = 0xD4, M_RST5 = 0xD5, M_RST6 = 0xD6, M_RST7 = 0xD7,
60
+ M_SOI = 0xD8, M_EOI = 0xD9, M_SOS = 0xDA, M_DQT = 0xDB, M_DNL = 0xDC, M_DRI = 0xDD, M_DHP = 0xDE, M_EXP = 0xDF,
61
+ M_APP0 = 0xE0, M_APP15 = 0xEF, M_JPG0 = 0xF0, M_JPG13 = 0xFD, M_COM = 0xFE, M_TEM = 0x01, M_ERROR = 0x100, RST0 = 0xD0
62
+ };
63
+
64
+ enum JPEG_SUBSAMPLING { JPGD_GRAYSCALE = 0, JPGD_YH1V1, JPGD_YH2V1, JPGD_YH1V2, JPGD_YH2V2 };
65
+
66
+ #define CONST_BITS 13
67
+ #define PASS1_BITS 2
68
+ #define SCALEDONE ((int32)1)
69
+
70
+ #define FIX_0_298631336 ((int32)2446) /* FIX(0.298631336) */
71
+ #define FIX_0_390180644 ((int32)3196) /* FIX(0.390180644) */
72
+ #define FIX_0_541196100 ((int32)4433) /* FIX(0.541196100) */
73
+ #define FIX_0_765366865 ((int32)6270) /* FIX(0.765366865) */
74
+ #define FIX_0_899976223 ((int32)7373) /* FIX(0.899976223) */
75
+ #define FIX_1_175875602 ((int32)9633) /* FIX(1.175875602) */
76
+ #define FIX_1_501321110 ((int32)12299) /* FIX(1.501321110) */
77
+ #define FIX_1_847759065 ((int32)15137) /* FIX(1.847759065) */
78
+ #define FIX_1_961570560 ((int32)16069) /* FIX(1.961570560) */
79
+ #define FIX_2_053119869 ((int32)16819) /* FIX(2.053119869) */
80
+ #define FIX_2_562915447 ((int32)20995) /* FIX(2.562915447) */
81
+ #define FIX_3_072711026 ((int32)25172) /* FIX(3.072711026) */
82
+
83
+ #define DESCALE(x,n) (((x) + (SCALEDONE << ((n)-1))) >> (n))
84
+ #define DESCALE_ZEROSHIFT(x,n) (((x) + (128 << (n)) + (SCALEDONE << ((n)-1))) >> (n))
85
+
86
+ #define MULTIPLY(var, cnst) ((var) * (cnst))
87
+
88
+ #define CLAMP(i) ((static_cast<uint>(i) > 255) ? (((~i) >> 31) & 0xFF) : (i))
89
+
90
+ // Compiler creates a fast path 1D IDCT for X non-zero columns
91
+ template <int NONZERO_COLS>
92
+ struct Row
93
+ {
94
+ static void idct(int* pTemp, const jpgd_block_t* pSrc)
95
+ {
96
+ // ACCESS_COL() will be optimized at compile time to either an array access, or 0.
97
+ #define ACCESS_COL(x) (((x) < NONZERO_COLS) ? (int)pSrc[x] : 0)
98
+
99
+ const int z2 = ACCESS_COL(2), z3 = ACCESS_COL(6);
100
+
101
+ const int z1 = MULTIPLY(z2 + z3, FIX_0_541196100);
102
+ const int tmp2 = z1 + MULTIPLY(z3, - FIX_1_847759065);
103
+ const int tmp3 = z1 + MULTIPLY(z2, FIX_0_765366865);
104
+
105
+ const int tmp0 = (ACCESS_COL(0) + ACCESS_COL(4)) << CONST_BITS;
106
+ const int tmp1 = (ACCESS_COL(0) - ACCESS_COL(4)) << CONST_BITS;
107
+
108
+ const int tmp10 = tmp0 + tmp3, tmp13 = tmp0 - tmp3, tmp11 = tmp1 + tmp2, tmp12 = tmp1 - tmp2;
109
+
110
+ const int atmp0 = ACCESS_COL(7), atmp1 = ACCESS_COL(5), atmp2 = ACCESS_COL(3), atmp3 = ACCESS_COL(1);
111
+
112
+ const int bz1 = atmp0 + atmp3, bz2 = atmp1 + atmp2, bz3 = atmp0 + atmp2, bz4 = atmp1 + atmp3;
113
+ const int bz5 = MULTIPLY(bz3 + bz4, FIX_1_175875602);
114
+
115
+ const int az1 = MULTIPLY(bz1, - FIX_0_899976223);
116
+ const int az2 = MULTIPLY(bz2, - FIX_2_562915447);
117
+ const int az3 = MULTIPLY(bz3, - FIX_1_961570560) + bz5;
118
+ const int az4 = MULTIPLY(bz4, - FIX_0_390180644) + bz5;
119
+
120
+ const int btmp0 = MULTIPLY(atmp0, FIX_0_298631336) + az1 + az3;
121
+ const int btmp1 = MULTIPLY(atmp1, FIX_2_053119869) + az2 + az4;
122
+ const int btmp2 = MULTIPLY(atmp2, FIX_3_072711026) + az2 + az3;
123
+ const int btmp3 = MULTIPLY(atmp3, FIX_1_501321110) + az1 + az4;
124
+
125
+ pTemp[0] = DESCALE(tmp10 + btmp3, CONST_BITS-PASS1_BITS);
126
+ pTemp[7] = DESCALE(tmp10 - btmp3, CONST_BITS-PASS1_BITS);
127
+ pTemp[1] = DESCALE(tmp11 + btmp2, CONST_BITS-PASS1_BITS);
128
+ pTemp[6] = DESCALE(tmp11 - btmp2, CONST_BITS-PASS1_BITS);
129
+ pTemp[2] = DESCALE(tmp12 + btmp1, CONST_BITS-PASS1_BITS);
130
+ pTemp[5] = DESCALE(tmp12 - btmp1, CONST_BITS-PASS1_BITS);
131
+ pTemp[3] = DESCALE(tmp13 + btmp0, CONST_BITS-PASS1_BITS);
132
+ pTemp[4] = DESCALE(tmp13 - btmp0, CONST_BITS-PASS1_BITS);
133
+ }
134
+ };
135
+
136
+ template <>
137
+ struct Row<0>
138
+ {
139
+ static void idct(int* pTemp, const jpgd_block_t* pSrc)
140
+ {
141
+ #ifdef _MSC_VER
142
+ pTemp; pSrc;
143
+ #endif
144
+ }
145
+ };
146
+
147
+ template <>
148
+ struct Row<1>
149
+ {
150
+ static void idct(int* pTemp, const jpgd_block_t* pSrc)
151
+ {
152
+ const int dcval = (pSrc[0] << PASS1_BITS);
153
+
154
+ pTemp[0] = dcval;
155
+ pTemp[1] = dcval;
156
+ pTemp[2] = dcval;
157
+ pTemp[3] = dcval;
158
+ pTemp[4] = dcval;
159
+ pTemp[5] = dcval;
160
+ pTemp[6] = dcval;
161
+ pTemp[7] = dcval;
162
+ }
163
+ };
164
+
165
+ // Compiler creates a fast path 1D IDCT for X non-zero rows
166
+ template <int NONZERO_ROWS>
167
+ struct Col
168
+ {
169
+ static void idct(uint8* pDst_ptr, const int* pTemp)
170
+ {
171
+ // ACCESS_ROW() will be optimized at compile time to either an array access, or 0.
172
+ #define ACCESS_ROW(x) (((x) < NONZERO_ROWS) ? pTemp[x * 8] : 0)
173
+
174
+ const int z2 = ACCESS_ROW(2);
175
+ const int z3 = ACCESS_ROW(6);
176
+
177
+ const int z1 = MULTIPLY(z2 + z3, FIX_0_541196100);
178
+ const int tmp2 = z1 + MULTIPLY(z3, - FIX_1_847759065);
179
+ const int tmp3 = z1 + MULTIPLY(z2, FIX_0_765366865);
180
+
181
+ const int tmp0 = (ACCESS_ROW(0) + ACCESS_ROW(4)) << CONST_BITS;
182
+ const int tmp1 = (ACCESS_ROW(0) - ACCESS_ROW(4)) << CONST_BITS;
183
+
184
+ const int tmp10 = tmp0 + tmp3, tmp13 = tmp0 - tmp3, tmp11 = tmp1 + tmp2, tmp12 = tmp1 - tmp2;
185
+
186
+ const int atmp0 = ACCESS_ROW(7), atmp1 = ACCESS_ROW(5), atmp2 = ACCESS_ROW(3), atmp3 = ACCESS_ROW(1);
187
+
188
+ const int bz1 = atmp0 + atmp3, bz2 = atmp1 + atmp2, bz3 = atmp0 + atmp2, bz4 = atmp1 + atmp3;
189
+ const int bz5 = MULTIPLY(bz3 + bz4, FIX_1_175875602);
190
+
191
+ const int az1 = MULTIPLY(bz1, - FIX_0_899976223);
192
+ const int az2 = MULTIPLY(bz2, - FIX_2_562915447);
193
+ const int az3 = MULTIPLY(bz3, - FIX_1_961570560) + bz5;
194
+ const int az4 = MULTIPLY(bz4, - FIX_0_390180644) + bz5;
195
+
196
+ const int btmp0 = MULTIPLY(atmp0, FIX_0_298631336) + az1 + az3;
197
+ const int btmp1 = MULTIPLY(atmp1, FIX_2_053119869) + az2 + az4;
198
+ const int btmp2 = MULTIPLY(atmp2, FIX_3_072711026) + az2 + az3;
199
+ const int btmp3 = MULTIPLY(atmp3, FIX_1_501321110) + az1 + az4;
200
+
201
+ int i = DESCALE_ZEROSHIFT(tmp10 + btmp3, CONST_BITS+PASS1_BITS+3);
202
+ pDst_ptr[8*0] = (uint8)CLAMP(i);
203
+
204
+ i = DESCALE_ZEROSHIFT(tmp10 - btmp3, CONST_BITS+PASS1_BITS+3);
205
+ pDst_ptr[8*7] = (uint8)CLAMP(i);
206
+
207
+ i = DESCALE_ZEROSHIFT(tmp11 + btmp2, CONST_BITS+PASS1_BITS+3);
208
+ pDst_ptr[8*1] = (uint8)CLAMP(i);
209
+
210
+ i = DESCALE_ZEROSHIFT(tmp11 - btmp2, CONST_BITS+PASS1_BITS+3);
211
+ pDst_ptr[8*6] = (uint8)CLAMP(i);
212
+
213
+ i = DESCALE_ZEROSHIFT(tmp12 + btmp1, CONST_BITS+PASS1_BITS+3);
214
+ pDst_ptr[8*2] = (uint8)CLAMP(i);
215
+
216
+ i = DESCALE_ZEROSHIFT(tmp12 - btmp1, CONST_BITS+PASS1_BITS+3);
217
+ pDst_ptr[8*5] = (uint8)CLAMP(i);
218
+
219
+ i = DESCALE_ZEROSHIFT(tmp13 + btmp0, CONST_BITS+PASS1_BITS+3);
220
+ pDst_ptr[8*3] = (uint8)CLAMP(i);
221
+
222
+ i = DESCALE_ZEROSHIFT(tmp13 - btmp0, CONST_BITS+PASS1_BITS+3);
223
+ pDst_ptr[8*4] = (uint8)CLAMP(i);
224
+ }
225
+ };
226
+
227
+ template <>
228
+ struct Col<1>
229
+ {
230
+ static void idct(uint8* pDst_ptr, const int* pTemp)
231
+ {
232
+ int dcval = DESCALE_ZEROSHIFT(pTemp[0], PASS1_BITS+3);
233
+ const uint8 dcval_clamped = (uint8)CLAMP(dcval);
234
+ pDst_ptr[0*8] = dcval_clamped;
235
+ pDst_ptr[1*8] = dcval_clamped;
236
+ pDst_ptr[2*8] = dcval_clamped;
237
+ pDst_ptr[3*8] = dcval_clamped;
238
+ pDst_ptr[4*8] = dcval_clamped;
239
+ pDst_ptr[5*8] = dcval_clamped;
240
+ pDst_ptr[6*8] = dcval_clamped;
241
+ pDst_ptr[7*8] = dcval_clamped;
242
+ }
243
+ };
244
+
245
+ static const uint8 s_idct_row_table[] =
246
+ {
247
+ 1,0,0,0,0,0,0,0, 2,0,0,0,0,0,0,0, 2,1,0,0,0,0,0,0, 2,1,1,0,0,0,0,0, 2,2,1,0,0,0,0,0, 3,2,1,0,0,0,0,0, 4,2,1,0,0,0,0,0, 4,3,1,0,0,0,0,0,
248
+ 4,3,2,0,0,0,0,0, 4,3,2,1,0,0,0,0, 4,3,2,1,1,0,0,0, 4,3,2,2,1,0,0,0, 4,3,3,2,1,0,0,0, 4,4,3,2,1,0,0,0, 5,4,3,2,1,0,0,0, 6,4,3,2,1,0,0,0,
249
+ 6,5,3,2,1,0,0,0, 6,5,4,2,1,0,0,0, 6,5,4,3,1,0,0,0, 6,5,4,3,2,0,0,0, 6,5,4,3,2,1,0,0, 6,5,4,3,2,1,1,0, 6,5,4,3,2,2,1,0, 6,5,4,3,3,2,1,0,
250
+ 6,5,4,4,3,2,1,0, 6,5,5,4,3,2,1,0, 6,6,5,4,3,2,1,0, 7,6,5,4,3,2,1,0, 8,6,5,4,3,2,1,0, 8,7,5,4,3,2,1,0, 8,7,6,4,3,2,1,0, 8,7,6,5,3,2,1,0,
251
+ 8,7,6,5,4,2,1,0, 8,7,6,5,4,3,1,0, 8,7,6,5,4,3,2,0, 8,7,6,5,4,3,2,1, 8,7,6,5,4,3,2,2, 8,7,6,5,4,3,3,2, 8,7,6,5,4,4,3,2, 8,7,6,5,5,4,3,2,
252
+ 8,7,6,6,5,4,3,2, 8,7,7,6,5,4,3,2, 8,8,7,6,5,4,3,2, 8,8,8,6,5,4,3,2, 8,8,8,7,5,4,3,2, 8,8,8,7,6,4,3,2, 8,8,8,7,6,5,3,2, 8,8,8,7,6,5,4,2,
253
+ 8,8,8,7,6,5,4,3, 8,8,8,7,6,5,4,4, 8,8,8,7,6,5,5,4, 8,8,8,7,6,6,5,4, 8,8,8,7,7,6,5,4, 8,8,8,8,7,6,5,4, 8,8,8,8,8,6,5,4, 8,8,8,8,8,7,5,4,
254
+ 8,8,8,8,8,7,6,4, 8,8,8,8,8,7,6,5, 8,8,8,8,8,7,6,6, 8,8,8,8,8,7,7,6, 8,8,8,8,8,8,7,6, 8,8,8,8,8,8,8,6, 8,8,8,8,8,8,8,7, 8,8,8,8,8,8,8,8,
255
+ };
256
+
257
+ static const uint8 s_idct_col_table[] = { 1, 1, 2, 3, 3, 3, 3, 3, 3, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8 };
258
+
259
+ void idct(const jpgd_block_t* pSrc_ptr, uint8* pDst_ptr, int block_max_zag)
260
+ {
261
+ JPGD_ASSERT(block_max_zag >= 1);
262
+ JPGD_ASSERT(block_max_zag <= 64);
263
+
264
+ if (block_max_zag == 1)
265
+ {
266
+ int k = ((pSrc_ptr[0] + 4) >> 3) + 128;
267
+ k = CLAMP(k);
268
+ k = k | (k<<8);
269
+ k = k | (k<<16);
270
+
271
+ for (int i = 8; i > 0; i--)
272
+ {
273
+ *(int*)&pDst_ptr[0] = k;
274
+ *(int*)&pDst_ptr[4] = k;
275
+ pDst_ptr += 8;
276
+ }
277
+ return;
278
+ }
279
+
280
+ int temp[64];
281
+
282
+ const jpgd_block_t* pSrc = pSrc_ptr;
283
+ int* pTemp = temp;
284
+
285
+ const uint8* pRow_tab = &s_idct_row_table[(block_max_zag - 1) * 8];
286
+ int i;
287
+ for (i = 8; i > 0; i--, pRow_tab++)
288
+ {
289
+ switch (*pRow_tab)
290
+ {
291
+ case 0: Row<0>::idct(pTemp, pSrc); break;
292
+ case 1: Row<1>::idct(pTemp, pSrc); break;
293
+ case 2: Row<2>::idct(pTemp, pSrc); break;
294
+ case 3: Row<3>::idct(pTemp, pSrc); break;
295
+ case 4: Row<4>::idct(pTemp, pSrc); break;
296
+ case 5: Row<5>::idct(pTemp, pSrc); break;
297
+ case 6: Row<6>::idct(pTemp, pSrc); break;
298
+ case 7: Row<7>::idct(pTemp, pSrc); break;
299
+ case 8: Row<8>::idct(pTemp, pSrc); break;
300
+ }
301
+
302
+ pSrc += 8;
303
+ pTemp += 8;
304
+ }
305
+
306
+ pTemp = temp;
307
+
308
+ const int nonzero_rows = s_idct_col_table[block_max_zag - 1];
309
+ for (i = 8; i > 0; i--)
310
+ {
311
+ switch (nonzero_rows)
312
+ {
313
+ case 1: Col<1>::idct(pDst_ptr, pTemp); break;
314
+ case 2: Col<2>::idct(pDst_ptr, pTemp); break;
315
+ case 3: Col<3>::idct(pDst_ptr, pTemp); break;
316
+ case 4: Col<4>::idct(pDst_ptr, pTemp); break;
317
+ case 5: Col<5>::idct(pDst_ptr, pTemp); break;
318
+ case 6: Col<6>::idct(pDst_ptr, pTemp); break;
319
+ case 7: Col<7>::idct(pDst_ptr, pTemp); break;
320
+ case 8: Col<8>::idct(pDst_ptr, pTemp); break;
321
+ }
322
+
323
+ pTemp++;
324
+ pDst_ptr++;
325
+ }
326
+ }
327
+
328
+ void idct_4x4(const jpgd_block_t* pSrc_ptr, uint8* pDst_ptr)
329
+ {
330
+ int temp[64];
331
+ int* pTemp = temp;
332
+ const jpgd_block_t* pSrc = pSrc_ptr;
333
+
334
+ for (int i = 4; i > 0; i--)
335
+ {
336
+ Row<4>::idct(pTemp, pSrc);
337
+ pSrc += 8;
338
+ pTemp += 8;
339
+ }
340
+
341
+ pTemp = temp;
342
+ for (int i = 8; i > 0; i--)
343
+ {
344
+ Col<4>::idct(pDst_ptr, pTemp);
345
+ pTemp++;
346
+ pDst_ptr++;
347
+ }
348
+ }
349
+
350
+ // Retrieve one character from the input stream.
351
+ inline uint jpeg_decoder::get_char()
352
+ {
353
+ // Any bytes remaining in buffer?
354
+ if (!m_in_buf_left)
355
+ {
356
+ // Try to get more bytes.
357
+ prep_in_buffer();
358
+ // Still nothing to get?
359
+ if (!m_in_buf_left)
360
+ {
361
+ // Pad the end of the stream with 0xFF 0xD9 (EOI marker)
362
+ int t = m_tem_flag;
363
+ m_tem_flag ^= 1;
364
+ if (t)
365
+ return 0xD9;
366
+ else
367
+ return 0xFF;
368
+ }
369
+ }
370
+
371
+ uint c = *m_pIn_buf_ofs++;
372
+ m_in_buf_left--;
373
+
374
+ return c;
375
+ }
376
+
377
+ // Same as previous method, except can indicate if the character is a pad character or not.
378
+ inline uint jpeg_decoder::get_char(bool *pPadding_flag)
379
+ {
380
+ if (!m_in_buf_left)
381
+ {
382
+ prep_in_buffer();
383
+ if (!m_in_buf_left)
384
+ {
385
+ *pPadding_flag = true;
386
+ int t = m_tem_flag;
387
+ m_tem_flag ^= 1;
388
+ if (t)
389
+ return 0xD9;
390
+ else
391
+ return 0xFF;
392
+ }
393
+ }
394
+
395
+ *pPadding_flag = false;
396
+
397
+ uint c = *m_pIn_buf_ofs++;
398
+ m_in_buf_left--;
399
+
400
+ return c;
401
+ }
402
+
403
+ // Inserts a previously retrieved character back into the input buffer.
404
+ inline void jpeg_decoder::stuff_char(uint8 q)
405
+ {
406
+ *(--m_pIn_buf_ofs) = q;
407
+ m_in_buf_left++;
408
+ }
409
+
410
+ // Retrieves one character from the input stream, but does not read past markers. Will continue to return 0xFF when a marker is encountered.
411
+ inline uint8 jpeg_decoder::get_octet()
412
+ {
413
+ bool padding_flag;
414
+ int c = get_char(&padding_flag);
415
+
416
+ if (c == 0xFF)
417
+ {
418
+ if (padding_flag)
419
+ return 0xFF;
420
+
421
+ c = get_char(&padding_flag);
422
+ if (padding_flag)
423
+ {
424
+ stuff_char(0xFF);
425
+ return 0xFF;
426
+ }
427
+
428
+ if (c == 0x00)
429
+ return 0xFF;
430
+ else
431
+ {
432
+ stuff_char(static_cast<uint8>(c));
433
+ stuff_char(0xFF);
434
+ return 0xFF;
435
+ }
436
+ }
437
+
438
+ return static_cast<uint8>(c);
439
+ }
440
+
441
+ // Retrieves a variable number of bits from the input stream. Does not recognize markers.
442
+ inline uint jpeg_decoder::get_bits(int num_bits)
443
+ {
444
+ if (!num_bits)
445
+ return 0;
446
+
447
+ uint i = m_bit_buf >> (32 - num_bits);
448
+
449
+ if ((m_bits_left -= num_bits) <= 0)
450
+ {
451
+ m_bit_buf <<= (num_bits += m_bits_left);
452
+
453
+ uint c1 = get_char();
454
+ uint c2 = get_char();
455
+ m_bit_buf = (m_bit_buf & 0xFFFF0000) | (c1 << 8) | c2;
456
+
457
+ m_bit_buf <<= -m_bits_left;
458
+
459
+ m_bits_left += 16;
460
+
461
+ JPGD_ASSERT(m_bits_left >= 0);
462
+ }
463
+ else
464
+ m_bit_buf <<= num_bits;
465
+
466
+ return i;
467
+ }
468
+
469
+ // Retrieves a variable number of bits from the input stream. Markers will not be read into the input bit buffer. Instead, an infinite number of all 1's will be returned when a marker is encountered.
470
+ inline uint jpeg_decoder::get_bits_no_markers(int num_bits)
471
+ {
472
+ if (!num_bits)
473
+ return 0;
474
+
475
+ uint i = m_bit_buf >> (32 - num_bits);
476
+
477
+ if ((m_bits_left -= num_bits) <= 0)
478
+ {
479
+ m_bit_buf <<= (num_bits += m_bits_left);
480
+
481
+ if ((m_in_buf_left < 2) || (m_pIn_buf_ofs[0] == 0xFF) || (m_pIn_buf_ofs[1] == 0xFF))
482
+ {
483
+ uint c1 = get_octet();
484
+ uint c2 = get_octet();
485
+ m_bit_buf |= (c1 << 8) | c2;
486
+ }
487
+ else
488
+ {
489
+ m_bit_buf |= ((uint)m_pIn_buf_ofs[0] << 8) | m_pIn_buf_ofs[1];
490
+ m_in_buf_left -= 2;
491
+ m_pIn_buf_ofs += 2;
492
+ }
493
+
494
+ m_bit_buf <<= -m_bits_left;
495
+
496
+ m_bits_left += 16;
497
+
498
+ JPGD_ASSERT(m_bits_left >= 0);
499
+ }
500
+ else
501
+ m_bit_buf <<= num_bits;
502
+
503
+ return i;
504
+ }
505
+
506
+ // Decodes a Huffman encoded symbol.
507
+ inline int jpeg_decoder::huff_decode(huff_tables *pH)
508
+ {
509
+ int symbol;
510
+
511
+ // Check first 8-bits: do we have a complete symbol?
512
+ if ((symbol = pH->look_up[m_bit_buf >> 24]) < 0)
513
+ {
514
+ // Decode more bits, use a tree traversal to find symbol.
515
+ int ofs = 23;
516
+ do
517
+ {
518
+ symbol = pH->tree[-(int)(symbol + ((m_bit_buf >> ofs) & 1))];
519
+ ofs--;
520
+ } while (symbol < 0);
521
+
522
+ get_bits_no_markers(8 + (23 - ofs));
523
+ }
524
+ else
525
+ get_bits_no_markers(pH->code_size[symbol]);
526
+
527
+ return symbol;
528
+ }
529
+
530
+ // Decodes a Huffman encoded symbol.
531
+ inline int jpeg_decoder::huff_decode(huff_tables *pH, int& extra_bits)
532
+ {
533
+ int symbol;
534
+
535
+ // Check first 8-bits: do we have a complete symbol?
536
+ if ((symbol = pH->look_up2[m_bit_buf >> 24]) < 0)
537
+ {
538
+ // Use a tree traversal to find symbol.
539
+ int ofs = 23;
540
+ do
541
+ {
542
+ symbol = pH->tree[-(int)(symbol + ((m_bit_buf >> ofs) & 1))];
543
+ ofs--;
544
+ } while (symbol < 0);
545
+
546
+ get_bits_no_markers(8 + (23 - ofs));
547
+
548
+ extra_bits = get_bits_no_markers(symbol & 0xF);
549
+ }
550
+ else
551
+ {
552
+ JPGD_ASSERT(((symbol >> 8) & 31) == pH->code_size[symbol & 255] + ((symbol & 0x8000) ? (symbol & 15) : 0));
553
+
554
+ if (symbol & 0x8000)
555
+ {
556
+ get_bits_no_markers((symbol >> 8) & 31);
557
+ extra_bits = symbol >> 16;
558
+ }
559
+ else
560
+ {
561
+ int code_size = (symbol >> 8) & 31;
562
+ int num_extra_bits = symbol & 0xF;
563
+ int bits = code_size + num_extra_bits;
564
+ if (bits <= (m_bits_left + 16))
565
+ extra_bits = get_bits_no_markers(bits) & ((1 << num_extra_bits) - 1);
566
+ else
567
+ {
568
+ get_bits_no_markers(code_size);
569
+ extra_bits = get_bits_no_markers(num_extra_bits);
570
+ }
571
+ }
572
+
573
+ symbol &= 0xFF;
574
+ }
575
+
576
+ return symbol;
577
+ }
578
+
579
+ // Tables and macro used to fully decode the DPCM differences.
580
+ static const int s_extend_test[16] = { 0, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000 };
581
+ static const int s_extend_offset[16] = { 0, -1, -3, -7, -15, -31, -63, -127, -255, -511, -1023, -2047, -4095, -8191, -16383, -32767 };
582
+ static const int s_extend_mask[] = { 0, (1<<0), (1<<1), (1<<2), (1<<3), (1<<4), (1<<5), (1<<6), (1<<7), (1<<8), (1<<9), (1<<10), (1<<11), (1<<12), (1<<13), (1<<14), (1<<15), (1<<16) };
583
+ #define HUFF_EXTEND(x,s) ((x) < s_extend_test[s] ? (x) + s_extend_offset[s] : (x))
584
+
585
+ // Clamps a value between 0-255.
586
+ inline uint8 jpeg_decoder::clamp(int i)
587
+ {
588
+ if (static_cast<uint>(i) > 255)
589
+ i = (((~i) >> 31) & 0xFF);
590
+
591
+ return static_cast<uint8>(i);
592
+ }
593
+
594
+ namespace DCT_Upsample
595
+ {
596
+ struct Matrix44
597
+ {
598
+ typedef int Element_Type;
599
+ enum { NUM_ROWS = 4, NUM_COLS = 4 };
600
+
601
+ Element_Type v[NUM_ROWS][NUM_COLS];
602
+
603
+ inline int rows() const { return NUM_ROWS; }
604
+ inline int cols() const { return NUM_COLS; }
605
+
606
+ inline const Element_Type & at(int r, int c) const { return v[r][c]; }
607
+ inline Element_Type & at(int r, int c) { return v[r][c]; }
608
+
609
+ inline Matrix44() { }
610
+
611
+ inline Matrix44& operator += (const Matrix44& a)
612
+ {
613
+ for (int r = 0; r < NUM_ROWS; r++)
614
+ {
615
+ at(r, 0) += a.at(r, 0);
616
+ at(r, 1) += a.at(r, 1);
617
+ at(r, 2) += a.at(r, 2);
618
+ at(r, 3) += a.at(r, 3);
619
+ }
620
+ return *this;
621
+ }
622
+
623
+ inline Matrix44& operator -= (const Matrix44& a)
624
+ {
625
+ for (int r = 0; r < NUM_ROWS; r++)
626
+ {
627
+ at(r, 0) -= a.at(r, 0);
628
+ at(r, 1) -= a.at(r, 1);
629
+ at(r, 2) -= a.at(r, 2);
630
+ at(r, 3) -= a.at(r, 3);
631
+ }
632
+ return *this;
633
+ }
634
+
635
+ friend inline Matrix44 operator + (const Matrix44& a, const Matrix44& b)
636
+ {
637
+ Matrix44 ret;
638
+ for (int r = 0; r < NUM_ROWS; r++)
639
+ {
640
+ ret.at(r, 0) = a.at(r, 0) + b.at(r, 0);
641
+ ret.at(r, 1) = a.at(r, 1) + b.at(r, 1);
642
+ ret.at(r, 2) = a.at(r, 2) + b.at(r, 2);
643
+ ret.at(r, 3) = a.at(r, 3) + b.at(r, 3);
644
+ }
645
+ return ret;
646
+ }
647
+
648
+ friend inline Matrix44 operator - (const Matrix44& a, const Matrix44& b)
649
+ {
650
+ Matrix44 ret;
651
+ for (int r = 0; r < NUM_ROWS; r++)
652
+ {
653
+ ret.at(r, 0) = a.at(r, 0) - b.at(r, 0);
654
+ ret.at(r, 1) = a.at(r, 1) - b.at(r, 1);
655
+ ret.at(r, 2) = a.at(r, 2) - b.at(r, 2);
656
+ ret.at(r, 3) = a.at(r, 3) - b.at(r, 3);
657
+ }
658
+ return ret;
659
+ }
660
+
661
+ static inline void add_and_store(jpgd_block_t* pDst, const Matrix44& a, const Matrix44& b)
662
+ {
663
+ for (int r = 0; r < 4; r++)
664
+ {
665
+ pDst[0*8 + r] = static_cast<jpgd_block_t>(a.at(r, 0) + b.at(r, 0));
666
+ pDst[1*8 + r] = static_cast<jpgd_block_t>(a.at(r, 1) + b.at(r, 1));
667
+ pDst[2*8 + r] = static_cast<jpgd_block_t>(a.at(r, 2) + b.at(r, 2));
668
+ pDst[3*8 + r] = static_cast<jpgd_block_t>(a.at(r, 3) + b.at(r, 3));
669
+ }
670
+ }
671
+
672
+ static inline void sub_and_store(jpgd_block_t* pDst, const Matrix44& a, const Matrix44& b)
673
+ {
674
+ for (int r = 0; r < 4; r++)
675
+ {
676
+ pDst[0*8 + r] = static_cast<jpgd_block_t>(a.at(r, 0) - b.at(r, 0));
677
+ pDst[1*8 + r] = static_cast<jpgd_block_t>(a.at(r, 1) - b.at(r, 1));
678
+ pDst[2*8 + r] = static_cast<jpgd_block_t>(a.at(r, 2) - b.at(r, 2));
679
+ pDst[3*8 + r] = static_cast<jpgd_block_t>(a.at(r, 3) - b.at(r, 3));
680
+ }
681
+ }
682
+ };
683
+
684
+ const int FRACT_BITS = 10;
685
+ const int SCALE = 1 << FRACT_BITS;
686
+
687
+ typedef int Temp_Type;
688
+ #define D(i) (((i) + (SCALE >> 1)) >> FRACT_BITS)
689
+ #define F(i) ((int)((i) * SCALE + .5f))
690
+
691
+ // Any decent C++ compiler will optimize this at compile time to a 0, or an array access.
692
+ #define AT(c, r) ((((c)>=NUM_COLS)||((r)>=NUM_ROWS)) ? 0 : pSrc[(c)+(r)*8])
693
+
694
+ // NUM_ROWS/NUM_COLS = # of non-zero rows/cols in input matrix
695
+ template<int NUM_ROWS, int NUM_COLS>
696
+ struct P_Q
697
+ {
698
+ static void calc(Matrix44& P, Matrix44& Q, const jpgd_block_t* pSrc)
699
+ {
700
+ // 4x8 = 4x8 times 8x8, matrix 0 is constant
701
+ const Temp_Type X000 = AT(0, 0);
702
+ const Temp_Type X001 = AT(0, 1);
703
+ const Temp_Type X002 = AT(0, 2);
704
+ const Temp_Type X003 = AT(0, 3);
705
+ const Temp_Type X004 = AT(0, 4);
706
+ const Temp_Type X005 = AT(0, 5);
707
+ const Temp_Type X006 = AT(0, 6);
708
+ const Temp_Type X007 = AT(0, 7);
709
+ const Temp_Type X010 = D(F(0.415735f) * AT(1, 0) + F(0.791065f) * AT(3, 0) + F(-0.352443f) * AT(5, 0) + F(0.277785f) * AT(7, 0));
710
+ const Temp_Type X011 = D(F(0.415735f) * AT(1, 1) + F(0.791065f) * AT(3, 1) + F(-0.352443f) * AT(5, 1) + F(0.277785f) * AT(7, 1));
711
+ const Temp_Type X012 = D(F(0.415735f) * AT(1, 2) + F(0.791065f) * AT(3, 2) + F(-0.352443f) * AT(5, 2) + F(0.277785f) * AT(7, 2));
712
+ const Temp_Type X013 = D(F(0.415735f) * AT(1, 3) + F(0.791065f) * AT(3, 3) + F(-0.352443f) * AT(5, 3) + F(0.277785f) * AT(7, 3));
713
+ const Temp_Type X014 = D(F(0.415735f) * AT(1, 4) + F(0.791065f) * AT(3, 4) + F(-0.352443f) * AT(5, 4) + F(0.277785f) * AT(7, 4));
714
+ const Temp_Type X015 = D(F(0.415735f) * AT(1, 5) + F(0.791065f) * AT(3, 5) + F(-0.352443f) * AT(5, 5) + F(0.277785f) * AT(7, 5));
715
+ const Temp_Type X016 = D(F(0.415735f) * AT(1, 6) + F(0.791065f) * AT(3, 6) + F(-0.352443f) * AT(5, 6) + F(0.277785f) * AT(7, 6));
716
+ const Temp_Type X017 = D(F(0.415735f) * AT(1, 7) + F(0.791065f) * AT(3, 7) + F(-0.352443f) * AT(5, 7) + F(0.277785f) * AT(7, 7));
717
+ const Temp_Type X020 = AT(4, 0);
718
+ const Temp_Type X021 = AT(4, 1);
719
+ const Temp_Type X022 = AT(4, 2);
720
+ const Temp_Type X023 = AT(4, 3);
721
+ const Temp_Type X024 = AT(4, 4);
722
+ const Temp_Type X025 = AT(4, 5);
723
+ const Temp_Type X026 = AT(4, 6);
724
+ const Temp_Type X027 = AT(4, 7);
725
+ const Temp_Type X030 = D(F(0.022887f) * AT(1, 0) + F(-0.097545f) * AT(3, 0) + F(0.490393f) * AT(5, 0) + F(0.865723f) * AT(7, 0));
726
+ const Temp_Type X031 = D(F(0.022887f) * AT(1, 1) + F(-0.097545f) * AT(3, 1) + F(0.490393f) * AT(5, 1) + F(0.865723f) * AT(7, 1));
727
+ const Temp_Type X032 = D(F(0.022887f) * AT(1, 2) + F(-0.097545f) * AT(3, 2) + F(0.490393f) * AT(5, 2) + F(0.865723f) * AT(7, 2));
728
+ const Temp_Type X033 = D(F(0.022887f) * AT(1, 3) + F(-0.097545f) * AT(3, 3) + F(0.490393f) * AT(5, 3) + F(0.865723f) * AT(7, 3));
729
+ const Temp_Type X034 = D(F(0.022887f) * AT(1, 4) + F(-0.097545f) * AT(3, 4) + F(0.490393f) * AT(5, 4) + F(0.865723f) * AT(7, 4));
730
+ const Temp_Type X035 = D(F(0.022887f) * AT(1, 5) + F(-0.097545f) * AT(3, 5) + F(0.490393f) * AT(5, 5) + F(0.865723f) * AT(7, 5));
731
+ const Temp_Type X036 = D(F(0.022887f) * AT(1, 6) + F(-0.097545f) * AT(3, 6) + F(0.490393f) * AT(5, 6) + F(0.865723f) * AT(7, 6));
732
+ const Temp_Type X037 = D(F(0.022887f) * AT(1, 7) + F(-0.097545f) * AT(3, 7) + F(0.490393f) * AT(5, 7) + F(0.865723f) * AT(7, 7));
733
+
734
+ // 4x4 = 4x8 times 8x4, matrix 1 is constant
735
+ P.at(0, 0) = X000;
736
+ P.at(0, 1) = D(X001 * F(0.415735f) + X003 * F(0.791065f) + X005 * F(-0.352443f) + X007 * F(0.277785f));
737
+ P.at(0, 2) = X004;
738
+ P.at(0, 3) = D(X001 * F(0.022887f) + X003 * F(-0.097545f) + X005 * F(0.490393f) + X007 * F(0.865723f));
739
+ P.at(1, 0) = X010;
740
+ P.at(1, 1) = D(X011 * F(0.415735f) + X013 * F(0.791065f) + X015 * F(-0.352443f) + X017 * F(0.277785f));
741
+ P.at(1, 2) = X014;
742
+ P.at(1, 3) = D(X011 * F(0.022887f) + X013 * F(-0.097545f) + X015 * F(0.490393f) + X017 * F(0.865723f));
743
+ P.at(2, 0) = X020;
744
+ P.at(2, 1) = D(X021 * F(0.415735f) + X023 * F(0.791065f) + X025 * F(-0.352443f) + X027 * F(0.277785f));
745
+ P.at(2, 2) = X024;
746
+ P.at(2, 3) = D(X021 * F(0.022887f) + X023 * F(-0.097545f) + X025 * F(0.490393f) + X027 * F(0.865723f));
747
+ P.at(3, 0) = X030;
748
+ P.at(3, 1) = D(X031 * F(0.415735f) + X033 * F(0.791065f) + X035 * F(-0.352443f) + X037 * F(0.277785f));
749
+ P.at(3, 2) = X034;
750
+ P.at(3, 3) = D(X031 * F(0.022887f) + X033 * F(-0.097545f) + X035 * F(0.490393f) + X037 * F(0.865723f));
751
+ // 40 muls 24 adds
752
+
753
+ // 4x4 = 4x8 times 8x4, matrix 1 is constant
754
+ Q.at(0, 0) = D(X001 * F(0.906127f) + X003 * F(-0.318190f) + X005 * F(0.212608f) + X007 * F(-0.180240f));
755
+ Q.at(0, 1) = X002;
756
+ Q.at(0, 2) = D(X001 * F(-0.074658f) + X003 * F(0.513280f) + X005 * F(0.768178f) + X007 * F(-0.375330f));
757
+ Q.at(0, 3) = X006;
758
+ Q.at(1, 0) = D(X011 * F(0.906127f) + X013 * F(-0.318190f) + X015 * F(0.212608f) + X017 * F(-0.180240f));
759
+ Q.at(1, 1) = X012;
760
+ Q.at(1, 2) = D(X011 * F(-0.074658f) + X013 * F(0.513280f) + X015 * F(0.768178f) + X017 * F(-0.375330f));
761
+ Q.at(1, 3) = X016;
762
+ Q.at(2, 0) = D(X021 * F(0.906127f) + X023 * F(-0.318190f) + X025 * F(0.212608f) + X027 * F(-0.180240f));
763
+ Q.at(2, 1) = X022;
764
+ Q.at(2, 2) = D(X021 * F(-0.074658f) + X023 * F(0.513280f) + X025 * F(0.768178f) + X027 * F(-0.375330f));
765
+ Q.at(2, 3) = X026;
766
+ Q.at(3, 0) = D(X031 * F(0.906127f) + X033 * F(-0.318190f) + X035 * F(0.212608f) + X037 * F(-0.180240f));
767
+ Q.at(3, 1) = X032;
768
+ Q.at(3, 2) = D(X031 * F(-0.074658f) + X033 * F(0.513280f) + X035 * F(0.768178f) + X037 * F(-0.375330f));
769
+ Q.at(3, 3) = X036;
770
+ // 40 muls 24 adds
771
+ }
772
+ };
773
+
774
+ template<int NUM_ROWS, int NUM_COLS>
775
+ struct R_S
776
+ {
777
+ static void calc(Matrix44& R, Matrix44& S, const jpgd_block_t* pSrc)
778
+ {
779
+ // 4x8 = 4x8 times 8x8, matrix 0 is constant
780
+ const Temp_Type X100 = D(F(0.906127f) * AT(1, 0) + F(-0.318190f) * AT(3, 0) + F(0.212608f) * AT(5, 0) + F(-0.180240f) * AT(7, 0));
781
+ const Temp_Type X101 = D(F(0.906127f) * AT(1, 1) + F(-0.318190f) * AT(3, 1) + F(0.212608f) * AT(5, 1) + F(-0.180240f) * AT(7, 1));
782
+ const Temp_Type X102 = D(F(0.906127f) * AT(1, 2) + F(-0.318190f) * AT(3, 2) + F(0.212608f) * AT(5, 2) + F(-0.180240f) * AT(7, 2));
783
+ const Temp_Type X103 = D(F(0.906127f) * AT(1, 3) + F(-0.318190f) * AT(3, 3) + F(0.212608f) * AT(5, 3) + F(-0.180240f) * AT(7, 3));
784
+ const Temp_Type X104 = D(F(0.906127f) * AT(1, 4) + F(-0.318190f) * AT(3, 4) + F(0.212608f) * AT(5, 4) + F(-0.180240f) * AT(7, 4));
785
+ const Temp_Type X105 = D(F(0.906127f) * AT(1, 5) + F(-0.318190f) * AT(3, 5) + F(0.212608f) * AT(5, 5) + F(-0.180240f) * AT(7, 5));
786
+ const Temp_Type X106 = D(F(0.906127f) * AT(1, 6) + F(-0.318190f) * AT(3, 6) + F(0.212608f) * AT(5, 6) + F(-0.180240f) * AT(7, 6));
787
+ const Temp_Type X107 = D(F(0.906127f) * AT(1, 7) + F(-0.318190f) * AT(3, 7) + F(0.212608f) * AT(5, 7) + F(-0.180240f) * AT(7, 7));
788
+ const Temp_Type X110 = AT(2, 0);
789
+ const Temp_Type X111 = AT(2, 1);
790
+ const Temp_Type X112 = AT(2, 2);
791
+ const Temp_Type X113 = AT(2, 3);
792
+ const Temp_Type X114 = AT(2, 4);
793
+ const Temp_Type X115 = AT(2, 5);
794
+ const Temp_Type X116 = AT(2, 6);
795
+ const Temp_Type X117 = AT(2, 7);
796
+ const Temp_Type X120 = D(F(-0.074658f) * AT(1, 0) + F(0.513280f) * AT(3, 0) + F(0.768178f) * AT(5, 0) + F(-0.375330f) * AT(7, 0));
797
+ const Temp_Type X121 = D(F(-0.074658f) * AT(1, 1) + F(0.513280f) * AT(3, 1) + F(0.768178f) * AT(5, 1) + F(-0.375330f) * AT(7, 1));
798
+ const Temp_Type X122 = D(F(-0.074658f) * AT(1, 2) + F(0.513280f) * AT(3, 2) + F(0.768178f) * AT(5, 2) + F(-0.375330f) * AT(7, 2));
799
+ const Temp_Type X123 = D(F(-0.074658f) * AT(1, 3) + F(0.513280f) * AT(3, 3) + F(0.768178f) * AT(5, 3) + F(-0.375330f) * AT(7, 3));
800
+ const Temp_Type X124 = D(F(-0.074658f) * AT(1, 4) + F(0.513280f) * AT(3, 4) + F(0.768178f) * AT(5, 4) + F(-0.375330f) * AT(7, 4));
801
+ const Temp_Type X125 = D(F(-0.074658f) * AT(1, 5) + F(0.513280f) * AT(3, 5) + F(0.768178f) * AT(5, 5) + F(-0.375330f) * AT(7, 5));
802
+ const Temp_Type X126 = D(F(-0.074658f) * AT(1, 6) + F(0.513280f) * AT(3, 6) + F(0.768178f) * AT(5, 6) + F(-0.375330f) * AT(7, 6));
803
+ const Temp_Type X127 = D(F(-0.074658f) * AT(1, 7) + F(0.513280f) * AT(3, 7) + F(0.768178f) * AT(5, 7) + F(-0.375330f) * AT(7, 7));
804
+ const Temp_Type X130 = AT(6, 0);
805
+ const Temp_Type X131 = AT(6, 1);
806
+ const Temp_Type X132 = AT(6, 2);
807
+ const Temp_Type X133 = AT(6, 3);
808
+ const Temp_Type X134 = AT(6, 4);
809
+ const Temp_Type X135 = AT(6, 5);
810
+ const Temp_Type X136 = AT(6, 6);
811
+ const Temp_Type X137 = AT(6, 7);
812
+ // 80 muls 48 adds
813
+
814
+ // 4x4 = 4x8 times 8x4, matrix 1 is constant
815
+ R.at(0, 0) = X100;
816
+ R.at(0, 1) = D(X101 * F(0.415735f) + X103 * F(0.791065f) + X105 * F(-0.352443f) + X107 * F(0.277785f));
817
+ R.at(0, 2) = X104;
818
+ R.at(0, 3) = D(X101 * F(0.022887f) + X103 * F(-0.097545f) + X105 * F(0.490393f) + X107 * F(0.865723f));
819
+ R.at(1, 0) = X110;
820
+ R.at(1, 1) = D(X111 * F(0.415735f) + X113 * F(0.791065f) + X115 * F(-0.352443f) + X117 * F(0.277785f));
821
+ R.at(1, 2) = X114;
822
+ R.at(1, 3) = D(X111 * F(0.022887f) + X113 * F(-0.097545f) + X115 * F(0.490393f) + X117 * F(0.865723f));
823
+ R.at(2, 0) = X120;
824
+ R.at(2, 1) = D(X121 * F(0.415735f) + X123 * F(0.791065f) + X125 * F(-0.352443f) + X127 * F(0.277785f));
825
+ R.at(2, 2) = X124;
826
+ R.at(2, 3) = D(X121 * F(0.022887f) + X123 * F(-0.097545f) + X125 * F(0.490393f) + X127 * F(0.865723f));
827
+ R.at(3, 0) = X130;
828
+ R.at(3, 1) = D(X131 * F(0.415735f) + X133 * F(0.791065f) + X135 * F(-0.352443f) + X137 * F(0.277785f));
829
+ R.at(3, 2) = X134;
830
+ R.at(3, 3) = D(X131 * F(0.022887f) + X133 * F(-0.097545f) + X135 * F(0.490393f) + X137 * F(0.865723f));
831
+ // 40 muls 24 adds
832
+ // 4x4 = 4x8 times 8x4, matrix 1 is constant
833
+ S.at(0, 0) = D(X101 * F(0.906127f) + X103 * F(-0.318190f) + X105 * F(0.212608f) + X107 * F(-0.180240f));
834
+ S.at(0, 1) = X102;
835
+ S.at(0, 2) = D(X101 * F(-0.074658f) + X103 * F(0.513280f) + X105 * F(0.768178f) + X107 * F(-0.375330f));
836
+ S.at(0, 3) = X106;
837
+ S.at(1, 0) = D(X111 * F(0.906127f) + X113 * F(-0.318190f) + X115 * F(0.212608f) + X117 * F(-0.180240f));
838
+ S.at(1, 1) = X112;
839
+ S.at(1, 2) = D(X111 * F(-0.074658f) + X113 * F(0.513280f) + X115 * F(0.768178f) + X117 * F(-0.375330f));
840
+ S.at(1, 3) = X116;
841
+ S.at(2, 0) = D(X121 * F(0.906127f) + X123 * F(-0.318190f) + X125 * F(0.212608f) + X127 * F(-0.180240f));
842
+ S.at(2, 1) = X122;
843
+ S.at(2, 2) = D(X121 * F(-0.074658f) + X123 * F(0.513280f) + X125 * F(0.768178f) + X127 * F(-0.375330f));
844
+ S.at(2, 3) = X126;
845
+ S.at(3, 0) = D(X131 * F(0.906127f) + X133 * F(-0.318190f) + X135 * F(0.212608f) + X137 * F(-0.180240f));
846
+ S.at(3, 1) = X132;
847
+ S.at(3, 2) = D(X131 * F(-0.074658f) + X133 * F(0.513280f) + X135 * F(0.768178f) + X137 * F(-0.375330f));
848
+ S.at(3, 3) = X136;
849
+ // 40 muls 24 adds
850
+ }
851
+ };
852
+ } // end namespace DCT_Upsample
853
+
854
+ // Unconditionally frees all allocated m_blocks.
855
+ void jpeg_decoder::free_all_blocks()
856
+ {
857
+ m_pStream = NULL;
858
+ for (mem_block *b = m_pMem_blocks; b; )
859
+ {
860
+ mem_block *n = b->m_pNext;
861
+ jpgd_free(b);
862
+ b = n;
863
+ }
864
+ m_pMem_blocks = NULL;
865
+ }
866
+
867
+ // This method handles all errors.
868
+ // It could easily be changed to use C++ exceptions.
869
+ void jpeg_decoder::stop_decoding(jpgd_status status)
870
+ {
871
+ m_error_code = status;
872
+ free_all_blocks();
873
+ longjmp(m_jmp_state, status);
874
+
875
+ // we shouldn't get here as longjmp shouldn't return, but we put it here to make it explicit
876
+ // that this function doesn't return, otherwise we get this error:
877
+ //
878
+ // error : function declared 'noreturn' should not return
879
+ exit(1);
880
+ }
881
+
882
+ void *jpeg_decoder::alloc(size_t nSize, bool zero)
883
+ {
884
+ nSize = (JPGD_MAX(nSize, 1) + 3) & ~3;
885
+ char *rv = NULL;
886
+ for (mem_block *b = m_pMem_blocks; b; b = b->m_pNext)
887
+ {
888
+ if ((b->m_used_count + nSize) <= b->m_size)
889
+ {
890
+ rv = b->m_data + b->m_used_count;
891
+ b->m_used_count += nSize;
892
+ break;
893
+ }
894
+ }
895
+ if (!rv)
896
+ {
897
+ int capacity = JPGD_MAX(32768 - 256, (nSize + 2047) & ~2047);
898
+ mem_block *b = (mem_block*)jpgd_malloc(sizeof(mem_block) + capacity);
899
+ if (!b) stop_decoding(JPGD_NOTENOUGHMEM);
900
+ b->m_pNext = m_pMem_blocks; m_pMem_blocks = b;
901
+ b->m_used_count = nSize;
902
+ b->m_size = capacity;
903
+ rv = b->m_data;
904
+ }
905
+ if (zero) memset(rv, 0, nSize);
906
+ return rv;
907
+ }
908
+
909
+ void jpeg_decoder::word_clear(void *p, uint16 c, uint n)
910
+ {
911
+ uint8 *pD = (uint8*)p;
912
+ const uint8 l = c & 0xFF, h = (c >> 8) & 0xFF;
913
+ while (n)
914
+ {
915
+ pD[0] = l; pD[1] = h; pD += 2;
916
+ n--;
917
+ }
918
+ }
919
+
920
+ // Refill the input buffer.
921
+ // This method will sit in a loop until (A) the buffer is full or (B)
922
+ // the stream's read() method reports and end of file condition.
923
+ void jpeg_decoder::prep_in_buffer()
924
+ {
925
+ m_in_buf_left = 0;
926
+ m_pIn_buf_ofs = m_in_buf;
927
+
928
+ if (m_eof_flag)
929
+ return;
930
+
931
+ do
932
+ {
933
+ int bytes_read = m_pStream->read(m_in_buf + m_in_buf_left, JPGD_IN_BUF_SIZE - m_in_buf_left, &m_eof_flag);
934
+ if (bytes_read == -1)
935
+ stop_decoding(JPGD_STREAM_READ);
936
+
937
+ m_in_buf_left += bytes_read;
938
+ } while ((m_in_buf_left < JPGD_IN_BUF_SIZE) && (!m_eof_flag));
939
+
940
+ m_total_bytes_read += m_in_buf_left;
941
+
942
+ // Pad the end of the block with M_EOI (prevents the decompressor from going off the rails if the stream is invalid).
943
+ // (This dates way back to when this decompressor was written in C/asm, and the all-asm Huffman decoder did some fancy things to increase perf.)
944
+ word_clear(m_pIn_buf_ofs + m_in_buf_left, 0xD9FF, 64);
945
+ }
946
+
947
+ // Read a Huffman code table.
948
+ void jpeg_decoder::read_dht_marker()
949
+ {
950
+ int i, index, count;
951
+ uint8 huff_num[17];
952
+ uint8 huff_val[256];
953
+
954
+ uint num_left = get_bits(16);
955
+
956
+ if (num_left < 2)
957
+ stop_decoding(JPGD_BAD_DHT_MARKER);
958
+
959
+ num_left -= 2;
960
+
961
+ while (num_left)
962
+ {
963
+ index = get_bits(8);
964
+
965
+ huff_num[0] = 0;
966
+
967
+ count = 0;
968
+
969
+ for (i = 1; i <= 16; i++)
970
+ {
971
+ huff_num[i] = static_cast<uint8>(get_bits(8));
972
+ count += huff_num[i];
973
+ }
974
+
975
+ if (count > 255)
976
+ stop_decoding(JPGD_BAD_DHT_COUNTS);
977
+
978
+ for (i = 0; i < count; i++)
979
+ huff_val[i] = static_cast<uint8>(get_bits(8));
980
+
981
+ i = 1 + 16 + count;
982
+
983
+ if (num_left < (uint)i)
984
+ stop_decoding(JPGD_BAD_DHT_MARKER);
985
+
986
+ num_left -= i;
987
+
988
+ if ((index & 0x10) > 0x10)
989
+ stop_decoding(JPGD_BAD_DHT_INDEX);
990
+
991
+ index = (index & 0x0F) + ((index & 0x10) >> 4) * (JPGD_MAX_HUFF_TABLES >> 1);
992
+
993
+ if (index >= JPGD_MAX_HUFF_TABLES)
994
+ stop_decoding(JPGD_BAD_DHT_INDEX);
995
+
996
+ if (!m_huff_num[index])
997
+ m_huff_num[index] = (uint8 *)alloc(17);
998
+
999
+ if (!m_huff_val[index])
1000
+ m_huff_val[index] = (uint8 *)alloc(256);
1001
+
1002
+ m_huff_ac[index] = (index & 0x10) != 0;
1003
+ memcpy(m_huff_num[index], huff_num, 17);
1004
+ memcpy(m_huff_val[index], huff_val, 256);
1005
+ }
1006
+ }
1007
+
1008
+ // Read a quantization table.
1009
+ void jpeg_decoder::read_dqt_marker()
1010
+ {
1011
+ int n, i, prec;
1012
+ uint num_left;
1013
+ uint temp;
1014
+
1015
+ num_left = get_bits(16);
1016
+
1017
+ if (num_left < 2)
1018
+ stop_decoding(JPGD_BAD_DQT_MARKER);
1019
+
1020
+ num_left -= 2;
1021
+
1022
+ while (num_left)
1023
+ {
1024
+ n = get_bits(8);
1025
+ prec = n >> 4;
1026
+ n &= 0x0F;
1027
+
1028
+ if (n >= JPGD_MAX_QUANT_TABLES)
1029
+ stop_decoding(JPGD_BAD_DQT_TABLE);
1030
+
1031
+ if (!m_quant[n])
1032
+ m_quant[n] = (jpgd_quant_t *)alloc(64 * sizeof(jpgd_quant_t));
1033
+
1034
+ // read quantization entries, in zag order
1035
+ for (i = 0; i < 64; i++)
1036
+ {
1037
+ temp = get_bits(8);
1038
+
1039
+ if (prec)
1040
+ temp = (temp << 8) + get_bits(8);
1041
+
1042
+ m_quant[n][i] = static_cast<jpgd_quant_t>(temp);
1043
+ }
1044
+
1045
+ i = 64 + 1;
1046
+
1047
+ if (prec)
1048
+ i += 64;
1049
+
1050
+ if (num_left < (uint)i)
1051
+ stop_decoding(JPGD_BAD_DQT_LENGTH);
1052
+
1053
+ num_left -= i;
1054
+ }
1055
+ }
1056
+
1057
+ // Read the start of frame (SOF) marker.
1058
+ void jpeg_decoder::read_sof_marker()
1059
+ {
1060
+ int i;
1061
+ uint num_left;
1062
+
1063
+ num_left = get_bits(16);
1064
+
1065
+ if (get_bits(8) != 8) /* precision: sorry, only 8-bit precision is supported right now */
1066
+ stop_decoding(JPGD_BAD_PRECISION);
1067
+
1068
+ m_image_y_size = get_bits(16);
1069
+
1070
+ if ((m_image_y_size < 1) || (m_image_y_size > JPGD_MAX_HEIGHT))
1071
+ stop_decoding(JPGD_BAD_HEIGHT);
1072
+
1073
+ m_image_x_size = get_bits(16);
1074
+
1075
+ if ((m_image_x_size < 1) || (m_image_x_size > JPGD_MAX_WIDTH))
1076
+ stop_decoding(JPGD_BAD_WIDTH);
1077
+
1078
+ m_comps_in_frame = get_bits(8);
1079
+
1080
+ if (m_comps_in_frame > JPGD_MAX_COMPONENTS)
1081
+ stop_decoding(JPGD_TOO_MANY_COMPONENTS);
1082
+
1083
+ if (num_left != (uint)(m_comps_in_frame * 3 + 8))
1084
+ stop_decoding(JPGD_BAD_SOF_LENGTH);
1085
+
1086
+ for (i = 0; i < m_comps_in_frame; i++)
1087
+ {
1088
+ m_comp_ident[i] = get_bits(8);
1089
+ m_comp_h_samp[i] = get_bits(4);
1090
+ m_comp_v_samp[i] = get_bits(4);
1091
+ m_comp_quant[i] = get_bits(8);
1092
+ }
1093
+ }
1094
+
1095
+ // Used to skip unrecognized markers.
1096
+ void jpeg_decoder::skip_variable_marker()
1097
+ {
1098
+ uint num_left;
1099
+
1100
+ num_left = get_bits(16);
1101
+
1102
+ if (num_left < 2)
1103
+ stop_decoding(JPGD_BAD_VARIABLE_MARKER);
1104
+
1105
+ num_left -= 2;
1106
+
1107
+ while (num_left)
1108
+ {
1109
+ get_bits(8);
1110
+ num_left--;
1111
+ }
1112
+ }
1113
+
1114
+ // Read a define restart interval (DRI) marker.
1115
+ void jpeg_decoder::read_dri_marker()
1116
+ {
1117
+ if (get_bits(16) != 4)
1118
+ stop_decoding(JPGD_BAD_DRI_LENGTH);
1119
+
1120
+ m_restart_interval = get_bits(16);
1121
+ }
1122
+
1123
+ // Read a start of scan (SOS) marker.
1124
+ void jpeg_decoder::read_sos_marker()
1125
+ {
1126
+ uint num_left;
1127
+ int i, ci, n, c, cc;
1128
+
1129
+ num_left = get_bits(16);
1130
+
1131
+ n = get_bits(8);
1132
+
1133
+ m_comps_in_scan = n;
1134
+
1135
+ num_left -= 3;
1136
+
1137
+ if ( (num_left != (uint)(n * 2 + 3)) || (n < 1) || (n > JPGD_MAX_COMPS_IN_SCAN) )
1138
+ stop_decoding(JPGD_BAD_SOS_LENGTH);
1139
+
1140
+ for (i = 0; i < n; i++)
1141
+ {
1142
+ cc = get_bits(8);
1143
+ c = get_bits(8);
1144
+ num_left -= 2;
1145
+
1146
+ for (ci = 0; ci < m_comps_in_frame; ci++)
1147
+ if (cc == m_comp_ident[ci])
1148
+ break;
1149
+
1150
+ if (ci >= m_comps_in_frame)
1151
+ stop_decoding(JPGD_BAD_SOS_COMP_ID);
1152
+
1153
+ m_comp_list[i] = ci;
1154
+ m_comp_dc_tab[ci] = (c >> 4) & 15;
1155
+ m_comp_ac_tab[ci] = (c & 15) + (JPGD_MAX_HUFF_TABLES >> 1);
1156
+ }
1157
+
1158
+ m_spectral_start = get_bits(8);
1159
+ m_spectral_end = get_bits(8);
1160
+ m_successive_high = get_bits(4);
1161
+ m_successive_low = get_bits(4);
1162
+
1163
+ if (!m_progressive_flag)
1164
+ {
1165
+ m_spectral_start = 0;
1166
+ m_spectral_end = 63;
1167
+ }
1168
+
1169
+ num_left -= 3;
1170
+
1171
+ while (num_left) /* read past whatever is num_left */
1172
+ {
1173
+ get_bits(8);
1174
+ num_left--;
1175
+ }
1176
+ }
1177
+
1178
+ // Finds the next marker.
1179
+ int jpeg_decoder::next_marker()
1180
+ {
1181
+ uint c, bytes;
1182
+
1183
+ bytes = 0;
1184
+
1185
+ do
1186
+ {
1187
+ do
1188
+ {
1189
+ bytes++;
1190
+ c = get_bits(8);
1191
+ } while (c != 0xFF);
1192
+
1193
+ do
1194
+ {
1195
+ c = get_bits(8);
1196
+ } while (c == 0xFF);
1197
+
1198
+ } while (c == 0);
1199
+
1200
+ // If bytes > 0 here, there where extra bytes before the marker (not good).
1201
+
1202
+ return c;
1203
+ }
1204
+
1205
+ // Process markers. Returns when an SOFx, SOI, EOI, or SOS marker is
1206
+ // encountered.
1207
+ int jpeg_decoder::process_markers()
1208
+ {
1209
+ int c;
1210
+
1211
+ for ( ; ; )
1212
+ {
1213
+ c = next_marker();
1214
+
1215
+ switch (c)
1216
+ {
1217
+ case M_SOF0:
1218
+ case M_SOF1:
1219
+ case M_SOF2:
1220
+ case M_SOF3:
1221
+ case M_SOF5:
1222
+ case M_SOF6:
1223
+ case M_SOF7:
1224
+ // case M_JPG:
1225
+ case M_SOF9:
1226
+ case M_SOF10:
1227
+ case M_SOF11:
1228
+ case M_SOF13:
1229
+ case M_SOF14:
1230
+ case M_SOF15:
1231
+ case M_SOI:
1232
+ case M_EOI:
1233
+ case M_SOS:
1234
+ {
1235
+ return c;
1236
+ }
1237
+ case M_DHT:
1238
+ {
1239
+ read_dht_marker();
1240
+ break;
1241
+ }
1242
+ // No arithmitic support - dumb patents!
1243
+ case M_DAC:
1244
+ {
1245
+ stop_decoding(JPGD_NO_ARITHMITIC_SUPPORT);
1246
+ break;
1247
+ }
1248
+ case M_DQT:
1249
+ {
1250
+ read_dqt_marker();
1251
+ break;
1252
+ }
1253
+ case M_DRI:
1254
+ {
1255
+ read_dri_marker();
1256
+ break;
1257
+ }
1258
+ //case M_APP0: /* no need to read the JFIF marker */
1259
+
1260
+ case M_JPG:
1261
+ case M_RST0: /* no parameters */
1262
+ case M_RST1:
1263
+ case M_RST2:
1264
+ case M_RST3:
1265
+ case M_RST4:
1266
+ case M_RST5:
1267
+ case M_RST6:
1268
+ case M_RST7:
1269
+ case M_TEM:
1270
+ {
1271
+ stop_decoding(JPGD_UNEXPECTED_MARKER);
1272
+ break;
1273
+ }
1274
+ default: /* must be DNL, DHP, EXP, APPn, JPGn, COM, or RESn or APP0 */
1275
+ {
1276
+ skip_variable_marker();
1277
+ break;
1278
+ }
1279
+ }
1280
+ }
1281
+ }
1282
+
1283
+ // Finds the start of image (SOI) marker.
1284
+ // This code is rather defensive: it only checks the first 512 bytes to avoid
1285
+ // false positives.
1286
+ void jpeg_decoder::locate_soi_marker()
1287
+ {
1288
+ uint lastchar, thischar;
1289
+ uint bytesleft;
1290
+
1291
+ lastchar = get_bits(8);
1292
+
1293
+ thischar = get_bits(8);
1294
+
1295
+ /* ok if it's a normal JPEG file without a special header */
1296
+
1297
+ if ((lastchar == 0xFF) && (thischar == M_SOI))
1298
+ return;
1299
+
1300
+ bytesleft = 4096; //512;
1301
+
1302
+ for ( ; ; )
1303
+ {
1304
+ if (--bytesleft == 0)
1305
+ stop_decoding(JPGD_NOT_JPEG);
1306
+
1307
+ lastchar = thischar;
1308
+
1309
+ thischar = get_bits(8);
1310
+
1311
+ if (lastchar == 0xFF)
1312
+ {
1313
+ if (thischar == M_SOI)
1314
+ break;
1315
+ else if (thischar == M_EOI) // get_bits will keep returning M_EOI if we read past the end
1316
+ stop_decoding(JPGD_NOT_JPEG);
1317
+ }
1318
+ }
1319
+
1320
+ // Check the next character after marker: if it's not 0xFF, it can't be the start of the next marker, so the file is bad.
1321
+ thischar = (m_bit_buf >> 24) & 0xFF;
1322
+
1323
+ if (thischar != 0xFF)
1324
+ stop_decoding(JPGD_NOT_JPEG);
1325
+ }
1326
+
1327
+ // Find a start of frame (SOF) marker.
1328
+ void jpeg_decoder::locate_sof_marker()
1329
+ {
1330
+ locate_soi_marker();
1331
+
1332
+ int c = process_markers();
1333
+
1334
+ switch (c)
1335
+ {
1336
+ case M_SOF2:
1337
+ m_progressive_flag = JPGD_TRUE;
1338
+ case M_SOF0: /* baseline DCT */
1339
+ case M_SOF1: /* extended sequential DCT */
1340
+ {
1341
+ read_sof_marker();
1342
+ break;
1343
+ }
1344
+ case M_SOF9: /* Arithmitic coding */
1345
+ {
1346
+ stop_decoding(JPGD_NO_ARITHMITIC_SUPPORT);
1347
+ break;
1348
+ }
1349
+ default:
1350
+ {
1351
+ stop_decoding(JPGD_UNSUPPORTED_MARKER);
1352
+ break;
1353
+ }
1354
+ }
1355
+ }
1356
+
1357
+ // Find a start of scan (SOS) marker.
1358
+ int jpeg_decoder::locate_sos_marker()
1359
+ {
1360
+ int c;
1361
+
1362
+ c = process_markers();
1363
+
1364
+ if (c == M_EOI)
1365
+ return JPGD_FALSE;
1366
+ else if (c != M_SOS)
1367
+ stop_decoding(JPGD_UNEXPECTED_MARKER);
1368
+
1369
+ read_sos_marker();
1370
+
1371
+ return JPGD_TRUE;
1372
+ }
1373
+
1374
+ // Reset everything to default/uninitialized state.
1375
+ void jpeg_decoder::init(jpeg_decoder_stream *pStream)
1376
+ {
1377
+ m_pMem_blocks = NULL;
1378
+ m_error_code = JPGD_SUCCESS;
1379
+ m_ready_flag = false;
1380
+ m_image_x_size = m_image_y_size = 0;
1381
+ m_pStream = pStream;
1382
+ m_progressive_flag = JPGD_FALSE;
1383
+
1384
+ memset(m_huff_ac, 0, sizeof(m_huff_ac));
1385
+ memset(m_huff_num, 0, sizeof(m_huff_num));
1386
+ memset(m_huff_val, 0, sizeof(m_huff_val));
1387
+ memset(m_quant, 0, sizeof(m_quant));
1388
+
1389
+ m_scan_type = 0;
1390
+ m_comps_in_frame = 0;
1391
+
1392
+ memset(m_comp_h_samp, 0, sizeof(m_comp_h_samp));
1393
+ memset(m_comp_v_samp, 0, sizeof(m_comp_v_samp));
1394
+ memset(m_comp_quant, 0, sizeof(m_comp_quant));
1395
+ memset(m_comp_ident, 0, sizeof(m_comp_ident));
1396
+ memset(m_comp_h_blocks, 0, sizeof(m_comp_h_blocks));
1397
+ memset(m_comp_v_blocks, 0, sizeof(m_comp_v_blocks));
1398
+
1399
+ m_comps_in_scan = 0;
1400
+ memset(m_comp_list, 0, sizeof(m_comp_list));
1401
+ memset(m_comp_dc_tab, 0, sizeof(m_comp_dc_tab));
1402
+ memset(m_comp_ac_tab, 0, sizeof(m_comp_ac_tab));
1403
+
1404
+ m_spectral_start = 0;
1405
+ m_spectral_end = 0;
1406
+ m_successive_low = 0;
1407
+ m_successive_high = 0;
1408
+ m_max_mcu_x_size = 0;
1409
+ m_max_mcu_y_size = 0;
1410
+ m_blocks_per_mcu = 0;
1411
+ m_max_blocks_per_row = 0;
1412
+ m_mcus_per_row = 0;
1413
+ m_mcus_per_col = 0;
1414
+ m_expanded_blocks_per_component = 0;
1415
+ m_expanded_blocks_per_mcu = 0;
1416
+ m_expanded_blocks_per_row = 0;
1417
+ m_freq_domain_chroma_upsample = false;
1418
+
1419
+ memset(m_mcu_org, 0, sizeof(m_mcu_org));
1420
+
1421
+ m_total_lines_left = 0;
1422
+ m_mcu_lines_left = 0;
1423
+ m_real_dest_bytes_per_scan_line = 0;
1424
+ m_dest_bytes_per_scan_line = 0;
1425
+ m_dest_bytes_per_pixel = 0;
1426
+
1427
+ memset(m_pHuff_tabs, 0, sizeof(m_pHuff_tabs));
1428
+
1429
+ memset(m_dc_coeffs, 0, sizeof(m_dc_coeffs));
1430
+ memset(m_ac_coeffs, 0, sizeof(m_ac_coeffs));
1431
+ memset(m_block_y_mcu, 0, sizeof(m_block_y_mcu));
1432
+
1433
+ m_eob_run = 0;
1434
+
1435
+ memset(m_block_y_mcu, 0, sizeof(m_block_y_mcu));
1436
+
1437
+ m_pIn_buf_ofs = m_in_buf;
1438
+ m_in_buf_left = 0;
1439
+ m_eof_flag = false;
1440
+ m_tem_flag = 0;
1441
+
1442
+ memset(m_in_buf_pad_start, 0, sizeof(m_in_buf_pad_start));
1443
+ memset(m_in_buf, 0, sizeof(m_in_buf));
1444
+ memset(m_in_buf_pad_end, 0, sizeof(m_in_buf_pad_end));
1445
+
1446
+ m_restart_interval = 0;
1447
+ m_restarts_left = 0;
1448
+ m_next_restart_num = 0;
1449
+
1450
+ m_max_mcus_per_row = 0;
1451
+ m_max_blocks_per_mcu = 0;
1452
+ m_max_mcus_per_col = 0;
1453
+
1454
+ memset(m_last_dc_val, 0, sizeof(m_last_dc_val));
1455
+ m_pMCU_coefficients = NULL;
1456
+ m_pSample_buf = NULL;
1457
+
1458
+ m_total_bytes_read = 0;
1459
+
1460
+ m_pScan_line_0 = NULL;
1461
+ m_pScan_line_1 = NULL;
1462
+
1463
+ // Ready the input buffer.
1464
+ prep_in_buffer();
1465
+
1466
+ // Prime the bit buffer.
1467
+ m_bits_left = 16;
1468
+ m_bit_buf = 0;
1469
+
1470
+ get_bits(16);
1471
+ get_bits(16);
1472
+
1473
+ for (int i = 0; i < JPGD_MAX_BLOCKS_PER_MCU; i++)
1474
+ m_mcu_block_max_zag[i] = 64;
1475
+ }
1476
+
1477
+ #define SCALEBITS 16
1478
+ #define ONE_HALF ((int) 1 << (SCALEBITS-1))
1479
+ #define FIX(x) ((int) ((x) * (1L<<SCALEBITS) + 0.5f))
1480
+
1481
+ // Create a few tables that allow us to quickly convert YCbCr to RGB.
1482
+ void jpeg_decoder::create_look_ups()
1483
+ {
1484
+ for (int i = 0; i <= 255; i++)
1485
+ {
1486
+ int k = i - 128;
1487
+ m_crr[i] = ( FIX(1.40200f) * k + ONE_HALF) >> SCALEBITS;
1488
+ m_cbb[i] = ( FIX(1.77200f) * k + ONE_HALF) >> SCALEBITS;
1489
+ m_crg[i] = (-FIX(0.71414f)) * k;
1490
+ m_cbg[i] = (-FIX(0.34414f)) * k + ONE_HALF;
1491
+ }
1492
+ }
1493
+
1494
+ // This method throws back into the stream any bytes that where read
1495
+ // into the bit buffer during initial marker scanning.
1496
+ void jpeg_decoder::fix_in_buffer()
1497
+ {
1498
+ // In case any 0xFF's where pulled into the buffer during marker scanning.
1499
+ JPGD_ASSERT((m_bits_left & 7) == 0);
1500
+
1501
+ if (m_bits_left == 16)
1502
+ stuff_char( (uint8)(m_bit_buf & 0xFF));
1503
+
1504
+ if (m_bits_left >= 8)
1505
+ stuff_char( (uint8)((m_bit_buf >> 8) & 0xFF));
1506
+
1507
+ stuff_char((uint8)((m_bit_buf >> 16) & 0xFF));
1508
+ stuff_char((uint8)((m_bit_buf >> 24) & 0xFF));
1509
+
1510
+ m_bits_left = 16;
1511
+ get_bits_no_markers(16);
1512
+ get_bits_no_markers(16);
1513
+ }
1514
+
1515
+ void jpeg_decoder::transform_mcu(int mcu_row)
1516
+ {
1517
+ jpgd_block_t* pSrc_ptr = m_pMCU_coefficients;
1518
+ uint8* pDst_ptr = m_pSample_buf + mcu_row * m_blocks_per_mcu * 64;
1519
+
1520
+ for (int mcu_block = 0; mcu_block < m_blocks_per_mcu; mcu_block++)
1521
+ {
1522
+ idct(pSrc_ptr, pDst_ptr, m_mcu_block_max_zag[mcu_block]);
1523
+ pSrc_ptr += 64;
1524
+ pDst_ptr += 64;
1525
+ }
1526
+ }
1527
+
1528
+ static const uint8 s_max_rc[64] =
1529
+ {
1530
+ 17, 18, 34, 50, 50, 51, 52, 52, 52, 68, 84, 84, 84, 84, 85, 86, 86, 86, 86, 86,
1531
+ 102, 118, 118, 118, 118, 118, 118, 119, 120, 120, 120, 120, 120, 120, 120, 136,
1532
+ 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136,
1533
+ 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136
1534
+ };
1535
+
1536
+ void jpeg_decoder::transform_mcu_expand(int mcu_row)
1537
+ {
1538
+ jpgd_block_t* pSrc_ptr = m_pMCU_coefficients;
1539
+ uint8* pDst_ptr = m_pSample_buf + mcu_row * m_expanded_blocks_per_mcu * 64;
1540
+
1541
+ // Y IDCT
1542
+ int mcu_block;
1543
+ for (mcu_block = 0; mcu_block < m_expanded_blocks_per_component; mcu_block++)
1544
+ {
1545
+ idct(pSrc_ptr, pDst_ptr, m_mcu_block_max_zag[mcu_block]);
1546
+ pSrc_ptr += 64;
1547
+ pDst_ptr += 64;
1548
+ }
1549
+
1550
+ // Chroma IDCT, with upsampling
1551
+ jpgd_block_t temp_block[64];
1552
+
1553
+ for (int i = 0; i < 2; i++)
1554
+ {
1555
+ DCT_Upsample::Matrix44 P, Q, R, S;
1556
+
1557
+ JPGD_ASSERT(m_mcu_block_max_zag[mcu_block] >= 1);
1558
+ JPGD_ASSERT(m_mcu_block_max_zag[mcu_block] <= 64);
1559
+
1560
+ switch (s_max_rc[m_mcu_block_max_zag[mcu_block++] - 1])
1561
+ {
1562
+ case 1*16+1:
1563
+ DCT_Upsample::P_Q<1, 1>::calc(P, Q, pSrc_ptr);
1564
+ DCT_Upsample::R_S<1, 1>::calc(R, S, pSrc_ptr);
1565
+ break;
1566
+ case 1*16+2:
1567
+ DCT_Upsample::P_Q<1, 2>::calc(P, Q, pSrc_ptr);
1568
+ DCT_Upsample::R_S<1, 2>::calc(R, S, pSrc_ptr);
1569
+ break;
1570
+ case 2*16+2:
1571
+ DCT_Upsample::P_Q<2, 2>::calc(P, Q, pSrc_ptr);
1572
+ DCT_Upsample::R_S<2, 2>::calc(R, S, pSrc_ptr);
1573
+ break;
1574
+ case 3*16+2:
1575
+ DCT_Upsample::P_Q<3, 2>::calc(P, Q, pSrc_ptr);
1576
+ DCT_Upsample::R_S<3, 2>::calc(R, S, pSrc_ptr);
1577
+ break;
1578
+ case 3*16+3:
1579
+ DCT_Upsample::P_Q<3, 3>::calc(P, Q, pSrc_ptr);
1580
+ DCT_Upsample::R_S<3, 3>::calc(R, S, pSrc_ptr);
1581
+ break;
1582
+ case 3*16+4:
1583
+ DCT_Upsample::P_Q<3, 4>::calc(P, Q, pSrc_ptr);
1584
+ DCT_Upsample::R_S<3, 4>::calc(R, S, pSrc_ptr);
1585
+ break;
1586
+ case 4*16+4:
1587
+ DCT_Upsample::P_Q<4, 4>::calc(P, Q, pSrc_ptr);
1588
+ DCT_Upsample::R_S<4, 4>::calc(R, S, pSrc_ptr);
1589
+ break;
1590
+ case 5*16+4:
1591
+ DCT_Upsample::P_Q<5, 4>::calc(P, Q, pSrc_ptr);
1592
+ DCT_Upsample::R_S<5, 4>::calc(R, S, pSrc_ptr);
1593
+ break;
1594
+ case 5*16+5:
1595
+ DCT_Upsample::P_Q<5, 5>::calc(P, Q, pSrc_ptr);
1596
+ DCT_Upsample::R_S<5, 5>::calc(R, S, pSrc_ptr);
1597
+ break;
1598
+ case 5*16+6:
1599
+ DCT_Upsample::P_Q<5, 6>::calc(P, Q, pSrc_ptr);
1600
+ DCT_Upsample::R_S<5, 6>::calc(R, S, pSrc_ptr);
1601
+ break;
1602
+ case 6*16+6:
1603
+ DCT_Upsample::P_Q<6, 6>::calc(P, Q, pSrc_ptr);
1604
+ DCT_Upsample::R_S<6, 6>::calc(R, S, pSrc_ptr);
1605
+ break;
1606
+ case 7*16+6:
1607
+ DCT_Upsample::P_Q<7, 6>::calc(P, Q, pSrc_ptr);
1608
+ DCT_Upsample::R_S<7, 6>::calc(R, S, pSrc_ptr);
1609
+ break;
1610
+ case 7*16+7:
1611
+ DCT_Upsample::P_Q<7, 7>::calc(P, Q, pSrc_ptr);
1612
+ DCT_Upsample::R_S<7, 7>::calc(R, S, pSrc_ptr);
1613
+ break;
1614
+ case 7*16+8:
1615
+ DCT_Upsample::P_Q<7, 8>::calc(P, Q, pSrc_ptr);
1616
+ DCT_Upsample::R_S<7, 8>::calc(R, S, pSrc_ptr);
1617
+ break;
1618
+ case 8*16+8:
1619
+ DCT_Upsample::P_Q<8, 8>::calc(P, Q, pSrc_ptr);
1620
+ DCT_Upsample::R_S<8, 8>::calc(R, S, pSrc_ptr);
1621
+ break;
1622
+ default:
1623
+ JPGD_ASSERT(false);
1624
+ }
1625
+
1626
+ DCT_Upsample::Matrix44 a(P + Q); P -= Q;
1627
+ DCT_Upsample::Matrix44& b = P;
1628
+ DCT_Upsample::Matrix44 c(R + S); R -= S;
1629
+ DCT_Upsample::Matrix44& d = R;
1630
+
1631
+ DCT_Upsample::Matrix44::add_and_store(temp_block, a, c);
1632
+ idct_4x4(temp_block, pDst_ptr);
1633
+ pDst_ptr += 64;
1634
+
1635
+ DCT_Upsample::Matrix44::sub_and_store(temp_block, a, c);
1636
+ idct_4x4(temp_block, pDst_ptr);
1637
+ pDst_ptr += 64;
1638
+
1639
+ DCT_Upsample::Matrix44::add_and_store(temp_block, b, d);
1640
+ idct_4x4(temp_block, pDst_ptr);
1641
+ pDst_ptr += 64;
1642
+
1643
+ DCT_Upsample::Matrix44::sub_and_store(temp_block, b, d);
1644
+ idct_4x4(temp_block, pDst_ptr);
1645
+ pDst_ptr += 64;
1646
+
1647
+ pSrc_ptr += 64;
1648
+ }
1649
+ }
1650
+
1651
+ // Loads and dequantizes the next row of (already decoded) coefficients.
1652
+ // Progressive images only.
1653
+ void jpeg_decoder::load_next_row()
1654
+ {
1655
+ int i;
1656
+ jpgd_block_t *p;
1657
+ jpgd_quant_t *q;
1658
+ int mcu_row, mcu_block, row_block = 0;
1659
+ int component_num, component_id;
1660
+ int block_x_mcu[JPGD_MAX_COMPONENTS];
1661
+
1662
+ memset(block_x_mcu, 0, JPGD_MAX_COMPONENTS * sizeof(int));
1663
+
1664
+ for (mcu_row = 0; mcu_row < m_mcus_per_row; mcu_row++)
1665
+ {
1666
+ int block_x_mcu_ofs = 0, block_y_mcu_ofs = 0;
1667
+
1668
+ for (mcu_block = 0; mcu_block < m_blocks_per_mcu; mcu_block++)
1669
+ {
1670
+ component_id = m_mcu_org[mcu_block];
1671
+ q = m_quant[m_comp_quant[component_id]];
1672
+
1673
+ p = m_pMCU_coefficients + 64 * mcu_block;
1674
+
1675
+ jpgd_block_t* pAC = coeff_buf_getp(m_ac_coeffs[component_id], block_x_mcu[component_id] + block_x_mcu_ofs, m_block_y_mcu[component_id] + block_y_mcu_ofs);
1676
+ jpgd_block_t* pDC = coeff_buf_getp(m_dc_coeffs[component_id], block_x_mcu[component_id] + block_x_mcu_ofs, m_block_y_mcu[component_id] + block_y_mcu_ofs);
1677
+ p[0] = pDC[0];
1678
+ memcpy(&p[1], &pAC[1], 63 * sizeof(jpgd_block_t));
1679
+
1680
+ for (i = 63; i > 0; i--)
1681
+ if (p[g_ZAG[i]])
1682
+ break;
1683
+
1684
+ m_mcu_block_max_zag[mcu_block] = i + 1;
1685
+
1686
+ for ( ; i >= 0; i--)
1687
+ if (p[g_ZAG[i]])
1688
+ p[g_ZAG[i]] = static_cast<jpgd_block_t>(p[g_ZAG[i]] * q[i]);
1689
+
1690
+ row_block++;
1691
+
1692
+ if (m_comps_in_scan == 1)
1693
+ block_x_mcu[component_id]++;
1694
+ else
1695
+ {
1696
+ if (++block_x_mcu_ofs == m_comp_h_samp[component_id])
1697
+ {
1698
+ block_x_mcu_ofs = 0;
1699
+
1700
+ if (++block_y_mcu_ofs == m_comp_v_samp[component_id])
1701
+ {
1702
+ block_y_mcu_ofs = 0;
1703
+
1704
+ block_x_mcu[component_id] += m_comp_h_samp[component_id];
1705
+ }
1706
+ }
1707
+ }
1708
+ }
1709
+
1710
+ if (m_freq_domain_chroma_upsample)
1711
+ transform_mcu_expand(mcu_row);
1712
+ else
1713
+ transform_mcu(mcu_row);
1714
+ }
1715
+
1716
+ if (m_comps_in_scan == 1)
1717
+ m_block_y_mcu[m_comp_list[0]]++;
1718
+ else
1719
+ {
1720
+ for (component_num = 0; component_num < m_comps_in_scan; component_num++)
1721
+ {
1722
+ component_id = m_comp_list[component_num];
1723
+
1724
+ m_block_y_mcu[component_id] += m_comp_v_samp[component_id];
1725
+ }
1726
+ }
1727
+ }
1728
+
1729
+ // Restart interval processing.
1730
+ void jpeg_decoder::process_restart()
1731
+ {
1732
+ int i;
1733
+ int c = 0;
1734
+
1735
+ // Align to a byte boundry
1736
+ // FIXME: Is this really necessary? get_bits_no_markers() never reads in markers!
1737
+ //get_bits_no_markers(m_bits_left & 7);
1738
+
1739
+ // Let's scan a little bit to find the marker, but not _too_ far.
1740
+ // 1536 is a "fudge factor" that determines how much to scan.
1741
+ for (i = 1536; i > 0; i--)
1742
+ if (get_char() == 0xFF)
1743
+ break;
1744
+
1745
+ if (i == 0)
1746
+ stop_decoding(JPGD_BAD_RESTART_MARKER);
1747
+
1748
+ for ( ; i > 0; i--)
1749
+ if ((c = get_char()) != 0xFF)
1750
+ break;
1751
+
1752
+ if (i == 0)
1753
+ stop_decoding(JPGD_BAD_RESTART_MARKER);
1754
+
1755
+ // Is it the expected marker? If not, something bad happened.
1756
+ if (c != (m_next_restart_num + M_RST0))
1757
+ stop_decoding(JPGD_BAD_RESTART_MARKER);
1758
+
1759
+ // Reset each component's DC prediction values.
1760
+ memset(&m_last_dc_val, 0, m_comps_in_frame * sizeof(uint));
1761
+
1762
+ m_eob_run = 0;
1763
+
1764
+ m_restarts_left = m_restart_interval;
1765
+
1766
+ m_next_restart_num = (m_next_restart_num + 1) & 7;
1767
+
1768
+ // Get the bit buffer going again...
1769
+
1770
+ m_bits_left = 16;
1771
+ get_bits_no_markers(16);
1772
+ get_bits_no_markers(16);
1773
+ }
1774
+
1775
+ static inline int dequantize_ac(int c, int q) { c *= q; return c; }
1776
+
1777
+ // Decodes and dequantizes the next row of coefficients.
1778
+ void jpeg_decoder::decode_next_row()
1779
+ {
1780
+ int row_block = 0;
1781
+
1782
+ for (int mcu_row = 0; mcu_row < m_mcus_per_row; mcu_row++)
1783
+ {
1784
+ if ((m_restart_interval) && (m_restarts_left == 0))
1785
+ process_restart();
1786
+
1787
+ jpgd_block_t* p = m_pMCU_coefficients;
1788
+ for (int mcu_block = 0; mcu_block < m_blocks_per_mcu; mcu_block++, p += 64)
1789
+ {
1790
+ int component_id = m_mcu_org[mcu_block];
1791
+ jpgd_quant_t* q = m_quant[m_comp_quant[component_id]];
1792
+
1793
+ int r, s;
1794
+ s = huff_decode(m_pHuff_tabs[m_comp_dc_tab[component_id]], r);
1795
+ s = HUFF_EXTEND(r, s);
1796
+
1797
+ m_last_dc_val[component_id] = (s += m_last_dc_val[component_id]);
1798
+
1799
+ p[0] = static_cast<jpgd_block_t>(s * q[0]);
1800
+
1801
+ int prev_num_set = m_mcu_block_max_zag[mcu_block];
1802
+
1803
+ huff_tables *pH = m_pHuff_tabs[m_comp_ac_tab[component_id]];
1804
+
1805
+ int k;
1806
+ for (k = 1; k < 64; k++)
1807
+ {
1808
+ int extra_bits;
1809
+ s = huff_decode(pH, extra_bits);
1810
+
1811
+ r = s >> 4;
1812
+ s &= 15;
1813
+
1814
+ if (s)
1815
+ {
1816
+ if (r)
1817
+ {
1818
+ if ((k + r) > 63)
1819
+ stop_decoding(JPGD_DECODE_ERROR);
1820
+
1821
+ if (k < prev_num_set)
1822
+ {
1823
+ int n = JPGD_MIN(r, prev_num_set - k);
1824
+ int kt = k;
1825
+ while (n--)
1826
+ p[g_ZAG[kt++]] = 0;
1827
+ }
1828
+
1829
+ k += r;
1830
+ }
1831
+
1832
+ s = HUFF_EXTEND(extra_bits, s);
1833
+
1834
+ JPGD_ASSERT(k < 64);
1835
+
1836
+ p[g_ZAG[k]] = static_cast<jpgd_block_t>(dequantize_ac(s, q[k])); //s * q[k];
1837
+ }
1838
+ else
1839
+ {
1840
+ if (r == 15)
1841
+ {
1842
+ if ((k + 16) > 64)
1843
+ stop_decoding(JPGD_DECODE_ERROR);
1844
+
1845
+ if (k < prev_num_set)
1846
+ {
1847
+ int n = JPGD_MIN(16, prev_num_set - k);
1848
+ int kt = k;
1849
+ while (n--)
1850
+ {
1851
+ JPGD_ASSERT(kt <= 63);
1852
+ p[g_ZAG[kt++]] = 0;
1853
+ }
1854
+ }
1855
+
1856
+ k += 16 - 1; // - 1 because the loop counter is k
1857
+ // BEGIN EPIC MOD
1858
+ JPGD_ASSERT(k < 64 && p[g_ZAG[k]] == 0);
1859
+ // END EPIC MOD
1860
+ }
1861
+ else
1862
+ break;
1863
+ }
1864
+ }
1865
+
1866
+ if (k < prev_num_set)
1867
+ {
1868
+ int kt = k;
1869
+ while (kt < prev_num_set)
1870
+ p[g_ZAG[kt++]] = 0;
1871
+ }
1872
+
1873
+ m_mcu_block_max_zag[mcu_block] = k;
1874
+
1875
+ row_block++;
1876
+ }
1877
+
1878
+ if (m_freq_domain_chroma_upsample)
1879
+ transform_mcu_expand(mcu_row);
1880
+ else
1881
+ transform_mcu(mcu_row);
1882
+
1883
+ m_restarts_left--;
1884
+ }
1885
+ }
1886
+
1887
+ // YCbCr H1V1 (1x1:1:1, 3 m_blocks per MCU) to RGB
1888
+ void jpeg_decoder::H1V1Convert()
1889
+ {
1890
+ int row = m_max_mcu_y_size - m_mcu_lines_left;
1891
+ uint8 *d = m_pScan_line_0;
1892
+ uint8 *s = m_pSample_buf + row * 8;
1893
+
1894
+ for (int i = m_max_mcus_per_row; i > 0; i--)
1895
+ {
1896
+ for (int j = 0; j < 8; j++)
1897
+ {
1898
+ int y = s[j];
1899
+ int cb = s[64+j];
1900
+ int cr = s[128+j];
1901
+
1902
+ if (jpg_format == ERGBFormatJPG::BGRA)
1903
+ {
1904
+ d[0] = clamp(y + m_cbb[cb]);
1905
+ d[1] = clamp(y + ((m_crg[cr] + m_cbg[cb]) >> 16));
1906
+ d[2] = clamp(y + m_crr[cr]);
1907
+ d[3] = 255;
1908
+ }
1909
+ else
1910
+ {
1911
+ d[0] = clamp(y + m_crr[cr]);
1912
+ d[1] = clamp(y + ((m_crg[cr] + m_cbg[cb]) >> 16));
1913
+ d[2] = clamp(y + m_cbb[cb]);
1914
+ d[3] = 255;
1915
+ }
1916
+ d += 4;
1917
+ }
1918
+
1919
+ s += 64*3;
1920
+ }
1921
+ }
1922
+
1923
+ // YCbCr H2V1 (2x1:1:1, 4 m_blocks per MCU) to RGB
1924
+ void jpeg_decoder::H2V1Convert()
1925
+ {
1926
+ int row = m_max_mcu_y_size - m_mcu_lines_left;
1927
+ uint8 *d0 = m_pScan_line_0;
1928
+ uint8 *y = m_pSample_buf + row * 8;
1929
+ uint8 *c = m_pSample_buf + 2*64 + row * 8;
1930
+
1931
+ for (int i = m_max_mcus_per_row; i > 0; i--)
1932
+ {
1933
+ for (int l = 0; l < 2; l++)
1934
+ {
1935
+ for (int j = 0; j < 4; j++)
1936
+ {
1937
+ int cb = c[0];
1938
+ int cr = c[64];
1939
+
1940
+ int rc = m_crr[cr];
1941
+ int gc = ((m_crg[cr] + m_cbg[cb]) >> 16);
1942
+ int bc = m_cbb[cb];
1943
+
1944
+ int yy = y[j<<1];
1945
+ if (jpg_format == ERGBFormatJPG::BGRA)
1946
+ {
1947
+ d0[0] = clamp(yy+bc);
1948
+ d0[1] = clamp(yy+gc);
1949
+ d0[2] = clamp(yy+rc);
1950
+ d0[3] = 255;
1951
+ yy = y[(j<<1)+1];
1952
+ d0[4] = clamp(yy+bc);
1953
+ d0[5] = clamp(yy+gc);
1954
+ d0[6] = clamp(yy+rc);
1955
+ d0[7] = 255;
1956
+ }
1957
+ else
1958
+ {
1959
+ d0[0] = clamp(yy+rc);
1960
+ d0[1] = clamp(yy+gc);
1961
+ d0[2] = clamp(yy+bc);
1962
+ d0[3] = 255;
1963
+ yy = y[(j<<1)+1];
1964
+ d0[4] = clamp(yy+rc);
1965
+ d0[5] = clamp(yy+gc);
1966
+ d0[6] = clamp(yy+bc);
1967
+ d0[7] = 255;
1968
+ }
1969
+
1970
+ d0 += 8;
1971
+
1972
+ c++;
1973
+ }
1974
+ y += 64;
1975
+ }
1976
+
1977
+ y += 64*4 - 64*2;
1978
+ c += 64*4 - 8;
1979
+ }
1980
+ }
1981
+
1982
+ // YCbCr H2V1 (1x2:1:1, 4 m_blocks per MCU) to RGB
1983
+ void jpeg_decoder::H1V2Convert()
1984
+ {
1985
+ int row = m_max_mcu_y_size - m_mcu_lines_left;
1986
+ uint8 *d0 = m_pScan_line_0;
1987
+ uint8 *d1 = m_pScan_line_1;
1988
+ uint8 *y;
1989
+ uint8 *c;
1990
+
1991
+ if (row < 8)
1992
+ y = m_pSample_buf + row * 8;
1993
+ else
1994
+ y = m_pSample_buf + 64*1 + (row & 7) * 8;
1995
+
1996
+ c = m_pSample_buf + 64*2 + (row >> 1) * 8;
1997
+
1998
+ for (int i = m_max_mcus_per_row; i > 0; i--)
1999
+ {
2000
+ for (int j = 0; j < 8; j++)
2001
+ {
2002
+ int cb = c[0+j];
2003
+ int cr = c[64+j];
2004
+
2005
+ int rc = m_crr[cr];
2006
+ int gc = ((m_crg[cr] + m_cbg[cb]) >> 16);
2007
+ int bc = m_cbb[cb];
2008
+
2009
+ int yy = y[j];
2010
+ if (jpg_format == ERGBFormatJPG::BGRA)
2011
+ {
2012
+ d0[0] = clamp(yy+bc);
2013
+ d0[1] = clamp(yy+gc);
2014
+ d0[2] = clamp(yy+rc);
2015
+ d0[3] = 255;
2016
+ yy = y[8+j];
2017
+ d1[0] = clamp(yy+bc);
2018
+ d1[1] = clamp(yy+gc);
2019
+ d1[2] = clamp(yy+rc);
2020
+ d1[3] = 255;
2021
+ }
2022
+ else
2023
+ {
2024
+ d0[0] = clamp(yy+rc);
2025
+ d0[1] = clamp(yy+gc);
2026
+ d0[2] = clamp(yy+bc);
2027
+ d0[3] = 255;
2028
+ yy = y[8+j];
2029
+ d1[0] = clamp(yy+rc);
2030
+ d1[1] = clamp(yy+gc);
2031
+ d1[2] = clamp(yy+bc);
2032
+ d1[3] = 255;
2033
+ }
2034
+
2035
+ d0 += 4;
2036
+ d1 += 4;
2037
+ }
2038
+
2039
+ y += 64*4;
2040
+ c += 64*4;
2041
+ }
2042
+ }
2043
+
2044
+ // YCbCr H2V2 (2x2:1:1, 6 m_blocks per MCU) to RGB
2045
+ void jpeg_decoder::H2V2Convert()
2046
+ {
2047
+ int row = m_max_mcu_y_size - m_mcu_lines_left;
2048
+ uint8 *d0 = m_pScan_line_0;
2049
+ uint8 *d1 = m_pScan_line_1;
2050
+ uint8 *y;
2051
+ uint8 *c;
2052
+
2053
+ if (row < 8)
2054
+ y = m_pSample_buf + row * 8;
2055
+ else
2056
+ y = m_pSample_buf + 64*2 + (row & 7) * 8;
2057
+
2058
+ c = m_pSample_buf + 64*4 + (row >> 1) * 8;
2059
+
2060
+ for (int i = m_max_mcus_per_row; i > 0; i--)
2061
+ {
2062
+ for (int l = 0; l < 2; l++)
2063
+ {
2064
+ for (int j = 0; j < 8; j += 2)
2065
+ {
2066
+ int cb = c[0];
2067
+ int cr = c[64];
2068
+
2069
+ int rc = m_crr[cr];
2070
+ int gc = ((m_crg[cr] + m_cbg[cb]) >> 16);
2071
+ int bc = m_cbb[cb];
2072
+
2073
+ int yy = y[j];
2074
+ if (jpg_format == ERGBFormatJPG::BGRA)
2075
+ {
2076
+ d0[0] = clamp(yy+bc);
2077
+ d0[1] = clamp(yy+gc);
2078
+ d0[2] = clamp(yy+rc);
2079
+ d0[3] = 255;
2080
+ yy = y[j+1];
2081
+ d0[4] = clamp(yy+bc);
2082
+ d0[5] = clamp(yy+gc);
2083
+ d0[6] = clamp(yy+rc);
2084
+ d0[7] = 255;
2085
+ yy = y[j+8];
2086
+ d1[0] = clamp(yy+bc);
2087
+ d1[1] = clamp(yy+gc);
2088
+ d1[2] = clamp(yy+rc);
2089
+ d1[3] = 255;
2090
+ yy = y[j+8+1];
2091
+ d1[4] = clamp(yy+bc);
2092
+ d1[5] = clamp(yy+gc);
2093
+ d1[6] = clamp(yy+rc);
2094
+ d1[7] = 255;
2095
+ }
2096
+ else
2097
+ {
2098
+ d0[0] = clamp(yy+rc);
2099
+ d0[1] = clamp(yy+gc);
2100
+ d0[2] = clamp(yy+bc);
2101
+ d0[3] = 255;
2102
+ yy = y[j+1];
2103
+ d0[4] = clamp(yy+rc);
2104
+ d0[5] = clamp(yy+gc);
2105
+ d0[6] = clamp(yy+bc);
2106
+ d0[7] = 255;
2107
+ yy = y[j+8];
2108
+ d1[0] = clamp(yy+rc);
2109
+ d1[1] = clamp(yy+gc);
2110
+ d1[2] = clamp(yy+bc);
2111
+ d1[3] = 255;
2112
+ yy = y[j+8+1];
2113
+ d1[4] = clamp(yy+rc);
2114
+ d1[5] = clamp(yy+gc);
2115
+ d1[6] = clamp(yy+bc);
2116
+ d1[7] = 255;
2117
+ }
2118
+
2119
+ d0 += 8;
2120
+ d1 += 8;
2121
+
2122
+ c++;
2123
+ }
2124
+ y += 64;
2125
+ }
2126
+
2127
+ y += 64*6 - 64*2;
2128
+ c += 64*6 - 8;
2129
+ }
2130
+ }
2131
+
2132
+ // Y (1 block per MCU) to 8-bit grayscale
2133
+ void jpeg_decoder::gray_convert()
2134
+ {
2135
+ int row = m_max_mcu_y_size - m_mcu_lines_left;
2136
+ uint8 *d = m_pScan_line_0;
2137
+ uint8 *s = m_pSample_buf + row * 8;
2138
+
2139
+ for (int i = m_max_mcus_per_row; i > 0; i--)
2140
+ {
2141
+ *(uint *)d = *(uint *)s;
2142
+ *(uint *)(&d[4]) = *(uint *)(&s[4]);
2143
+
2144
+ s += 64;
2145
+ d += 8;
2146
+ }
2147
+ }
2148
+
2149
+ void jpeg_decoder::expanded_convert()
2150
+ {
2151
+ int row = m_max_mcu_y_size - m_mcu_lines_left;
2152
+
2153
+ uint8* Py = m_pSample_buf + (row / 8) * 64 * m_comp_h_samp[0] + (row & 7) * 8;
2154
+
2155
+ uint8* d = m_pScan_line_0;
2156
+
2157
+ for (int i = m_max_mcus_per_row; i > 0; i--)
2158
+ {
2159
+ for (int k = 0; k < m_max_mcu_x_size; k += 8)
2160
+ {
2161
+ const int Y_ofs = k * 8;
2162
+ const int Cb_ofs = Y_ofs + 64 * m_expanded_blocks_per_component;
2163
+ const int Cr_ofs = Y_ofs + 64 * m_expanded_blocks_per_component * 2;
2164
+ for (int j = 0; j < 8; j++)
2165
+ {
2166
+ int y = Py[Y_ofs + j];
2167
+ int cb = Py[Cb_ofs + j];
2168
+ int cr = Py[Cr_ofs + j];
2169
+
2170
+ if (jpg_format == ERGBFormatJPG::BGRA)
2171
+ {
2172
+ d[0] = clamp(y + m_cbb[cb]);
2173
+ d[1] = clamp(y + ((m_crg[cr] + m_cbg[cb]) >> 16));
2174
+ d[2] = clamp(y + m_crr[cr]);
2175
+ d[3] = 255;
2176
+ }
2177
+ else
2178
+ {
2179
+ d[0] = clamp(y + m_crr[cr]);
2180
+ d[1] = clamp(y + ((m_crg[cr] + m_cbg[cb]) >> 16));
2181
+ d[2] = clamp(y + m_cbb[cb]);
2182
+ d[3] = 255;
2183
+ }
2184
+
2185
+ d += 4;
2186
+ }
2187
+ }
2188
+
2189
+ Py += 64 * m_expanded_blocks_per_mcu;
2190
+ }
2191
+ }
2192
+
2193
+ // Find end of image (EOI) marker, so we can return to the user the exact size of the input stream.
2194
+ void jpeg_decoder::find_eoi()
2195
+ {
2196
+ if (!m_progressive_flag)
2197
+ {
2198
+ // Attempt to read the EOI marker.
2199
+ //get_bits_no_markers(m_bits_left & 7);
2200
+
2201
+ // Prime the bit buffer
2202
+ m_bits_left = 16;
2203
+ get_bits(16);
2204
+ get_bits(16);
2205
+
2206
+ // The next marker _should_ be EOI
2207
+ process_markers();
2208
+ }
2209
+
2210
+ m_total_bytes_read -= m_in_buf_left;
2211
+ }
2212
+
2213
+ int jpeg_decoder::decode(const void** pScan_line, uint* pScan_line_len)
2214
+ {
2215
+ if ((m_error_code) || (!m_ready_flag))
2216
+ return JPGD_FAILED;
2217
+
2218
+ if (m_total_lines_left == 0)
2219
+ return JPGD_DONE;
2220
+
2221
+ if (m_mcu_lines_left == 0)
2222
+ {
2223
+ if (setjmp(m_jmp_state))
2224
+ return JPGD_FAILED;
2225
+
2226
+ if (m_progressive_flag)
2227
+ load_next_row();
2228
+ else
2229
+ decode_next_row();
2230
+
2231
+ // Find the EOI marker if that was the last row.
2232
+ if (m_total_lines_left <= m_max_mcu_y_size)
2233
+ find_eoi();
2234
+
2235
+ m_mcu_lines_left = m_max_mcu_y_size;
2236
+ }
2237
+
2238
+ if (m_freq_domain_chroma_upsample)
2239
+ {
2240
+ expanded_convert();
2241
+ *pScan_line = m_pScan_line_0;
2242
+ }
2243
+ else
2244
+ {
2245
+ switch (m_scan_type)
2246
+ {
2247
+ case JPGD_YH2V2:
2248
+ {
2249
+ if ((m_mcu_lines_left & 1) == 0)
2250
+ {
2251
+ H2V2Convert();
2252
+ *pScan_line = m_pScan_line_0;
2253
+ }
2254
+ else
2255
+ *pScan_line = m_pScan_line_1;
2256
+
2257
+ break;
2258
+ }
2259
+ case JPGD_YH2V1:
2260
+ {
2261
+ H2V1Convert();
2262
+ *pScan_line = m_pScan_line_0;
2263
+ break;
2264
+ }
2265
+ case JPGD_YH1V2:
2266
+ {
2267
+ if ((m_mcu_lines_left & 1) == 0)
2268
+ {
2269
+ H1V2Convert();
2270
+ *pScan_line = m_pScan_line_0;
2271
+ }
2272
+ else
2273
+ *pScan_line = m_pScan_line_1;
2274
+
2275
+ break;
2276
+ }
2277
+ case JPGD_YH1V1:
2278
+ {
2279
+ H1V1Convert();
2280
+ *pScan_line = m_pScan_line_0;
2281
+ break;
2282
+ }
2283
+ case JPGD_GRAYSCALE:
2284
+ {
2285
+ gray_convert();
2286
+ *pScan_line = m_pScan_line_0;
2287
+
2288
+ break;
2289
+ }
2290
+ }
2291
+ }
2292
+
2293
+ *pScan_line_len = m_real_dest_bytes_per_scan_line;
2294
+
2295
+ m_mcu_lines_left--;
2296
+ m_total_lines_left--;
2297
+
2298
+ return JPGD_SUCCESS;
2299
+ }
2300
+
2301
+ // Creates the tables needed for efficient Huffman decoding.
2302
+ void jpeg_decoder::make_huff_table(int index, huff_tables *pH)
2303
+ {
2304
+ int p, i, l, si;
2305
+ uint8 huffsize[257];
2306
+ uint huffcode[257];
2307
+ uint code;
2308
+ uint subtree;
2309
+ int code_size;
2310
+ int lastp;
2311
+ int nextfreeentry;
2312
+ int currententry;
2313
+
2314
+ pH->ac_table = m_huff_ac[index] != 0;
2315
+
2316
+ p = 0;
2317
+
2318
+ for (l = 1; l <= 16; l++)
2319
+ {
2320
+ for (i = 1; i <= m_huff_num[index][l]; i++)
2321
+ huffsize[p++] = static_cast<uint8>(l);
2322
+ }
2323
+
2324
+ huffsize[p] = 0;
2325
+
2326
+ lastp = p;
2327
+
2328
+ code = 0;
2329
+ si = huffsize[0];
2330
+ p = 0;
2331
+
2332
+ while (huffsize[p])
2333
+ {
2334
+ while (huffsize[p] == si)
2335
+ {
2336
+ huffcode[p++] = code;
2337
+ code++;
2338
+ }
2339
+
2340
+ code <<= 1;
2341
+ si++;
2342
+ }
2343
+
2344
+ memset(pH->look_up, 0, sizeof(pH->look_up));
2345
+ memset(pH->look_up2, 0, sizeof(pH->look_up2));
2346
+ memset(pH->tree, 0, sizeof(pH->tree));
2347
+ memset(pH->code_size, 0, sizeof(pH->code_size));
2348
+
2349
+ nextfreeentry = -1;
2350
+
2351
+ p = 0;
2352
+
2353
+ while (p < lastp)
2354
+ {
2355
+ i = m_huff_val[index][p];
2356
+ code = huffcode[p];
2357
+ code_size = huffsize[p];
2358
+
2359
+ pH->code_size[i] = static_cast<uint8>(code_size);
2360
+
2361
+ if (code_size <= 8)
2362
+ {
2363
+ code <<= (8 - code_size);
2364
+
2365
+ for (l = 1 << (8 - code_size); l > 0; l--)
2366
+ {
2367
+ JPGD_ASSERT(i < 256);
2368
+
2369
+ pH->look_up[code] = i;
2370
+
2371
+ bool has_extrabits = false;
2372
+ int extra_bits = 0;
2373
+ int num_extra_bits = i & 15;
2374
+
2375
+ int bits_to_fetch = code_size;
2376
+ if (num_extra_bits)
2377
+ {
2378
+ int total_codesize = code_size + num_extra_bits;
2379
+ if (total_codesize <= 8)
2380
+ {
2381
+ has_extrabits = true;
2382
+ extra_bits = ((1 << num_extra_bits) - 1) & (code >> (8 - total_codesize));
2383
+ JPGD_ASSERT(extra_bits <= 0x7FFF);
2384
+ bits_to_fetch += num_extra_bits;
2385
+ }
2386
+ }
2387
+
2388
+ if (!has_extrabits)
2389
+ pH->look_up2[code] = i | (bits_to_fetch << 8);
2390
+ else
2391
+ pH->look_up2[code] = i | 0x8000 | (extra_bits << 16) | (bits_to_fetch << 8);
2392
+
2393
+ code++;
2394
+ }
2395
+ }
2396
+ else
2397
+ {
2398
+ subtree = (code >> (code_size - 8)) & 0xFF;
2399
+
2400
+ currententry = pH->look_up[subtree];
2401
+
2402
+ if (currententry == 0)
2403
+ {
2404
+ pH->look_up[subtree] = currententry = nextfreeentry;
2405
+ pH->look_up2[subtree] = currententry = nextfreeentry;
2406
+
2407
+ nextfreeentry -= 2;
2408
+ }
2409
+
2410
+ code <<= (16 - (code_size - 8));
2411
+
2412
+ for (l = code_size; l > 9; l--)
2413
+ {
2414
+ if ((code & 0x8000) == 0)
2415
+ currententry--;
2416
+
2417
+ if (pH->tree[-currententry - 1] == 0)
2418
+ {
2419
+ pH->tree[-currententry - 1] = nextfreeentry;
2420
+
2421
+ currententry = nextfreeentry;
2422
+
2423
+ nextfreeentry -= 2;
2424
+ }
2425
+ else
2426
+ currententry = pH->tree[-currententry - 1];
2427
+
2428
+ code <<= 1;
2429
+ }
2430
+
2431
+ if ((code & 0x8000) == 0)
2432
+ currententry--;
2433
+
2434
+ pH->tree[-currententry - 1] = i;
2435
+ }
2436
+
2437
+ p++;
2438
+ }
2439
+ }
2440
+
2441
+ // Verifies the quantization tables needed for this scan are available.
2442
+ void jpeg_decoder::check_quant_tables()
2443
+ {
2444
+ for (int i = 0; i < m_comps_in_scan; i++)
2445
+ if (m_quant[m_comp_quant[m_comp_list[i]]] == NULL)
2446
+ stop_decoding(JPGD_UNDEFINED_QUANT_TABLE);
2447
+ }
2448
+
2449
+ // Verifies that all the Huffman tables needed for this scan are available.
2450
+ void jpeg_decoder::check_huff_tables()
2451
+ {
2452
+ for (int i = 0; i < m_comps_in_scan; i++)
2453
+ {
2454
+ if ((m_spectral_start == 0) && (m_huff_num[m_comp_dc_tab[m_comp_list[i]]] == NULL))
2455
+ stop_decoding(JPGD_UNDEFINED_HUFF_TABLE);
2456
+
2457
+ if ((m_spectral_end > 0) && (m_huff_num[m_comp_ac_tab[m_comp_list[i]]] == NULL))
2458
+ stop_decoding(JPGD_UNDEFINED_HUFF_TABLE);
2459
+ }
2460
+
2461
+ for (int i = 0; i < JPGD_MAX_HUFF_TABLES; i++)
2462
+ if (m_huff_num[i])
2463
+ {
2464
+ if (!m_pHuff_tabs[i])
2465
+ m_pHuff_tabs[i] = (huff_tables *)alloc(sizeof(huff_tables));
2466
+
2467
+ make_huff_table(i, m_pHuff_tabs[i]);
2468
+ }
2469
+ }
2470
+
2471
+ // Determines the component order inside each MCU.
2472
+ // Also calcs how many MCU's are on each row, etc.
2473
+ void jpeg_decoder::calc_mcu_block_order()
2474
+ {
2475
+ int component_num, component_id;
2476
+ int max_h_samp = 0, max_v_samp = 0;
2477
+
2478
+ for (component_id = 0; component_id < m_comps_in_frame; component_id++)
2479
+ {
2480
+ if (m_comp_h_samp[component_id] > max_h_samp)
2481
+ max_h_samp = m_comp_h_samp[component_id];
2482
+
2483
+ if (m_comp_v_samp[component_id] > max_v_samp)
2484
+ max_v_samp = m_comp_v_samp[component_id];
2485
+ }
2486
+
2487
+ for (component_id = 0; component_id < m_comps_in_frame; component_id++)
2488
+ {
2489
+ m_comp_h_blocks[component_id] = ((((m_image_x_size * m_comp_h_samp[component_id]) + (max_h_samp - 1)) / max_h_samp) + 7) / 8;
2490
+ m_comp_v_blocks[component_id] = ((((m_image_y_size * m_comp_v_samp[component_id]) + (max_v_samp - 1)) / max_v_samp) + 7) / 8;
2491
+ }
2492
+
2493
+ if (m_comps_in_scan == 1)
2494
+ {
2495
+ m_mcus_per_row = m_comp_h_blocks[m_comp_list[0]];
2496
+ m_mcus_per_col = m_comp_v_blocks[m_comp_list[0]];
2497
+ }
2498
+ else
2499
+ {
2500
+ m_mcus_per_row = (((m_image_x_size + 7) / 8) + (max_h_samp - 1)) / max_h_samp;
2501
+ m_mcus_per_col = (((m_image_y_size + 7) / 8) + (max_v_samp - 1)) / max_v_samp;
2502
+ }
2503
+
2504
+ if (m_comps_in_scan == 1)
2505
+ {
2506
+ m_mcu_org[0] = m_comp_list[0];
2507
+
2508
+ m_blocks_per_mcu = 1;
2509
+ }
2510
+ else
2511
+ {
2512
+ m_blocks_per_mcu = 0;
2513
+
2514
+ for (component_num = 0; component_num < m_comps_in_scan; component_num++)
2515
+ {
2516
+ int num_blocks;
2517
+
2518
+ component_id = m_comp_list[component_num];
2519
+
2520
+ num_blocks = m_comp_h_samp[component_id] * m_comp_v_samp[component_id];
2521
+
2522
+ while (num_blocks--)
2523
+ m_mcu_org[m_blocks_per_mcu++] = component_id;
2524
+ }
2525
+ }
2526
+ }
2527
+
2528
+ // Starts a new scan.
2529
+ int jpeg_decoder::init_scan()
2530
+ {
2531
+ if (!locate_sos_marker())
2532
+ return JPGD_FALSE;
2533
+
2534
+ calc_mcu_block_order();
2535
+
2536
+ check_huff_tables();
2537
+
2538
+ check_quant_tables();
2539
+
2540
+ memset(m_last_dc_val, 0, m_comps_in_frame * sizeof(uint));
2541
+
2542
+ m_eob_run = 0;
2543
+
2544
+ if (m_restart_interval)
2545
+ {
2546
+ m_restarts_left = m_restart_interval;
2547
+ m_next_restart_num = 0;
2548
+ }
2549
+
2550
+ fix_in_buffer();
2551
+
2552
+ return JPGD_TRUE;
2553
+ }
2554
+
2555
+ // Starts a frame. Determines if the number of components or sampling factors
2556
+ // are supported.
2557
+ void jpeg_decoder::init_frame()
2558
+ {
2559
+ int i;
2560
+
2561
+ if (m_comps_in_frame == 1)
2562
+ {
2563
+ if ((m_comp_h_samp[0] != 1) || (m_comp_v_samp[0] != 1))
2564
+ stop_decoding(JPGD_UNSUPPORTED_SAMP_FACTORS);
2565
+
2566
+ m_scan_type = JPGD_GRAYSCALE;
2567
+ m_max_blocks_per_mcu = 1;
2568
+ m_max_mcu_x_size = 8;
2569
+ m_max_mcu_y_size = 8;
2570
+ }
2571
+ else if (m_comps_in_frame == 3)
2572
+ {
2573
+ if ( ((m_comp_h_samp[1] != 1) || (m_comp_v_samp[1] != 1)) ||
2574
+ ((m_comp_h_samp[2] != 1) || (m_comp_v_samp[2] != 1)) )
2575
+ stop_decoding(JPGD_UNSUPPORTED_SAMP_FACTORS);
2576
+
2577
+ if ((m_comp_h_samp[0] == 1) && (m_comp_v_samp[0] == 1))
2578
+ {
2579
+ m_scan_type = JPGD_YH1V1;
2580
+
2581
+ m_max_blocks_per_mcu = 3;
2582
+ m_max_mcu_x_size = 8;
2583
+ m_max_mcu_y_size = 8;
2584
+ }
2585
+ else if ((m_comp_h_samp[0] == 2) && (m_comp_v_samp[0] == 1))
2586
+ {
2587
+ m_scan_type = JPGD_YH2V1;
2588
+ m_max_blocks_per_mcu = 4;
2589
+ m_max_mcu_x_size = 16;
2590
+ m_max_mcu_y_size = 8;
2591
+ }
2592
+ else if ((m_comp_h_samp[0] == 1) && (m_comp_v_samp[0] == 2))
2593
+ {
2594
+ m_scan_type = JPGD_YH1V2;
2595
+ m_max_blocks_per_mcu = 4;
2596
+ m_max_mcu_x_size = 8;
2597
+ m_max_mcu_y_size = 16;
2598
+ }
2599
+ else if ((m_comp_h_samp[0] == 2) && (m_comp_v_samp[0] == 2))
2600
+ {
2601
+ m_scan_type = JPGD_YH2V2;
2602
+ m_max_blocks_per_mcu = 6;
2603
+ m_max_mcu_x_size = 16;
2604
+ m_max_mcu_y_size = 16;
2605
+ }
2606
+ else
2607
+ stop_decoding(JPGD_UNSUPPORTED_SAMP_FACTORS);
2608
+ }
2609
+ else
2610
+ stop_decoding(JPGD_UNSUPPORTED_COLORSPACE);
2611
+
2612
+ m_max_mcus_per_row = (m_image_x_size + (m_max_mcu_x_size - 1)) / m_max_mcu_x_size;
2613
+ m_max_mcus_per_col = (m_image_y_size + (m_max_mcu_y_size - 1)) / m_max_mcu_y_size;
2614
+
2615
+ // These values are for the *destination* pixels: after conversion.
2616
+ if (m_scan_type == JPGD_GRAYSCALE)
2617
+ m_dest_bytes_per_pixel = 1;
2618
+ else
2619
+ m_dest_bytes_per_pixel = 4;
2620
+
2621
+ m_dest_bytes_per_scan_line = ((m_image_x_size + 15) & 0xFFF0) * m_dest_bytes_per_pixel;
2622
+
2623
+ m_real_dest_bytes_per_scan_line = (m_image_x_size * m_dest_bytes_per_pixel);
2624
+
2625
+ // Initialize two scan line buffers.
2626
+ m_pScan_line_0 = (uint8 *)alloc(m_dest_bytes_per_scan_line, true);
2627
+ if ((m_scan_type == JPGD_YH1V2) || (m_scan_type == JPGD_YH2V2))
2628
+ m_pScan_line_1 = (uint8 *)alloc(m_dest_bytes_per_scan_line, true);
2629
+
2630
+ m_max_blocks_per_row = m_max_mcus_per_row * m_max_blocks_per_mcu;
2631
+
2632
+ // Should never happen
2633
+ if (m_max_blocks_per_row > JPGD_MAX_BLOCKS_PER_ROW)
2634
+ stop_decoding(JPGD_ASSERTION_ERROR);
2635
+
2636
+ // Allocate the coefficient buffer, enough for one MCU
2637
+ m_pMCU_coefficients = (jpgd_block_t*)alloc(m_max_blocks_per_mcu * 64 * sizeof(jpgd_block_t));
2638
+
2639
+ for (i = 0; i < m_max_blocks_per_mcu; i++)
2640
+ m_mcu_block_max_zag[i] = 64;
2641
+
2642
+ m_expanded_blocks_per_component = m_comp_h_samp[0] * m_comp_v_samp[0];
2643
+ m_expanded_blocks_per_mcu = m_expanded_blocks_per_component * m_comps_in_frame;
2644
+ m_expanded_blocks_per_row = m_max_mcus_per_row * m_expanded_blocks_per_mcu;
2645
+ // Freq. domain chroma upsampling is only supported for H2V2 subsampling factor.
2646
+ // BEGIN EPIC MOD
2647
+ #if JPGD_SUPPORT_FREQ_DOMAIN_UPSAMPLING
2648
+ m_freq_domain_chroma_upsample = (m_expanded_blocks_per_mcu == 4*3);
2649
+ #else
2650
+ m_freq_domain_chroma_upsample = 0;
2651
+ #endif
2652
+ // END EPIC MOD
2653
+
2654
+ if (m_freq_domain_chroma_upsample)
2655
+ m_pSample_buf = (uint8 *)alloc(m_expanded_blocks_per_row * 64);
2656
+ else
2657
+ m_pSample_buf = (uint8 *)alloc(m_max_blocks_per_row * 64);
2658
+
2659
+ m_total_lines_left = m_image_y_size;
2660
+
2661
+ m_mcu_lines_left = 0;
2662
+
2663
+ create_look_ups();
2664
+ }
2665
+
2666
+ // The coeff_buf series of methods originally stored the coefficients
2667
+ // into a "virtual" file which was located in EMS, XMS, or a disk file. A cache
2668
+ // was used to make this process more efficient. Now, we can store the entire
2669
+ // thing in RAM.
2670
+ jpeg_decoder::coeff_buf* jpeg_decoder::coeff_buf_open(int block_num_x, int block_num_y, int block_len_x, int block_len_y)
2671
+ {
2672
+ coeff_buf* cb = (coeff_buf*)alloc(sizeof(coeff_buf));
2673
+
2674
+ cb->block_num_x = block_num_x;
2675
+ cb->block_num_y = block_num_y;
2676
+ cb->block_len_x = block_len_x;
2677
+ cb->block_len_y = block_len_y;
2678
+ cb->block_size = (block_len_x * block_len_y) * sizeof(jpgd_block_t);
2679
+ cb->pData = (uint8 *)alloc(cb->block_size * block_num_x * block_num_y, true);
2680
+ return cb;
2681
+ }
2682
+
2683
+ inline jpgd_block_t *jpeg_decoder::coeff_buf_getp(coeff_buf *cb, int block_x, int block_y)
2684
+ {
2685
+ JPGD_ASSERT((block_x < cb->block_num_x) && (block_y < cb->block_num_y));
2686
+ return (jpgd_block_t *)(cb->pData + block_x * cb->block_size + block_y * (cb->block_size * cb->block_num_x));
2687
+ }
2688
+
2689
+ // The following methods decode the various types of m_blocks encountered
2690
+ // in progressively encoded images.
2691
+ void jpeg_decoder::decode_block_dc_first(jpeg_decoder *pD, int component_id, int block_x, int block_y)
2692
+ {
2693
+ int s, r;
2694
+ jpgd_block_t *p = pD->coeff_buf_getp(pD->m_dc_coeffs[component_id], block_x, block_y);
2695
+
2696
+ if ((s = pD->huff_decode(pD->m_pHuff_tabs[pD->m_comp_dc_tab[component_id]])) != 0)
2697
+ {
2698
+ r = pD->get_bits_no_markers(s);
2699
+ s = HUFF_EXTEND(r, s);
2700
+ }
2701
+
2702
+ pD->m_last_dc_val[component_id] = (s += pD->m_last_dc_val[component_id]);
2703
+
2704
+ p[0] = static_cast<jpgd_block_t>(s << pD->m_successive_low);
2705
+ }
2706
+
2707
+ void jpeg_decoder::decode_block_dc_refine(jpeg_decoder *pD, int component_id, int block_x, int block_y)
2708
+ {
2709
+ if (pD->get_bits_no_markers(1))
2710
+ {
2711
+ jpgd_block_t *p = pD->coeff_buf_getp(pD->m_dc_coeffs[component_id], block_x, block_y);
2712
+
2713
+ p[0] |= (1 << pD->m_successive_low);
2714
+ }
2715
+ }
2716
+
2717
+ void jpeg_decoder::decode_block_ac_first(jpeg_decoder *pD, int component_id, int block_x, int block_y)
2718
+ {
2719
+ int k, s, r;
2720
+
2721
+ if (pD->m_eob_run)
2722
+ {
2723
+ pD->m_eob_run--;
2724
+ return;
2725
+ }
2726
+
2727
+ jpgd_block_t *p = pD->coeff_buf_getp(pD->m_ac_coeffs[component_id], block_x, block_y);
2728
+
2729
+ for (k = pD->m_spectral_start; k <= pD->m_spectral_end; k++)
2730
+ {
2731
+ s = pD->huff_decode(pD->m_pHuff_tabs[pD->m_comp_ac_tab[component_id]]);
2732
+
2733
+ r = s >> 4;
2734
+ s &= 15;
2735
+
2736
+ if (s)
2737
+ {
2738
+ if ((k += r) > 63)
2739
+ pD->stop_decoding(JPGD_DECODE_ERROR);
2740
+
2741
+ r = pD->get_bits_no_markers(s);
2742
+ s = HUFF_EXTEND(r, s);
2743
+
2744
+ p[g_ZAG[k]] = static_cast<jpgd_block_t>(s << pD->m_successive_low);
2745
+ }
2746
+ else
2747
+ {
2748
+ if (r == 15)
2749
+ {
2750
+ if ((k += 15) > 63)
2751
+ pD->stop_decoding(JPGD_DECODE_ERROR);
2752
+ }
2753
+ else
2754
+ {
2755
+ pD->m_eob_run = 1 << r;
2756
+
2757
+ if (r)
2758
+ pD->m_eob_run += pD->get_bits_no_markers(r);
2759
+
2760
+ pD->m_eob_run--;
2761
+
2762
+ break;
2763
+ }
2764
+ }
2765
+ }
2766
+ }
2767
+
2768
+ void jpeg_decoder::decode_block_ac_refine(jpeg_decoder *pD, int component_id, int block_x, int block_y)
2769
+ {
2770
+ int s, k, r;
2771
+ int p1 = 1 << pD->m_successive_low;
2772
+ int m1 = (-1) << pD->m_successive_low;
2773
+ jpgd_block_t *p = pD->coeff_buf_getp(pD->m_ac_coeffs[component_id], block_x, block_y);
2774
+
2775
+ k = pD->m_spectral_start;
2776
+
2777
+ if (pD->m_eob_run == 0)
2778
+ {
2779
+ for ( ; k <= pD->m_spectral_end; k++)
2780
+ {
2781
+ s = pD->huff_decode(pD->m_pHuff_tabs[pD->m_comp_ac_tab[component_id]]);
2782
+
2783
+ r = s >> 4;
2784
+ s &= 15;
2785
+
2786
+ if (s)
2787
+ {
2788
+ if (s != 1)
2789
+ pD->stop_decoding(JPGD_DECODE_ERROR);
2790
+
2791
+ if (pD->get_bits_no_markers(1))
2792
+ s = p1;
2793
+ else
2794
+ s = m1;
2795
+ }
2796
+ else
2797
+ {
2798
+ if (r != 15)
2799
+ {
2800
+ pD->m_eob_run = 1 << r;
2801
+
2802
+ if (r)
2803
+ pD->m_eob_run += pD->get_bits_no_markers(r);
2804
+
2805
+ break;
2806
+ }
2807
+ }
2808
+
2809
+ do
2810
+ {
2811
+ // BEGIN EPIC MOD
2812
+ JPGD_ASSERT(k < 64);
2813
+ // END EPIC MOD
2814
+
2815
+ jpgd_block_t *this_coef = p + g_ZAG[k];
2816
+
2817
+ if (*this_coef != 0)
2818
+ {
2819
+ if (pD->get_bits_no_markers(1))
2820
+ {
2821
+ if ((*this_coef & p1) == 0)
2822
+ {
2823
+ if (*this_coef >= 0)
2824
+ *this_coef = static_cast<jpgd_block_t>(*this_coef + p1);
2825
+ else
2826
+ *this_coef = static_cast<jpgd_block_t>(*this_coef + m1);
2827
+ }
2828
+ }
2829
+ }
2830
+ else
2831
+ {
2832
+ if (--r < 0)
2833
+ break;
2834
+ }
2835
+
2836
+ k++;
2837
+
2838
+ } while (k <= pD->m_spectral_end);
2839
+
2840
+ if ((s) && (k < 64))
2841
+ {
2842
+ p[g_ZAG[k]] = static_cast<jpgd_block_t>(s);
2843
+ }
2844
+ }
2845
+ }
2846
+
2847
+ if (pD->m_eob_run > 0)
2848
+ {
2849
+ for ( ; k <= pD->m_spectral_end; k++)
2850
+ {
2851
+ // BEGIN EPIC MOD
2852
+ JPGD_ASSERT(k < 64);
2853
+ // END EPIC MOD
2854
+
2855
+ jpgd_block_t *this_coef = p + g_ZAG[k];
2856
+
2857
+ if (*this_coef != 0)
2858
+ {
2859
+ if (pD->get_bits_no_markers(1))
2860
+ {
2861
+ if ((*this_coef & p1) == 0)
2862
+ {
2863
+ if (*this_coef >= 0)
2864
+ *this_coef = static_cast<jpgd_block_t>(*this_coef + p1);
2865
+ else
2866
+ *this_coef = static_cast<jpgd_block_t>(*this_coef + m1);
2867
+ }
2868
+ }
2869
+ }
2870
+ }
2871
+
2872
+ pD->m_eob_run--;
2873
+ }
2874
+ }
2875
+
2876
+ // Decode a scan in a progressively encoded image.
2877
+ void jpeg_decoder::decode_scan(pDecode_block_func decode_block_func)
2878
+ {
2879
+ int mcu_row, mcu_col, mcu_block;
2880
+ int block_x_mcu[JPGD_MAX_COMPONENTS], m_block_y_mcu[JPGD_MAX_COMPONENTS];
2881
+
2882
+ memset(m_block_y_mcu, 0, sizeof(m_block_y_mcu));
2883
+
2884
+ for (mcu_col = 0; mcu_col < m_mcus_per_col; mcu_col++)
2885
+ {
2886
+ int component_num, component_id;
2887
+
2888
+ memset(block_x_mcu, 0, sizeof(block_x_mcu));
2889
+
2890
+ for (mcu_row = 0; mcu_row < m_mcus_per_row; mcu_row++)
2891
+ {
2892
+ int block_x_mcu_ofs = 0, block_y_mcu_ofs = 0;
2893
+
2894
+ if ((m_restart_interval) && (m_restarts_left == 0))
2895
+ process_restart();
2896
+
2897
+ for (mcu_block = 0; mcu_block < m_blocks_per_mcu; mcu_block++)
2898
+ {
2899
+ component_id = m_mcu_org[mcu_block];
2900
+
2901
+ decode_block_func(this, component_id, block_x_mcu[component_id] + block_x_mcu_ofs, m_block_y_mcu[component_id] + block_y_mcu_ofs);
2902
+
2903
+ if (m_comps_in_scan == 1)
2904
+ block_x_mcu[component_id]++;
2905
+ else
2906
+ {
2907
+ if (++block_x_mcu_ofs == m_comp_h_samp[component_id])
2908
+ {
2909
+ block_x_mcu_ofs = 0;
2910
+
2911
+ if (++block_y_mcu_ofs == m_comp_v_samp[component_id])
2912
+ {
2913
+ block_y_mcu_ofs = 0;
2914
+ block_x_mcu[component_id] += m_comp_h_samp[component_id];
2915
+ }
2916
+ }
2917
+ }
2918
+ }
2919
+
2920
+ m_restarts_left--;
2921
+ }
2922
+
2923
+ if (m_comps_in_scan == 1)
2924
+ m_block_y_mcu[m_comp_list[0]]++;
2925
+ else
2926
+ {
2927
+ for (component_num = 0; component_num < m_comps_in_scan; component_num++)
2928
+ {
2929
+ component_id = m_comp_list[component_num];
2930
+ m_block_y_mcu[component_id] += m_comp_v_samp[component_id];
2931
+ }
2932
+ }
2933
+ }
2934
+ }
2935
+
2936
+ // Decode a progressively encoded image.
2937
+ void jpeg_decoder::init_progressive()
2938
+ {
2939
+ int i;
2940
+
2941
+ if (m_comps_in_frame == 4)
2942
+ stop_decoding(JPGD_UNSUPPORTED_COLORSPACE);
2943
+
2944
+ // Allocate the coefficient buffers.
2945
+ for (i = 0; i < m_comps_in_frame; i++)
2946
+ {
2947
+ m_dc_coeffs[i] = coeff_buf_open(m_max_mcus_per_row * m_comp_h_samp[i], m_max_mcus_per_col * m_comp_v_samp[i], 1, 1);
2948
+ m_ac_coeffs[i] = coeff_buf_open(m_max_mcus_per_row * m_comp_h_samp[i], m_max_mcus_per_col * m_comp_v_samp[i], 8, 8);
2949
+ }
2950
+
2951
+ for ( ; ; )
2952
+ {
2953
+ int dc_only_scan, refinement_scan;
2954
+ pDecode_block_func decode_block_func;
2955
+
2956
+ if (!init_scan())
2957
+ break;
2958
+
2959
+ dc_only_scan = (m_spectral_start == 0);
2960
+ refinement_scan = (m_successive_high != 0);
2961
+
2962
+ if ((m_spectral_start > m_spectral_end) || (m_spectral_end > 63))
2963
+ stop_decoding(JPGD_BAD_SOS_SPECTRAL);
2964
+
2965
+ if (dc_only_scan)
2966
+ {
2967
+ if (m_spectral_end)
2968
+ stop_decoding(JPGD_BAD_SOS_SPECTRAL);
2969
+ }
2970
+ else if (m_comps_in_scan != 1) /* AC scans can only contain one component */
2971
+ stop_decoding(JPGD_BAD_SOS_SPECTRAL);
2972
+
2973
+ if ((refinement_scan) && (m_successive_low != m_successive_high - 1))
2974
+ stop_decoding(JPGD_BAD_SOS_SUCCESSIVE);
2975
+
2976
+ if (dc_only_scan)
2977
+ {
2978
+ if (refinement_scan)
2979
+ decode_block_func = decode_block_dc_refine;
2980
+ else
2981
+ decode_block_func = decode_block_dc_first;
2982
+ }
2983
+ else
2984
+ {
2985
+ if (refinement_scan)
2986
+ decode_block_func = decode_block_ac_refine;
2987
+ else
2988
+ decode_block_func = decode_block_ac_first;
2989
+ }
2990
+
2991
+ decode_scan(decode_block_func);
2992
+
2993
+ m_bits_left = 16;
2994
+ get_bits(16);
2995
+ get_bits(16);
2996
+ }
2997
+
2998
+ m_comps_in_scan = m_comps_in_frame;
2999
+
3000
+ for (i = 0; i < m_comps_in_frame; i++)
3001
+ m_comp_list[i] = i;
3002
+
3003
+ calc_mcu_block_order();
3004
+ }
3005
+
3006
+ void jpeg_decoder::init_sequential()
3007
+ {
3008
+ if (!init_scan())
3009
+ stop_decoding(JPGD_UNEXPECTED_MARKER);
3010
+ }
3011
+
3012
+ void jpeg_decoder::decode_start()
3013
+ {
3014
+ init_frame();
3015
+
3016
+ if (m_progressive_flag)
3017
+ init_progressive();
3018
+ else
3019
+ init_sequential();
3020
+ }
3021
+
3022
+ void jpeg_decoder::decode_init(jpeg_decoder_stream *pStream)
3023
+ {
3024
+ init(pStream);
3025
+ locate_sof_marker();
3026
+ }
3027
+
3028
+ jpeg_decoder::jpeg_decoder(jpeg_decoder_stream *pStream)
3029
+ {
3030
+ if (setjmp(m_jmp_state))
3031
+ return;
3032
+ decode_init(pStream);
3033
+ }
3034
+
3035
+ int jpeg_decoder::begin_decoding()
3036
+ {
3037
+ if (m_ready_flag)
3038
+ return JPGD_SUCCESS;
3039
+
3040
+ if (m_error_code)
3041
+ return JPGD_FAILED;
3042
+
3043
+ if (setjmp(m_jmp_state))
3044
+ return JPGD_FAILED;
3045
+
3046
+ decode_start();
3047
+
3048
+ m_ready_flag = true;
3049
+
3050
+ return JPGD_SUCCESS;
3051
+ }
3052
+
3053
+ jpeg_decoder::~jpeg_decoder()
3054
+ {
3055
+ free_all_blocks();
3056
+ }
3057
+
3058
+ jpeg_decoder_file_stream::jpeg_decoder_file_stream()
3059
+ {
3060
+ m_pFile = NULL;
3061
+ m_eof_flag = false;
3062
+ m_error_flag = false;
3063
+ }
3064
+
3065
+ void jpeg_decoder_file_stream::close()
3066
+ {
3067
+ if (m_pFile)
3068
+ {
3069
+ fclose(m_pFile);
3070
+ m_pFile = NULL;
3071
+ }
3072
+
3073
+ m_eof_flag = false;
3074
+ m_error_flag = false;
3075
+ }
3076
+
3077
+ jpeg_decoder_file_stream::~jpeg_decoder_file_stream()
3078
+ {
3079
+ close();
3080
+ }
3081
+
3082
+ bool jpeg_decoder_file_stream::open(const char *Pfilename)
3083
+ {
3084
+ close();
3085
+
3086
+ m_eof_flag = false;
3087
+ m_error_flag = false;
3088
+
3089
+ #if defined(_MSC_VER)
3090
+ m_pFile = NULL;
3091
+ fopen_s(&m_pFile, Pfilename, "rb");
3092
+ #else
3093
+ m_pFile = fopen(Pfilename, "rb");
3094
+ #endif
3095
+ return m_pFile != NULL;
3096
+ }
3097
+
3098
+ int jpeg_decoder_file_stream::read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag)
3099
+ {
3100
+ if (!m_pFile)
3101
+ return -1;
3102
+
3103
+ if (m_eof_flag)
3104
+ {
3105
+ *pEOF_flag = true;
3106
+ return 0;
3107
+ }
3108
+
3109
+ if (m_error_flag)
3110
+ return -1;
3111
+
3112
+ int bytes_read = static_cast<int>(fread(pBuf, 1, max_bytes_to_read, m_pFile));
3113
+ if (bytes_read < max_bytes_to_read)
3114
+ {
3115
+ if (ferror(m_pFile))
3116
+ {
3117
+ m_error_flag = true;
3118
+ return -1;
3119
+ }
3120
+
3121
+ m_eof_flag = true;
3122
+ *pEOF_flag = true;
3123
+ }
3124
+
3125
+ return bytes_read;
3126
+ }
3127
+
3128
+ bool jpeg_decoder_mem_stream::open(const uint8 *pSrc_data, uint size)
3129
+ {
3130
+ close();
3131
+ m_pSrc_data = pSrc_data;
3132
+ m_ofs = 0;
3133
+ m_size = size;
3134
+ return true;
3135
+ }
3136
+
3137
+ int jpeg_decoder_mem_stream::read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag)
3138
+ {
3139
+ *pEOF_flag = false;
3140
+
3141
+ if (!m_pSrc_data)
3142
+ return -1;
3143
+
3144
+ uint bytes_remaining = m_size - m_ofs;
3145
+ if ((uint)max_bytes_to_read > bytes_remaining)
3146
+ {
3147
+ max_bytes_to_read = bytes_remaining;
3148
+ *pEOF_flag = true;
3149
+ }
3150
+
3151
+ memcpy(pBuf, m_pSrc_data + m_ofs, max_bytes_to_read);
3152
+ m_ofs += max_bytes_to_read;
3153
+
3154
+ return max_bytes_to_read;
3155
+ }
3156
+
3157
+ unsigned char *decompress_jpeg_image_from_stream(jpeg_decoder_stream *pStream, int *width, int *height, int *actual_comps, int req_comps)
3158
+ {
3159
+ if (!actual_comps)
3160
+ return NULL;
3161
+ *actual_comps = 0;
3162
+
3163
+ if ((!pStream) || (!width) || (!height) || (!req_comps))
3164
+ return NULL;
3165
+
3166
+ if ((req_comps != 1) && (req_comps != 3) && (req_comps != 4))
3167
+ return NULL;
3168
+
3169
+ jpeg_decoder decoder(pStream);
3170
+ if (decoder.get_error_code() != JPGD_SUCCESS)
3171
+ return NULL;
3172
+
3173
+ const int image_width = decoder.get_width(), image_height = decoder.get_height();
3174
+ *width = image_width;
3175
+ *height = image_height;
3176
+ *actual_comps = decoder.get_num_components();
3177
+
3178
+ if (decoder.begin_decoding() != JPGD_SUCCESS)
3179
+ return NULL;
3180
+
3181
+ const int dst_bpl = image_width * req_comps;
3182
+
3183
+ uint8 *pImage_data = (uint8*)jpgd_malloc(dst_bpl * image_height);
3184
+ if (!pImage_data)
3185
+ return NULL;
3186
+
3187
+ for (int y = 0; y < image_height; y++)
3188
+ {
3189
+ const uint8* pScan_line = 0;
3190
+ uint scan_line_len;
3191
+ if (decoder.decode((const void**)&pScan_line, &scan_line_len) != JPGD_SUCCESS)
3192
+ {
3193
+ jpgd_free(pImage_data);
3194
+ return NULL;
3195
+ }
3196
+
3197
+ uint8 *pDst = pImage_data + y * dst_bpl;
3198
+
3199
+ if (((req_comps == 4) && (decoder.get_num_components() == 3)) ||
3200
+ ((req_comps == 1) && (decoder.get_num_components() == 1)))
3201
+ {
3202
+ memcpy(pDst, pScan_line, dst_bpl);
3203
+ }
3204
+ else if (decoder.get_num_components() == 1)
3205
+ {
3206
+ if (req_comps == 3)
3207
+ {
3208
+ for (int x = 0; x < image_width; x++)
3209
+ {
3210
+ uint8 luma = pScan_line[x];
3211
+ pDst[0] = luma;
3212
+ pDst[1] = luma;
3213
+ pDst[2] = luma;
3214
+ pDst += 3;
3215
+ }
3216
+ }
3217
+ else
3218
+ {
3219
+ for (int x = 0; x < image_width; x++)
3220
+ {
3221
+ uint8 luma = pScan_line[x];
3222
+ pDst[0] = luma;
3223
+ pDst[1] = luma;
3224
+ pDst[2] = luma;
3225
+ pDst[3] = 255;
3226
+ pDst += 4;
3227
+ }
3228
+ }
3229
+ }
3230
+ else if (decoder.get_num_components() == 3)
3231
+ {
3232
+ if (req_comps == 1)
3233
+ {
3234
+ const int YR = 19595, YG = 38470, YB = 7471;
3235
+ for (int x = 0; x < image_width; x++)
3236
+ {
3237
+ int r = pScan_line[x*4+0];
3238
+ int g = pScan_line[x*4+1];
3239
+ int b = pScan_line[x*4+2];
3240
+ *pDst++ = static_cast<uint8>((r * YR + g * YG + b * YB + 32768) >> 16);
3241
+ }
3242
+ }
3243
+ else
3244
+ {
3245
+ for (int x = 0; x < image_width; x++)
3246
+ {
3247
+ pDst[0] = pScan_line[x*4+0];
3248
+ pDst[1] = pScan_line[x*4+1];
3249
+ pDst[2] = pScan_line[x*4+2];
3250
+ pDst += 3;
3251
+ }
3252
+ }
3253
+ }
3254
+ }
3255
+
3256
+ return pImage_data;
3257
+ }
3258
+
3259
+ // BEGIN EPIC MOD
3260
+ unsigned char *decompress_jpeg_image_from_memory(const unsigned char *pSrc_data, int src_data_size, int *width, int *height, int *actual_comps, int req_comps, int format)
3261
+ {
3262
+ jpg_format = (ERGBFormatJPG)format;
3263
+ // EMD EPIC MOD
3264
+ jpgd::jpeg_decoder_mem_stream mem_stream(pSrc_data, src_data_size);
3265
+ return decompress_jpeg_image_from_stream(&mem_stream, width, height, actual_comps, req_comps);
3266
+ }
3267
+
3268
+ unsigned char *decompress_jpeg_image_from_file(const char *pSrc_filename, int *width, int *height, int *actual_comps, int req_comps)
3269
+ {
3270
+ jpgd::jpeg_decoder_file_stream file_stream;
3271
+ if (!file_stream.open(pSrc_filename))
3272
+ return NULL;
3273
+ return decompress_jpeg_image_from_stream(&file_stream, width, height, actual_comps, req_comps);
3274
+ }
3275
+
3276
+ } // namespace jpgd
crazy_functions/test_project/cpp/longcode/jpge.cpp ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // jpge.cpp - C++ class for JPEG compression.
2
+ // Public domain, Rich Geldreich <[email protected]>
3
+ // v1.01, Dec. 18, 2010 - Initial release
4
+ // v1.02, Apr. 6, 2011 - Removed 2x2 ordered dither in H2V1 chroma subsampling method load_block_16_8_8(). (The rounding factor was 2, when it should have been 1. Either way, it wasn't helping.)
5
+ // v1.03, Apr. 16, 2011 - Added support for optimized Huffman code tables, optimized dynamic memory allocation down to only 1 alloc.
6
+ // Also from Alex Evans: Added RGBA support, linear memory allocator (no longer needed in v1.03).
7
+ // v1.04, May. 19, 2012: Forgot to set m_pFile ptr to NULL in cfile_stream::close(). Thanks to Owen Kaluza for reporting this bug.
8
+ // Code tweaks to fix VS2008 static code analysis warnings (all looked harmless).
9
+ // Code review revealed method load_block_16_8_8() (used for the non-default H2V1 sampling mode to downsample chroma) somehow didn't get the rounding factor fix from v1.02.
10
+
11
+ #include "jpge.h"
12
+
13
+ #include <stdlib.h>
14
+ #include <string.h>
15
+ #if PLATFORM_WINDOWS
16
+ #include <malloc.h>
17
+ #endif
18
+
19
+ #define JPGE_MAX(a,b) (((a)>(b))?(a):(b))
20
+ #define JPGE_MIN(a,b) (((a)<(b))?(a):(b))
21
+
22
+ namespace jpge {
23
+
24
+ static inline void *jpge_malloc(size_t nSize) { return FMemory::Malloc(nSize); }
25
+ static inline void jpge_free(void *p) { FMemory::Free(p);; }
26
+
27
+ // Various JPEG enums and tables.
28
+ enum { M_SOF0 = 0xC0, M_DHT = 0xC4, M_SOI = 0xD8, M_EOI = 0xD9, M_SOS = 0xDA, M_DQT = 0xDB, M_APP0 = 0xE0 };
29
+ enum { DC_LUM_CODES = 12, AC_LUM_CODES = 256, DC_CHROMA_CODES = 12, AC_CHROMA_CODES = 256, MAX_HUFF_SYMBOLS = 257, MAX_HUFF_CODESIZE = 32 };
30
+
31
+ static uint8 s_zag[64] = { 0,1,8,16,9,2,3,10,17,24,32,25,18,11,4,5,12,19,26,33,40,48,41,34,27,20,13,6,7,14,21,28,35,42,49,56,57,50,43,36,29,22,15,23,30,37,44,51,58,59,52,45,38,31,39,46,53,60,61,54,47,55,62,63 };
32
+ static int16 s_std_lum_quant[64] = { 16,11,12,14,12,10,16,14,13,14,18,17,16,19,24,40,26,24,22,22,24,49,35,37,29,40,58,51,61,60,57,51,56,55,64,72,92,78,64,68,87,69,55,56,80,109,81,87,95,98,103,104,103,62,77,113,121,112,100,120,92,101,103,99 };
33
+ static int16 s_std_croma_quant[64] = { 17,18,18,24,21,24,47,26,26,47,99,66,56,66,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99 };
34
+ static uint8 s_dc_lum_bits[17] = { 0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0 };
35
+ static uint8 s_dc_lum_val[DC_LUM_CODES] = { 0,1,2,3,4,5,6,7,8,9,10,11 };
36
+ static uint8 s_ac_lum_bits[17] = { 0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d };
37
+ static uint8 s_ac_lum_val[AC_LUM_CODES] =
38
+ {
39
+ 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08,0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,
40
+ 0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28,0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,
41
+ 0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89,
42
+ 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,
43
+ 0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,
44
+ 0xf9,0xfa
45
+ };
46
+ static uint8 s_dc_chroma_bits[17] = { 0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0 };
47
+ static uint8 s_dc_chroma_val[DC_CHROMA_CODES] = { 0,1,2,3,4,5,6,7,8,9,10,11 };
48
+ static uint8 s_ac_chroma_bits[17] = { 0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77 };
49
+ static uint8 s_ac_chroma_val[AC_CHROMA_CODES] =
50
+ {
51
+ 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91,0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,
52
+ 0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26,0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,
53
+ 0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87,
54
+ 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,
55
+ 0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,
56
+ 0xf9,0xfa
57
+ };
58
+
59
+ // Low-level helper functions.
60
+ template <class T> inline void clear_obj(T &obj) { memset(&obj, 0, sizeof(obj)); }
61
+
62
+ const int YR = 19595, YG = 38470, YB = 7471, CB_R = -11059, CB_G = -21709, CB_B = 32768, CR_R = 32768, CR_G = -27439, CR_B = -5329;
63
+ static inline uint8 clamp(int i) { if (static_cast<uint>(i) > 255U) { if (i < 0) i = 0; else if (i > 255) i = 255; } return static_cast<uint8>(i); }
64
+
65
+ static void RGB_to_YCC(uint8* pDst, const uint8 *pSrc, int num_pixels)
66
+ {
67
+ for ( ; num_pixels; pDst += 3, pSrc += 3, num_pixels--)
68
+ {
69
+ const int r = pSrc[0], g = pSrc[1], b = pSrc[2];
70
+ pDst[0] = static_cast<uint8>((r * YR + g * YG + b * YB + 32768) >> 16);
71
+ pDst[1] = clamp(128 + ((r * CB_R + g * CB_G + b * CB_B + 32768) >> 16));
72
+ pDst[2] = clamp(128 + ((r * CR_R + g * CR_G + b * CR_B + 32768) >> 16));
73
+ }
74
+ }
75
+
76
+ static void RGB_to_Y(uint8* pDst, const uint8 *pSrc, int num_pixels)
77
+ {
78
+ for ( ; num_pixels; pDst++, pSrc += 3, num_pixels--)
79
+ pDst[0] = static_cast<uint8>((pSrc[0] * YR + pSrc[1] * YG + pSrc[2] * YB + 32768) >> 16);
80
+ }
81
+
82
+ static void RGBA_to_YCC(uint8* pDst, const uint8 *pSrc, int num_pixels)
83
+ {
84
+ for ( ; num_pixels; pDst += 3, pSrc += 4, num_pixels--)
85
+ {
86
+ const int r = pSrc[0], g = pSrc[1], b = pSrc[2];
87
+ pDst[0] = static_cast<uint8>((r * YR + g * YG + b * YB + 32768) >> 16);
88
+ pDst[1] = clamp(128 + ((r * CB_R + g * CB_G + b * CB_B + 32768) >> 16));
89
+ pDst[2] = clamp(128 + ((r * CR_R + g * CR_G + b * CR_B + 32768) >> 16));
90
+ }
91
+ }
92
+
93
+ static void RGBA_to_Y(uint8* pDst, const uint8 *pSrc, int num_pixels)
94
+ {
95
+ for ( ; num_pixels; pDst++, pSrc += 4, num_pixels--)
96
+ pDst[0] = static_cast<uint8>((pSrc[0] * YR + pSrc[1] * YG + pSrc[2] * YB + 32768) >> 16);
97
+ }
98
+
99
+ static void Y_to_YCC(uint8* pDst, const uint8* pSrc, int num_pixels)
100
+ {
101
+ for( ; num_pixels; pDst += 3, pSrc++, num_pixels--) { pDst[0] = pSrc[0]; pDst[1] = 128; pDst[2] = 128; }
102
+ }
103
+
104
+ // Forward DCT - DCT derived from jfdctint.
105
+ #define CONST_BITS 13
106
+ #define ROW_BITS 2
107
+ #define DCT_DESCALE(x, n) (((x) + (((int32)1) << ((n) - 1))) >> (n))
108
+ #define DCT_MUL(var, c) (static_cast<int16>(var) * static_cast<int32>(c))
109
+ #define DCT1D(s0, s1, s2, s3, s4, s5, s6, s7) \
110
+ int32 t0 = s0 + s7, t7 = s0 - s7, t1 = s1 + s6, t6 = s1 - s6, t2 = s2 + s5, t5 = s2 - s5, t3 = s3 + s4, t4 = s3 - s4; \
111
+ int32 t10 = t0 + t3, t13 = t0 - t3, t11 = t1 + t2, t12 = t1 - t2; \
112
+ int32 u1 = DCT_MUL(t12 + t13, 4433); \
113
+ s2 = u1 + DCT_MUL(t13, 6270); \
114
+ s6 = u1 + DCT_MUL(t12, -15137); \
115
+ u1 = t4 + t7; \
116
+ int32 u2 = t5 + t6, u3 = t4 + t6, u4 = t5 + t7; \
117
+ int32 z5 = DCT_MUL(u3 + u4, 9633); \
118
+ t4 = DCT_MUL(t4, 2446); t5 = DCT_MUL(t5, 16819); \
119
+ t6 = DCT_MUL(t6, 25172); t7 = DCT_MUL(t7, 12299); \
120
+ u1 = DCT_MUL(u1, -7373); u2 = DCT_MUL(u2, -20995); \
121
+ u3 = DCT_MUL(u3, -16069); u4 = DCT_MUL(u4, -3196); \
122
+ u3 += z5; u4 += z5; \
123
+ s0 = t10 + t11; s1 = t7 + u1 + u4; s3 = t6 + u2 + u3; s4 = t10 - t11; s5 = t5 + u2 + u4; s7 = t4 + u1 + u3;
124
+
125
+ static void DCT2D(int32 *p)
126
+ {
127
+ int32 c, *q = p;
128
+ for (c = 7; c >= 0; c--, q += 8)
129
+ {
130
+ int32 s0 = q[0], s1 = q[1], s2 = q[2], s3 = q[3], s4 = q[4], s5 = q[5], s6 = q[6], s7 = q[7];
131
+ DCT1D(s0, s1, s2, s3, s4, s5, s6, s7);
132
+ q[0] = s0 << ROW_BITS; q[1] = DCT_DESCALE(s1, CONST_BITS-ROW_BITS); q[2] = DCT_DESCALE(s2, CONST_BITS-ROW_BITS); q[3] = DCT_DESCALE(s3, CONST_BITS-ROW_BITS);
133
+ q[4] = s4 << ROW_BITS; q[5] = DCT_DESCALE(s5, CONST_BITS-ROW_BITS); q[6] = DCT_DESCALE(s6, CONST_BITS-ROW_BITS); q[7] = DCT_DESCALE(s7, CONST_BITS-ROW_BITS);
134
+ }
135
+ for (q = p, c = 7; c >= 0; c--, q++)
136
+ {
137
+ int32 s0 = q[0*8], s1 = q[1*8], s2 = q[2*8], s3 = q[3*8], s4 = q[4*8], s5 = q[5*8], s6 = q[6*8], s7 = q[7*8];
138
+ DCT1D(s0, s1, s2, s3, s4, s5, s6, s7);
139
+ q[0*8] = DCT_DESCALE(s0, ROW_BITS+3); q[1*8] = DCT_DESCALE(s1, CONST_BITS+ROW_BITS+3); q[2*8] = DCT_DESCALE(s2, CONST_BITS+ROW_BITS+3); q[3*8] = DCT_DESCALE(s3, CONST_BITS+ROW_BITS+3);
140
+ q[4*8] = DCT_DESCALE(s4, ROW_BITS+3); q[5*8] = DCT_DESCALE(s5, CONST_BITS+ROW_BITS+3); q[6*8] = DCT_DESCALE(s6, CONST_BITS+ROW_BITS+3); q[7*8] = DCT_DESCALE(s7, CONST_BITS+ROW_BITS+3);
141
+ }
142
+ }
143
+
144
+ struct sym_freq { uint m_key, m_sym_index; };
145
+
146
+ // Radix sorts sym_freq[] array by 32-bit key m_key. Returns ptr to sorted values.
147
+ static inline sym_freq* radix_sort_syms(uint num_syms, sym_freq* pSyms0, sym_freq* pSyms1)
148
+ {
149
+ const uint cMaxPasses = 4;
150
+ uint32 hist[256 * cMaxPasses]; clear_obj(hist);
151
+ for (uint i = 0; i < num_syms; i++) { uint freq = pSyms0[i].m_key; hist[freq & 0xFF]++; hist[256 + ((freq >> 8) & 0xFF)]++; hist[256*2 + ((freq >> 16) & 0xFF)]++; hist[256*3 + ((freq >> 24) & 0xFF)]++; }
152
+ sym_freq* pCur_syms = pSyms0, *pNew_syms = pSyms1;
153
+ uint total_passes = cMaxPasses; while ((total_passes > 1) && (num_syms == hist[(total_passes - 1) * 256])) total_passes--;
154
+ for (uint pass_shift = 0, pass = 0; pass < total_passes; pass++, pass_shift += 8)
155
+ {
156
+ const uint32* pHist = &hist[pass << 8];
157
+ uint offsets[256], cur_ofs = 0;
158
+ for (uint i = 0; i < 256; i++) { offsets[i] = cur_ofs; cur_ofs += pHist[i]; }
159
+ for (uint i = 0; i < num_syms; i++)
160
+ pNew_syms[offsets[(pCur_syms[i].m_key >> pass_shift) & 0xFF]++] = pCur_syms[i];
161
+ sym_freq* t = pCur_syms; pCur_syms = pNew_syms; pNew_syms = t;
162
+ }
163
+ return pCur_syms;
164
+ }
165
+
166
+ // calculate_minimum_redundancy() originally written by: Alistair Moffat, [email protected], Jyrki Katajainen, [email protected], November 1996.
167
+ static void calculate_minimum_redundancy(sym_freq *A, int n)
168
+ {
169
+ int root, leaf, next, avbl, used, dpth;
170
+ if (n==0) return; else if (n==1) { A[0].m_key = 1; return; }
171
+ A[0].m_key += A[1].m_key; root = 0; leaf = 2;
172
+ for (next=1; next < n-1; next++)
173
+ {
174
+ if (leaf>=n || A[root].m_key<A[leaf].m_key) { A[next].m_key = A[root].m_key; A[root++].m_key = next; } else A[next].m_key = A[leaf++].m_key;
175
+ if (leaf>=n || (root<next && A[root].m_key<A[leaf].m_key)) { A[next].m_key += A[root].m_key; A[root++].m_key = next; } else A[next].m_key += A[leaf++].m_key;
176
+ }
177
+ A[n-2].m_key = 0;
178
+ for (next=n-3; next>=0; next--) A[next].m_key = A[A[next].m_key].m_key+1;
179
+ avbl = 1; used = dpth = 0; root = n-2; next = n-1;
180
+ while (avbl>0)
181
+ {
182
+ while (root>=0 && (int)A[root].m_key==dpth) { used++; root--; }
183
+ while (avbl>used) { A[next--].m_key = dpth; avbl--; }
184
+ avbl = 2*used; dpth++; used = 0;
185
+ }
186
+ }
187
+
188
+ // Limits canonical Huffman code table's max code size to max_code_size.
189
+ static void huffman_enforce_max_code_size(int *pNum_codes, int code_list_len, int max_code_size)
190
+ {
191
+ if (code_list_len <= 1) return;
192
+
193
+ for (int i = max_code_size + 1; i <= MAX_HUFF_CODESIZE; i++) pNum_codes[max_code_size] += pNum_codes[i];
194
+
195
+ uint32 total = 0;
196
+ for (int i = max_code_size; i > 0; i--)
197
+ total += (((uint32)pNum_codes[i]) << (max_code_size - i));
198
+
199
+ while (total != (1UL << max_code_size))
200
+ {
201
+ pNum_codes[max_code_size]--;
202
+ for (int i = max_code_size - 1; i > 0; i--)
203
+ {
204
+ if (pNum_codes[i]) { pNum_codes[i]--; pNum_codes[i + 1] += 2; break; }
205
+ }
206
+ total--;
207
+ }
208
+ }
209
+
210
+ // Generates an optimized offman table.
211
+ void jpeg_encoder::optimize_huffman_table(int table_num, int table_len)
212
+ {
213
+ sym_freq syms0[MAX_HUFF_SYMBOLS], syms1[MAX_HUFF_SYMBOLS];
214
+ syms0[0].m_key = 1; syms0[0].m_sym_index = 0; // dummy symbol, assures that no valid code contains all 1's
215
+ int num_used_syms = 1;
216
+ const uint32 *pSym_count = &m_huff_count[table_num][0];
217
+ for (int i = 0; i < table_len; i++)
218
+ if (pSym_count[i]) { syms0[num_used_syms].m_key = pSym_count[i]; syms0[num_used_syms++].m_sym_index = i + 1; }
219
+ sym_freq* pSyms = radix_sort_syms(num_used_syms, syms0, syms1);
220
+ calculate_minimum_redundancy(pSyms, num_used_syms);
221
+
222
+ // Count the # of symbols of each code size.
223
+ int num_codes[1 + MAX_HUFF_CODESIZE]; clear_obj(num_codes);
224
+ for (int i = 0; i < num_used_syms; i++)
225
+ num_codes[pSyms[i].m_key]++;
226
+
227
+ const uint JPGE_CODE_SIZE_LIMIT = 16; // the maximum possible size of a JPEG Huffman code (valid range is [9,16] - 9 vs. 8 because of the dummy symbol)
228
+ huffman_enforce_max_code_size(num_codes, num_used_syms, JPGE_CODE_SIZE_LIMIT);
229
+
230
+ // Compute m_huff_bits array, which contains the # of symbols per code size.
231
+ clear_obj(m_huff_bits[table_num]);
232
+ for (int i = 1; i <= (int)JPGE_CODE_SIZE_LIMIT; i++)
233
+ m_huff_bits[table_num][i] = static_cast<uint8>(num_codes[i]);
234
+
235
+ // Remove the dummy symbol added above, which must be in largest bucket.
236
+ for (int i = JPGE_CODE_SIZE_LIMIT; i >= 1; i--)
237
+ {
238
+ if (m_huff_bits[table_num][i]) { m_huff_bits[table_num][i]--; break; }
239
+ }
240
+
241
+ // Compute the m_huff_val array, which contains the symbol indices sorted by code size (smallest to largest).
242
+ for (int i = num_used_syms - 1; i >= 1; i--)
243
+ m_huff_val[table_num][num_used_syms - 1 - i] = static_cast<uint8>(pSyms[i].m_sym_index - 1);
244
+ }
245
+
246
+ // JPEG marker generation.
247
+ void jpeg_encoder::emit_byte(uint8 i)
248
+ {
249
+ m_all_stream_writes_succeeded = m_all_stream_writes_succeeded && m_pStream->put_obj(i);
250
+ }
251
+
252
+ void jpeg_encoder::emit_word(uint i)
253
+ {
254
+ emit_byte(uint8(i >> 8)); emit_byte(uint8(i & 0xFF));
255
+ }
256
+
257
+ void jpeg_encoder::emit_marker(int marker)
258
+ {
259
+ emit_byte(uint8(0xFF)); emit_byte(uint8(marker));
260
+ }
261
+
262
+ // Emit JFIF marker
263
+ void jpeg_encoder::emit_jfif_app0()
264
+ {
265
+ emit_marker(M_APP0);
266
+ emit_word(2 + 4 + 1 + 2 + 1 + 2 + 2 + 1 + 1);
267
+ emit_byte(0x4A); emit_byte(0x46); emit_byte(0x49); emit_byte(0x46); /* Identifier: ASCII "JFIF" */
268
+ emit_byte(0);
269
+ emit_byte(1); /* Major version */
270
+ emit_byte(1); /* Minor version */
271
+ emit_byte(0); /* Density unit */
272
+ emit_word(1);
273
+ emit_word(1);
274
+ emit_byte(0); /* No thumbnail image */
275
+ emit_byte(0);
276
+ }
277
+
278
+ // Emit quantization tables
279
+ void jpeg_encoder::emit_dqt()
280
+ {
281
+ for (int i = 0; i < ((m_num_components == 3) ? 2 : 1); i++)
282
+ {
283
+ emit_marker(M_DQT);
284
+ emit_word(64 + 1 + 2);
285
+ emit_byte(static_cast<uint8>(i));
286
+ for (int j = 0; j < 64; j++)
287
+ emit_byte(static_cast<uint8>(m_quantization_tables[i][j]));
288
+ }
289
+ }
290
+
291
+ // Emit start of frame marker
292
+ void jpeg_encoder::emit_sof()
293
+ {
294
+ emit_marker(M_SOF0); /* baseline */
295
+ emit_word(3 * m_num_components + 2 + 5 + 1);
296
+ emit_byte(8); /* precision */
297
+ emit_word(m_image_y);
298
+ emit_word(m_image_x);
299
+ emit_byte(m_num_components);
300
+ for (int i = 0; i < m_num_components; i++)
301
+ {
302
+ emit_byte(static_cast<uint8>(i + 1)); /* component ID */
303
+ emit_byte((m_comp_h_samp[i] << 4) + m_comp_v_samp[i]); /* h and v sampling */
304
+ emit_byte(i > 0); /* quant. table num */
305
+ }
306
+ }
307
+
308
+ // Emit Huffman table.
309
+ void jpeg_encoder::emit_dht(uint8 *bits, uint8 *val, int index, bool ac_flag)
310
+ {
311
+ emit_marker(M_DHT);
312
+
313
+ int length = 0;
314
+ for (int i = 1; i <= 16; i++)
315
+ length += bits[i];
316
+
317
+ emit_word(length + 2 + 1 + 16);
318
+ emit_byte(static_cast<uint8>(index + (ac_flag << 4)));
319
+
320
+ for (int i = 1; i <= 16; i++)
321
+ emit_byte(bits[i]);
322
+
323
+ for (int i = 0; i < length; i++)
324
+ emit_byte(val[i]);
325
+ }
326
+
327
+ // Emit all Huffman tables.
328
+ void jpeg_encoder::emit_dhts()
329
+ {
330
+ emit_dht(m_huff_bits[0+0], m_huff_val[0+0], 0, false);
331
+ emit_dht(m_huff_bits[2+0], m_huff_val[2+0], 0, true);
332
+ if (m_num_components == 3)
333
+ {
334
+ emit_dht(m_huff_bits[0+1], m_huff_val[0+1], 1, false);
335
+ emit_dht(m_huff_bits[2+1], m_huff_val[2+1], 1, true);
336
+ }
337
+ }
338
+
339
+ // emit start of scan
340
+ void jpeg_encoder::emit_sos()
341
+ {
342
+ emit_marker(M_SOS);
343
+ emit_word(2 * m_num_components + 2 + 1 + 3);
344
+ emit_byte(m_num_components);
345
+ for (int i = 0; i < m_num_components; i++)
346
+ {
347
+ emit_byte(static_cast<uint8>(i + 1));
348
+ if (i == 0)
349
+ emit_byte((0 << 4) + 0);
350
+ else
351
+ emit_byte((1 << 4) + 1);
352
+ }
353
+ emit_byte(0); /* spectral selection */
354
+ emit_byte(63);
355
+ emit_byte(0);
356
+ }
357
+
358
+ // Emit all markers at beginning of image file.
359
+ void jpeg_encoder::emit_markers()
360
+ {
361
+ emit_marker(M_SOI);
362
+ emit_jfif_app0();
363
+ emit_dqt();
364
+ emit_sof();
365
+ emit_dhts();
366
+ emit_sos();
367
+ }
368
+
369
+ // Compute the actual canonical Huffman codes/code sizes given the JPEG huff bits and val arrays.
370
+ void jpeg_encoder::compute_huffman_table(uint *codes, uint8 *code_sizes, uint8 *bits, uint8 *val)
371
+ {
372
+ int i, l, last_p, si;
373
+ uint8 huff_size[257];
374
+ uint huff_code[257];
375
+ uint code;
376
+
377
+ int p = 0;
378
+ for (l = 1; l <= 16; l++)
379
+ for (i = 1; i <= bits[l]; i++)
380
+ huff_size[p++] = (char)l;
381
+
382
+ huff_size[p] = 0; last_p = p; // write sentinel
383
+
384
+ code = 0; si = huff_size[0]; p = 0;
385
+
386
+ while (huff_size[p])
387
+ {
388
+ while (huff_size[p] == si)
389
+ huff_code[p++] = code++;
390
+ code <<= 1;
391
+ si++;
392
+ }
393
+
394
+ memset(codes, 0, sizeof(codes[0])*256);
395
+ memset(code_sizes, 0, sizeof(code_sizes[0])*256);
396
+ for (p = 0; p < last_p; p++)
397
+ {
398
+ codes[val[p]] = huff_code[p];
399
+ code_sizes[val[p]] = huff_size[p];
400
+ }
401
+ }
402
+
403
+ // Quantization table generation.
404
+ void jpeg_encoder::compute_quant_table(int32 *pDst, int16 *pSrc)
405
+ {
406
+ int32 q;
407
+ if (m_params.m_quality < 50)
408
+ q = 5000 / m_params.m_quality;
409
+ else
410
+ q = 200 - m_params.m_quality * 2;
411
+ for (int i = 0; i < 64; i++)
412
+ {
413
+ int32 j = *pSrc++; j = (j * q + 50L) / 100L;
414
+ *pDst++ = JPGE_MIN(JPGE_MAX(j, 1), 255);
415
+ }
416
+ }
417
+
418
+ // Higher-level methods.
419
+ void jpeg_encoder::first_pass_init()
420
+ {
421
+ m_bit_buffer = 0; m_bits_in = 0;
422
+ memset(m_last_dc_val, 0, 3 * sizeof(m_last_dc_val[0]));
423
+ m_mcu_y_ofs = 0;
424
+ m_pass_num = 1;
425
+ }
426
+
427
+ bool jpeg_encoder::second_pass_init()
428
+ {
429
+ compute_huffman_table(&m_huff_codes[0+0][0], &m_huff_code_sizes[0+0][0], m_huff_bits[0+0], m_huff_val[0+0]);
430
+ compute_huffman_table(&m_huff_codes[2+0][0], &m_huff_code_sizes[2+0][0], m_huff_bits[2+0], m_huff_val[2+0]);
431
+ if (m_num_components > 1)
432
+ {
433
+ compute_huffman_table(&m_huff_codes[0+1][0], &m_huff_code_sizes[0+1][0], m_huff_bits[0+1], m_huff_val[0+1]);
434
+ compute_huffman_table(&m_huff_codes[2+1][0], &m_huff_code_sizes[2+1][0], m_huff_bits[2+1], m_huff_val[2+1]);
435
+ }
436
+ first_pass_init();
437
+ emit_markers();
438
+ m_pass_num = 2;
439
+ return true;
440
+ }
441
+
442
+ bool jpeg_encoder::jpg_open(int p_x_res, int p_y_res, int src_channels)
443
+ {
444
+ m_num_components = 3;
445
+ switch (m_params.m_subsampling)
446
+ {
447
+ case Y_ONLY:
448
+ {
449
+ m_num_components = 1;
450
+ m_comp_h_samp[0] = 1; m_comp_v_samp[0] = 1;
451
+ m_mcu_x = 8; m_mcu_y = 8;
452
+ break;
453
+ }
454
+ case H1V1:
455
+ {
456
+ m_comp_h_samp[0] = 1; m_comp_v_samp[0] = 1;
457
+ m_comp_h_samp[1] = 1; m_comp_v_samp[1] = 1;
458
+ m_comp_h_samp[2] = 1; m_comp_v_samp[2] = 1;
459
+ m_mcu_x = 8; m_mcu_y = 8;
460
+ break;
461
+ }
462
+ case H2V1:
463
+ {
464
+ m_comp_h_samp[0] = 2; m_comp_v_samp[0] = 1;
465
+ m_comp_h_samp[1] = 1; m_comp_v_samp[1] = 1;
466
+ m_comp_h_samp[2] = 1; m_comp_v_samp[2] = 1;
467
+ m_mcu_x = 16; m_mcu_y = 8;
468
+ break;
469
+ }
470
+ case H2V2:
471
+ {
472
+ m_comp_h_samp[0] = 2; m_comp_v_samp[0] = 2;
473
+ m_comp_h_samp[1] = 1; m_comp_v_samp[1] = 1;
474
+ m_comp_h_samp[2] = 1; m_comp_v_samp[2] = 1;
475
+ m_mcu_x = 16; m_mcu_y = 16;
476
+ }
477
+ }
478
+
479
+ m_image_x = p_x_res; m_image_y = p_y_res;
480
+ m_image_bpp = src_channels;
481
+ m_image_bpl = m_image_x * src_channels;
482
+ m_image_x_mcu = (m_image_x + m_mcu_x - 1) & (~(m_mcu_x - 1));
483
+ m_image_y_mcu = (m_image_y + m_mcu_y - 1) & (~(m_mcu_y - 1));
484
+ m_image_bpl_xlt = m_image_x * m_num_components;
485
+ m_image_bpl_mcu = m_image_x_mcu * m_num_components;
486
+ m_mcus_per_row = m_image_x_mcu / m_mcu_x;
487
+
488
+ if ((m_mcu_lines[0] = static_cast<uint8*>(jpge_malloc(m_image_bpl_mcu * m_mcu_y))) == NULL) return false;
489
+ for (int i = 1; i < m_mcu_y; i++)
490
+ m_mcu_lines[i] = m_mcu_lines[i-1] + m_image_bpl_mcu;
491
+
492
+ compute_quant_table(m_quantization_tables[0], s_std_lum_quant);
493
+ compute_quant_table(m_quantization_tables[1], m_params.m_no_chroma_discrim_flag ? s_std_lum_quant : s_std_croma_quant);
494
+
495
+ m_out_buf_left = JPGE_OUT_BUF_SIZE;
496
+ m_pOut_buf = m_out_buf;
497
+
498
+ if (m_params.m_two_pass_flag)
499
+ {
500
+ clear_obj(m_huff_count);
501
+ first_pass_init();
502
+ }
503
+ else
504
+ {
505
+ memcpy(m_huff_bits[0+0], s_dc_lum_bits, 17); memcpy(m_huff_val [0+0], s_dc_lum_val, DC_LUM_CODES);
506
+ memcpy(m_huff_bits[2+0], s_ac_lum_bits, 17); memcpy(m_huff_val [2+0], s_ac_lum_val, AC_LUM_CODES);
507
+ memcpy(m_huff_bits[0+1], s_dc_chroma_bits, 17); memcpy(m_huff_val [0+1], s_dc_chroma_val, DC_CHROMA_CODES);
508
+ memcpy(m_huff_bits[2+1], s_ac_chroma_bits, 17); memcpy(m_huff_val [2+1], s_ac_chroma_val, AC_CHROMA_CODES);
509
+ if (!second_pass_init()) return false; // in effect, skip over the first pass
510
+ }
511
+ return m_all_stream_writes_succeeded;
512
+ }
513
+
514
+ void jpeg_encoder::load_block_8_8_grey(int x)
515
+ {
516
+ uint8 *pSrc;
517
+ sample_array_t *pDst = m_sample_array;
518
+ x <<= 3;
519
+ for (int i = 0; i < 8; i++, pDst += 8)
520
+ {
521
+ pSrc = m_mcu_lines[i] + x;
522
+ pDst[0] = pSrc[0] - 128; pDst[1] = pSrc[1] - 128; pDst[2] = pSrc[2] - 128; pDst[3] = pSrc[3] - 128;
523
+ pDst[4] = pSrc[4] - 128; pDst[5] = pSrc[5] - 128; pDst[6] = pSrc[6] - 128; pDst[7] = pSrc[7] - 128;
524
+ }
525
+ }
526
+
527
+ void jpeg_encoder::load_block_8_8(int x, int y, int c)
528
+ {
529
+ uint8 *pSrc;
530
+ sample_array_t *pDst = m_sample_array;
531
+ x = (x * (8 * 3)) + c;
532
+ y <<= 3;
533
+ for (int i = 0; i < 8; i++, pDst += 8)
534
+ {
535
+ pSrc = m_mcu_lines[y + i] + x;
536
+ pDst[0] = pSrc[0 * 3] - 128; pDst[1] = pSrc[1 * 3] - 128; pDst[2] = pSrc[2 * 3] - 128; pDst[3] = pSrc[3 * 3] - 128;
537
+ pDst[4] = pSrc[4 * 3] - 128; pDst[5] = pSrc[5 * 3] - 128; pDst[6] = pSrc[6 * 3] - 128; pDst[7] = pSrc[7 * 3] - 128;
538
+ }
539
+ }
540
+
541
+ void jpeg_encoder::load_block_16_8(int x, int c)
542
+ {
543
+ uint8 *pSrc1, *pSrc2;
544
+ sample_array_t *pDst = m_sample_array;
545
+ x = (x * (16 * 3)) + c;
546
+ int a = 0, b = 2;
547
+ for (int i = 0; i < 16; i += 2, pDst += 8)
548
+ {
549
+ pSrc1 = m_mcu_lines[i + 0] + x;
550
+ pSrc2 = m_mcu_lines[i + 1] + x;
551
+ pDst[0] = ((pSrc1[ 0 * 3] + pSrc1[ 1 * 3] + pSrc2[ 0 * 3] + pSrc2[ 1 * 3] + a) >> 2) - 128; pDst[1] = ((pSrc1[ 2 * 3] + pSrc1[ 3 * 3] + pSrc2[ 2 * 3] + pSrc2[ 3 * 3] + b) >> 2) - 128;
552
+ pDst[2] = ((pSrc1[ 4 * 3] + pSrc1[ 5 * 3] + pSrc2[ 4 * 3] + pSrc2[ 5 * 3] + a) >> 2) - 128; pDst[3] = ((pSrc1[ 6 * 3] + pSrc1[ 7 * 3] + pSrc2[ 6 * 3] + pSrc2[ 7 * 3] + b) >> 2) - 128;
553
+ pDst[4] = ((pSrc1[ 8 * 3] + pSrc1[ 9 * 3] + pSrc2[ 8 * 3] + pSrc2[ 9 * 3] + a) >> 2) - 128; pDst[5] = ((pSrc1[10 * 3] + pSrc1[11 * 3] + pSrc2[10 * 3] + pSrc2[11 * 3] + b) >> 2) - 128;
554
+ pDst[6] = ((pSrc1[12 * 3] + pSrc1[13 * 3] + pSrc2[12 * 3] + pSrc2[13 * 3] + a) >> 2) - 128; pDst[7] = ((pSrc1[14 * 3] + pSrc1[15 * 3] + pSrc2[14 * 3] + pSrc2[15 * 3] + b) >> 2) - 128;
555
+ int temp = a; a = b; b = temp;
556
+ }
557
+ }
558
+
559
+ void jpeg_encoder::load_block_16_8_8(int x, int c)
560
+ {
561
+ uint8 *pSrc1;
562
+ sample_array_t *pDst = m_sample_array;
563
+ x = (x * (16 * 3)) + c;
564
+ for (int i = 0; i < 8; i++, pDst += 8)
565
+ {
566
+ pSrc1 = m_mcu_lines[i + 0] + x;
567
+ pDst[0] = ((pSrc1[ 0 * 3] + pSrc1[ 1 * 3]) >> 1) - 128; pDst[1] = ((pSrc1[ 2 * 3] + pSrc1[ 3 * 3]) >> 1) - 128;
568
+ pDst[2] = ((pSrc1[ 4 * 3] + pSrc1[ 5 * 3]) >> 1) - 128; pDst[3] = ((pSrc1[ 6 * 3] + pSrc1[ 7 * 3]) >> 1) - 128;
569
+ pDst[4] = ((pSrc1[ 8 * 3] + pSrc1[ 9 * 3]) >> 1) - 128; pDst[5] = ((pSrc1[10 * 3] + pSrc1[11 * 3]) >> 1) - 128;
570
+ pDst[6] = ((pSrc1[12 * 3] + pSrc1[13 * 3]) >> 1) - 128; pDst[7] = ((pSrc1[14 * 3] + pSrc1[15 * 3]) >> 1) - 128;
571
+ }
572
+ }
573
+
574
+ void jpeg_encoder::load_quantized_coefficients(int component_num)
575
+ {
576
+ int32 *q = m_quantization_tables[component_num > 0];
577
+ int16 *pDst = m_coefficient_array;
578
+ for (int i = 0; i < 64; i++)
579
+ {
580
+ sample_array_t j = m_sample_array[s_zag[i]];
581
+ if (j < 0)
582
+ {
583
+ if ((j = -j + (*q >> 1)) < *q)
584
+ *pDst++ = 0;
585
+ else
586
+ *pDst++ = static_cast<int16>(-(j / *q));
587
+ }
588
+ else
589
+ {
590
+ if ((j = j + (*q >> 1)) < *q)
591
+ *pDst++ = 0;
592
+ else
593
+ *pDst++ = static_cast<int16>((j / *q));
594
+ }
595
+ q++;
596
+ }
597
+ }
598
+
599
+ void jpeg_encoder::flush_output_buffer()
600
+ {
601
+ if (m_out_buf_left != JPGE_OUT_BUF_SIZE)
602
+ m_all_stream_writes_succeeded = m_all_stream_writes_succeeded && m_pStream->put_buf(m_out_buf, JPGE_OUT_BUF_SIZE - m_out_buf_left);
603
+ m_pOut_buf = m_out_buf;
604
+ m_out_buf_left = JPGE_OUT_BUF_SIZE;
605
+ }
606
+
607
+ void jpeg_encoder::put_bits(uint bits, uint len)
608
+ {
609
+ m_bit_buffer |= ((uint32)bits << (24 - (m_bits_in += len)));
610
+ while (m_bits_in >= 8)
611
+ {
612
+ uint8 c;
613
+ #define JPGE_PUT_BYTE(c) { *m_pOut_buf++ = (c); if (--m_out_buf_left == 0) flush_output_buffer(); }
614
+ JPGE_PUT_BYTE(c = (uint8)((m_bit_buffer >> 16) & 0xFF));
615
+ if (c == 0xFF) JPGE_PUT_BYTE(0);
616
+ m_bit_buffer <<= 8;
617
+ m_bits_in -= 8;
618
+ }
619
+ }
620
+
621
+ void jpeg_encoder::code_coefficients_pass_one(int component_num)
622
+ {
623
+ if (component_num >= 3) return; // just to shut up static analysis
624
+ int i, run_len, nbits, temp1;
625
+ int16 *src = m_coefficient_array;
626
+ uint32 *dc_count = component_num ? m_huff_count[0 + 1] : m_huff_count[0 + 0], *ac_count = component_num ? m_huff_count[2 + 1] : m_huff_count[2 + 0];
627
+
628
+ temp1 = src[0] - m_last_dc_val[component_num];
629
+ m_last_dc_val[component_num] = src[0];
630
+ if (temp1 < 0) temp1 = -temp1;
631
+
632
+ nbits = 0;
633
+ while (temp1)
634
+ {
635
+ nbits++; temp1 >>= 1;
636
+ }
637
+
638
+ dc_count[nbits]++;
639
+ for (run_len = 0, i = 1; i < 64; i++)
640
+ {
641
+ if ((temp1 = m_coefficient_array[i]) == 0)
642
+ run_len++;
643
+ else
644
+ {
645
+ while (run_len >= 16)
646
+ {
647
+ ac_count[0xF0]++;
648
+ run_len -= 16;
649
+ }
650
+ if (temp1 < 0) temp1 = -temp1;
651
+ nbits = 1;
652
+ while (temp1 >>= 1) nbits++;
653
+ ac_count[(run_len << 4) + nbits]++;
654
+ run_len = 0;
655
+ }
656
+ }
657
+ if (run_len) ac_count[0]++;
658
+ }
659
+
660
+ void jpeg_encoder::code_coefficients_pass_two(int component_num)
661
+ {
662
+ int i, j, run_len, nbits, temp1, temp2;
663
+ int16 *pSrc = m_coefficient_array;
664
+ uint *codes[2];
665
+ uint8 *code_sizes[2];
666
+
667
+ if (component_num == 0)
668
+ {
669
+ codes[0] = m_huff_codes[0 + 0]; codes[1] = m_huff_codes[2 + 0];
670
+ code_sizes[0] = m_huff_code_sizes[0 + 0]; code_sizes[1] = m_huff_code_sizes[2 + 0];
671
+ }
672
+ else
673
+ {
674
+ codes[0] = m_huff_codes[0 + 1]; codes[1] = m_huff_codes[2 + 1];
675
+ code_sizes[0] = m_huff_code_sizes[0 + 1]; code_sizes[1] = m_huff_code_sizes[2 + 1];
676
+ }
677
+
678
+ temp1 = temp2 = pSrc[0] - m_last_dc_val[component_num];
679
+ m_last_dc_val[component_num] = pSrc[0];
680
+
681
+ if (temp1 < 0)
682
+ {
683
+ temp1 = -temp1; temp2--;
684
+ }
685
+
686
+ nbits = 0;
687
+ while (temp1)
688
+ {
689
+ nbits++; temp1 >>= 1;
690
+ }
691
+
692
+ put_bits(codes[0][nbits], code_sizes[0][nbits]);
693
+ if (nbits) put_bits(temp2 & ((1 << nbits) - 1), nbits);
694
+
695
+ for (run_len = 0, i = 1; i < 64; i++)
696
+ {
697
+ if ((temp1 = m_coefficient_array[i]) == 0)
698
+ run_len++;
699
+ else
700
+ {
701
+ while (run_len >= 16)
702
+ {
703
+ put_bits(codes[1][0xF0], code_sizes[1][0xF0]);
704
+ run_len -= 16;
705
+ }
706
+ if ((temp2 = temp1) < 0)
707
+ {
708
+ temp1 = -temp1;
709
+ temp2--;
710
+ }
711
+ nbits = 1;
712
+ while (temp1 >>= 1)
713
+ nbits++;
714
+ j = (run_len << 4) + nbits;
715
+ put_bits(codes[1][j], code_sizes[1][j]);
716
+ put_bits(temp2 & ((1 << nbits) - 1), nbits);
717
+ run_len = 0;
718
+ }
719
+ }
720
+ if (run_len)
721
+ put_bits(codes[1][0], code_sizes[1][0]);
722
+ }
723
+
724
+ void jpeg_encoder::code_block(int component_num)
725
+ {
726
+ DCT2D(m_sample_array);
727
+ load_quantized_coefficients(component_num);
728
+ if (m_pass_num == 1)
729
+ code_coefficients_pass_one(component_num);
730
+ else
731
+ code_coefficients_pass_two(component_num);
732
+ }
733
+
734
+ void jpeg_encoder::process_mcu_row()
735
+ {
736
+ if (m_num_components == 1)
737
+ {
738
+ for (int i = 0; i < m_mcus_per_row; i++)
739
+ {
740
+ load_block_8_8_grey(i); code_block(0);
741
+ }
742
+ }
743
+ else if ((m_comp_h_samp[0] == 1) && (m_comp_v_samp[0] == 1))
744
+ {
745
+ for (int i = 0; i < m_mcus_per_row; i++)
746
+ {
747
+ load_block_8_8(i, 0, 0); code_block(0); load_block_8_8(i, 0, 1); code_block(1); load_block_8_8(i, 0, 2); code_block(2);
748
+ }
749
+ }
750
+ else if ((m_comp_h_samp[0] == 2) && (m_comp_v_samp[0] == 1))
751
+ {
752
+ for (int i = 0; i < m_mcus_per_row; i++)
753
+ {
754
+ load_block_8_8(i * 2 + 0, 0, 0); code_block(0); load_block_8_8(i * 2 + 1, 0, 0); code_block(0);
755
+ load_block_16_8_8(i, 1); code_block(1); load_block_16_8_8(i, 2); code_block(2);
756
+ }
757
+ }
758
+ else if ((m_comp_h_samp[0] == 2) && (m_comp_v_samp[0] == 2))
759
+ {
760
+ for (int i = 0; i < m_mcus_per_row; i++)
761
+ {
762
+ load_block_8_8(i * 2 + 0, 0, 0); code_block(0); load_block_8_8(i * 2 + 1, 0, 0); code_block(0);
763
+ load_block_8_8(i * 2 + 0, 1, 0); code_block(0); load_block_8_8(i * 2 + 1, 1, 0); code_block(0);
764
+ load_block_16_8(i, 1); code_block(1); load_block_16_8(i, 2); code_block(2);
765
+ }
766
+ }
767
+ }
768
+
769
+ bool jpeg_encoder::terminate_pass_one()
770
+ {
771
+ optimize_huffman_table(0+0, DC_LUM_CODES); optimize_huffman_table(2+0, AC_LUM_CODES);
772
+ if (m_num_components > 1)
773
+ {
774
+ optimize_huffman_table(0+1, DC_CHROMA_CODES); optimize_huffman_table(2+1, AC_CHROMA_CODES);
775
+ }
776
+ return second_pass_init();
777
+ }
778
+
779
+ bool jpeg_encoder::terminate_pass_two()
780
+ {
781
+ put_bits(0x7F, 7);
782
+ flush_output_buffer();
783
+ emit_marker(M_EOI);
784
+ m_pass_num++; // purposely bump up m_pass_num, for debugging
785
+ return true;
786
+ }
787
+
788
+ bool jpeg_encoder::process_end_of_image()
789
+ {
790
+ if (m_mcu_y_ofs)
791
+ {
792
+ if (m_mcu_y_ofs < 16) // check here just to shut up static analysis
793
+ {
794
+ for (int i = m_mcu_y_ofs; i < m_mcu_y; i++)
795
+ memcpy(m_mcu_lines[i], m_mcu_lines[m_mcu_y_ofs - 1], m_image_bpl_mcu);
796
+ }
797
+
798
+ process_mcu_row();
799
+ }
800
+
801
+ if (m_pass_num == 1)
802
+ return terminate_pass_one();
803
+ else
804
+ return terminate_pass_two();
805
+ }
806
+
807
+ void jpeg_encoder::load_mcu(const void *pSrc)
808
+ {
809
+ const uint8* Psrc = reinterpret_cast<const uint8*>(pSrc);
810
+
811
+ uint8* pDst = m_mcu_lines[m_mcu_y_ofs]; // OK to write up to m_image_bpl_xlt bytes to pDst
812
+
813
+ if (m_num_components == 1)
814
+ {
815
+ if (m_image_bpp == 4)
816
+ RGBA_to_Y(pDst, Psrc, m_image_x);
817
+ else if (m_image_bpp == 3)
818
+ RGB_to_Y(pDst, Psrc, m_image_x);
819
+ else
820
+ memcpy(pDst, Psrc, m_image_x);
821
+ }
822
+ else
823
+ {
824
+ if (m_image_bpp == 4)
825
+ RGBA_to_YCC(pDst, Psrc, m_image_x);
826
+ else if (m_image_bpp == 3)
827
+ RGB_to_YCC(pDst, Psrc, m_image_x);
828
+ else
829
+ Y_to_YCC(pDst, Psrc, m_image_x);
830
+ }
831
+
832
+ // Possibly duplicate pixels at end of scanline if not a multiple of 8 or 16
833
+ if (m_num_components == 1)
834
+ memset(m_mcu_lines[m_mcu_y_ofs] + m_image_bpl_xlt, pDst[m_image_bpl_xlt - 1], m_image_x_mcu - m_image_x);
835
+ else
836
+ {
837
+ const uint8 y = pDst[m_image_bpl_xlt - 3 + 0], cb = pDst[m_image_bpl_xlt - 3 + 1], cr = pDst[m_image_bpl_xlt - 3 + 2];
838
+ uint8 *q = m_mcu_lines[m_mcu_y_ofs] + m_image_bpl_xlt;
839
+ for (int i = m_image_x; i < m_image_x_mcu; i++)
840
+ {
841
+ *q++ = y; *q++ = cb; *q++ = cr;
842
+ }
843
+ }
844
+
845
+ if (++m_mcu_y_ofs == m_mcu_y)
846
+ {
847
+ process_mcu_row();
848
+ m_mcu_y_ofs = 0;
849
+ }
850
+ }
851
+
852
+ void jpeg_encoder::clear()
853
+ {
854
+ m_mcu_lines[0] = NULL;
855
+ m_pass_num = 0;
856
+ m_all_stream_writes_succeeded = true;
857
+ }
858
+
859
+ jpeg_encoder::jpeg_encoder()
860
+ {
861
+ clear();
862
+ }
863
+
864
+ jpeg_encoder::~jpeg_encoder()
865
+ {
866
+ deinit();
867
+ }
868
+
869
+ bool jpeg_encoder::init(output_stream *pStream, int64_t width, int64_t height, int64_t src_channels, const params &comp_params)
870
+ {
871
+ deinit();
872
+ if (((!pStream) || (width < 1) || (height < 1)) || ((src_channels != 1) && (src_channels != 3) && (src_channels != 4)) || (!comp_params.check_valid())) return false;
873
+ m_pStream = pStream;
874
+ m_params = comp_params;
875
+ return jpg_open(width, height, src_channels);
876
+ }
877
+
878
+ void jpeg_encoder::deinit()
879
+ {
880
+ jpge_free(m_mcu_lines[0]);
881
+ clear();
882
+ }
883
+
884
+ bool jpeg_encoder::process_scanline(const void* pScanline)
885
+ {
886
+ if ((m_pass_num < 1) || (m_pass_num > 2)) return false;
887
+ if (m_all_stream_writes_succeeded)
888
+ {
889
+ if (!pScanline)
890
+ {
891
+ if (!process_end_of_image()) return false;
892
+ }
893
+ else
894
+ {
895
+ load_mcu(pScanline);
896
+ }
897
+ }
898
+ return m_all_stream_writes_succeeded;
899
+ }
900
+
901
+ // Higher level wrappers/examples (optional).
902
+ #include <stdio.h>
903
+
904
+ class cfile_stream : public output_stream
905
+ {
906
+ cfile_stream(const cfile_stream &);
907
+ cfile_stream &operator= (const cfile_stream &);
908
+
909
+ FILE* m_pFile;
910
+ bool m_bStatus;
911
+
912
+ public:
913
+ cfile_stream() : m_pFile(NULL), m_bStatus(false) { }
914
+
915
+ virtual ~cfile_stream()
916
+ {
917
+ close();
918
+ }
919
+
920
+ bool open(const char *pFilename)
921
+ {
922
+ close();
923
+ #if defined(_MSC_VER)
924
+ if (fopen_s(&m_pFile, pFilename, "wb") != 0)
925
+ {
926
+ return false;
927
+ }
928
+ #else
929
+ m_pFile = fopen(pFilename, "wb");
930
+ #endif
931
+ m_bStatus = (m_pFile != NULL);
932
+ return m_bStatus;
933
+ }
934
+
935
+ bool close()
936
+ {
937
+ if (m_pFile)
938
+ {
939
+ if (fclose(m_pFile) == EOF)
940
+ {
941
+ m_bStatus = false;
942
+ }
943
+ m_pFile = NULL;
944
+ }
945
+ return m_bStatus;
946
+ }
947
+
948
+ virtual bool put_buf(const void* pBuf, int64_t len)
949
+ {
950
+ m_bStatus = m_bStatus && (fwrite(pBuf, len, 1, m_pFile) == 1);
951
+ return m_bStatus;
952
+ }
953
+
954
+ uint get_size() const
955
+ {
956
+ return m_pFile ? ftell(m_pFile) : 0;
957
+ }
958
+ };
959
+
960
+ // Writes JPEG image to file.
961
+ bool compress_image_to_jpeg_file(const char *pFilename, int64_t width, int64_t height, int64_t num_channels, const uint8 *pImage_data, const params &comp_params)
962
+ {
963
+ cfile_stream dst_stream;
964
+ if (!dst_stream.open(pFilename))
965
+ return false;
966
+
967
+ jpge::jpeg_encoder dst_image;
968
+ if (!dst_image.init(&dst_stream, width, height, num_channels, comp_params))
969
+ return false;
970
+
971
+ for (uint pass_index = 0; pass_index < dst_image.get_total_passes(); pass_index++)
972
+ {
973
+ for (int64_t i = 0; i < height; i++)
974
+ {
975
+ // i, width, and num_channels are all 64bit
976
+ const uint8* pBuf = pImage_data + i * width * num_channels;
977
+ if (!dst_image.process_scanline(pBuf))
978
+ return false;
979
+ }
980
+ if (!dst_image.process_scanline(NULL))
981
+ return false;
982
+ }
983
+
984
+ dst_image.deinit();
985
+
986
+ return dst_stream.close();
987
+ }
988
+
989
+ class memory_stream : public output_stream
990
+ {
991
+ memory_stream(const memory_stream &);
992
+ memory_stream &operator= (const memory_stream &);
993
+
994
+ uint8 *m_pBuf;
995
+ uint64_t m_buf_size, m_buf_ofs;
996
+
997
+ public:
998
+ memory_stream(void *pBuf, uint64_t buf_size) : m_pBuf(static_cast<uint8*>(pBuf)), m_buf_size(buf_size), m_buf_ofs(0) { }
999
+
1000
+ virtual ~memory_stream() { }
1001
+
1002
+ virtual bool put_buf(const void* pBuf, int64_t len)
1003
+ {
1004
+ uint64_t buf_remaining = m_buf_size - m_buf_ofs;
1005
+ if ((uint64_t)len > buf_remaining)
1006
+ return false;
1007
+ memcpy(m_pBuf + m_buf_ofs, pBuf, len);
1008
+ m_buf_ofs += len;
1009
+ return true;
1010
+ }
1011
+
1012
+ uint64_t get_size() const
1013
+ {
1014
+ return m_buf_ofs;
1015
+ }
1016
+ };
1017
+
1018
+ bool compress_image_to_jpeg_file_in_memory(void *pDstBuf, int64_t &buf_size, int64_t width, int64_t height, int64_t num_channels, const uint8 *pImage_data, const params &comp_params)
1019
+ {
1020
+ if ((!pDstBuf) || (!buf_size))
1021
+ return false;
1022
+
1023
+ memory_stream dst_stream(pDstBuf, buf_size);
1024
+
1025
+ buf_size = 0;
1026
+
1027
+ jpge::jpeg_encoder dst_image;
1028
+ if (!dst_image.init(&dst_stream, width, height, num_channels, comp_params))
1029
+ return false;
1030
+
1031
+ for (uint pass_index = 0; pass_index < dst_image.get_total_passes(); pass_index++)
1032
+ {
1033
+ for (int64_t i = 0; i < height; i++)
1034
+ {
1035
+ const uint8* pScanline = pImage_data + i * width * num_channels;
1036
+ if (!dst_image.process_scanline(pScanline))
1037
+ return false;
1038
+ }
1039
+ if (!dst_image.process_scanline(NULL))
1040
+ return false;
1041
+ }
1042
+
1043
+ dst_image.deinit();
1044
+
1045
+ buf_size = dst_stream.get_size();
1046
+ return true;
1047
+ }
1048
+
1049
+ } // namespace jpge
crazy_functions/test_project/cpp/longcode/prod_cons.h ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <atomic>
4
+ #include <utility>
5
+ #include <cstring>
6
+ #include <type_traits>
7
+ #include <cstdint>
8
+
9
+ #include "libipc/def.h"
10
+
11
+ #include "libipc/platform/detail.h"
12
+ #include "libipc/circ/elem_def.h"
13
+ #include "libipc/utility/log.h"
14
+ #include "libipc/utility/utility.h"
15
+
16
+ namespace ipc {
17
+
18
+ ////////////////////////////////////////////////////////////////
19
+ /// producer-consumer implementation
20
+ ////////////////////////////////////////////////////////////////
21
+
22
+ template <typename Flag>
23
+ struct prod_cons_impl;
24
+
25
+ template <>
26
+ struct prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
27
+
28
+ template <std::size_t DataSize, std::size_t AlignSize>
29
+ struct elem_t {
30
+ std::aligned_storage_t<DataSize, AlignSize> data_ {};
31
+ };
32
+
33
+ alignas(cache_line_size) std::atomic<circ::u2_t> rd_; // read index
34
+ alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
35
+
36
+ constexpr circ::u2_t cursor() const noexcept {
37
+ return 0;
38
+ }
39
+
40
+ template <typename W, typename F, typename E>
41
+ bool push(W* /*wrapper*/, F&& f, E* elems) {
42
+ auto cur_wt = circ::index_of(wt_.load(std::memory_order_relaxed));
43
+ if (cur_wt == circ::index_of(rd_.load(std::memory_order_acquire) - 1)) {
44
+ return false; // full
45
+ }
46
+ std::forward<F>(f)(&(elems[cur_wt].data_));
47
+ wt_.fetch_add(1, std::memory_order_release);
48
+ return true;
49
+ }
50
+
51
+ /**
52
+ * In single-single-unicast, 'force_push' means 'no reader' or 'the only one reader is dead'.
53
+ * So we could just disconnect all connections of receiver, and return false.
54
+ */
55
+ template <typename W, typename F, typename E>
56
+ bool force_push(W* wrapper, F&&, E*) {
57
+ wrapper->elems()->disconnect_receiver(~static_cast<circ::cc_t>(0u));
58
+ return false;
59
+ }
60
+
61
+ template <typename W, typename F, typename R, typename E>
62
+ bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E* elems) {
63
+ auto cur_rd = circ::index_of(rd_.load(std::memory_order_relaxed));
64
+ if (cur_rd == circ::index_of(wt_.load(std::memory_order_acquire))) {
65
+ return false; // empty
66
+ }
67
+ std::forward<F>(f)(&(elems[cur_rd].data_));
68
+ std::forward<R>(out)(true);
69
+ rd_.fetch_add(1, std::memory_order_release);
70
+ return true;
71
+ }
72
+ };
73
+
74
+ template <>
75
+ struct prod_cons_impl<wr<relat::single, relat::multi , trans::unicast>>
76
+ : prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
77
+
78
+ template <typename W, typename F, typename E>
79
+ bool force_push(W* wrapper, F&&, E*) {
80
+ wrapper->elems()->disconnect_receiver(1);
81
+ return false;
82
+ }
83
+
84
+ template <typename W, typename F, typename R,
85
+ template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
86
+ bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
87
+ byte_t buff[DS];
88
+ for (unsigned k = 0;;) {
89
+ auto cur_rd = rd_.load(std::memory_order_relaxed);
90
+ if (circ::index_of(cur_rd) ==
91
+ circ::index_of(wt_.load(std::memory_order_acquire))) {
92
+ return false; // empty
93
+ }
94
+ std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
95
+ if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
96
+ std::forward<F>(f)(buff);
97
+ std::forward<R>(out)(true);
98
+ return true;
99
+ }
100
+ ipc::yield(k);
101
+ }
102
+ }
103
+ };
104
+
105
+ template <>
106
+ struct prod_cons_impl<wr<relat::multi , relat::multi, trans::unicast>>
107
+ : prod_cons_impl<wr<relat::single, relat::multi, trans::unicast>> {
108
+
109
+ using flag_t = std::uint64_t;
110
+
111
+ template <std::size_t DataSize, std::size_t AlignSize>
112
+ struct elem_t {
113
+ std::aligned_storage_t<DataSize, AlignSize> data_ {};
114
+ std::atomic<flag_t> f_ct_ { 0 }; // commit flag
115
+ };
116
+
117
+ alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
118
+
119
+ template <typename W, typename F, typename E>
120
+ bool push(W* /*wrapper*/, F&& f, E* elems) {
121
+ circ::u2_t cur_ct, nxt_ct;
122
+ for (unsigned k = 0;;) {
123
+ cur_ct = ct_.load(std::memory_order_relaxed);
124
+ if (circ::index_of(nxt_ct = cur_ct + 1) ==
125
+ circ::index_of(rd_.load(std::memory_order_acquire))) {
126
+ return false; // full
127
+ }
128
+ if (ct_.compare_exchange_weak(cur_ct, nxt_ct, std::memory_order_acq_rel)) {
129
+ break;
130
+ }
131
+ ipc::yield(k);
132
+ }
133
+ auto* el = elems + circ::index_of(cur_ct);
134
+ std::forward<F>(f)(&(el->data_));
135
+ // set flag & try update wt
136
+ el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
137
+ while (1) {
138
+ auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
139
+ if (cur_ct != wt_.load(std::memory_order_relaxed)) {
140
+ return true;
141
+ }
142
+ if ((~cac_ct) != cur_ct) {
143
+ return true;
144
+ }
145
+ if (!el->f_ct_.compare_exchange_strong(cac_ct, 0, std::memory_order_relaxed)) {
146
+ return true;
147
+ }
148
+ wt_.store(nxt_ct, std::memory_order_release);
149
+ cur_ct = nxt_ct;
150
+ nxt_ct = cur_ct + 1;
151
+ el = elems + circ::index_of(cur_ct);
152
+ }
153
+ return true;
154
+ }
155
+
156
+ template <typename W, typename F, typename E>
157
+ bool force_push(W* wrapper, F&&, E*) {
158
+ wrapper->elems()->disconnect_receiver(1);
159
+ return false;
160
+ }
161
+
162
+ template <typename W, typename F, typename R,
163
+ template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
164
+ bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
165
+ byte_t buff[DS];
166
+ for (unsigned k = 0;;) {
167
+ auto cur_rd = rd_.load(std::memory_order_relaxed);
168
+ auto cur_wt = wt_.load(std::memory_order_acquire);
169
+ auto id_rd = circ::index_of(cur_rd);
170
+ auto id_wt = circ::index_of(cur_wt);
171
+ if (id_rd == id_wt) {
172
+ auto* el = elems + id_wt;
173
+ auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
174
+ if ((~cac_ct) != cur_wt) {
175
+ return false; // empty
176
+ }
177
+ if (el->f_ct_.compare_exchange_weak(cac_ct, 0, std::memory_order_relaxed)) {
178
+ wt_.store(cur_wt + 1, std::memory_order_release);
179
+ }
180
+ k = 0;
181
+ }
182
+ else {
183
+ std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
184
+ if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
185
+ std::forward<F>(f)(buff);
186
+ std::forward<R>(out)(true);
187
+ return true;
188
+ }
189
+ ipc::yield(k);
190
+ }
191
+ }
192
+ }
193
+ };
194
+
195
+ template <>
196
+ struct prod_cons_impl<wr<relat::single, relat::multi, trans::broadcast>> {
197
+
198
+ using rc_t = std::uint64_t;
199
+
200
+ enum : rc_t {
201
+ ep_mask = 0x00000000ffffffffull,
202
+ ep_incr = 0x0000000100000000ull
203
+ };
204
+
205
+ template <std::size_t DataSize, std::size_t AlignSize>
206
+ struct elem_t {
207
+ std::aligned_storage_t<DataSize, AlignSize> data_ {};
208
+ std::atomic<rc_t> rc_ { 0 }; // read-counter
209
+ };
210
+
211
+ alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
212
+ alignas(cache_line_size) rc_t epoch_ { 0 }; // only one writer
213
+
214
+ circ::u2_t cursor() const noexcept {
215
+ return wt_.load(std::memory_order_acquire);
216
+ }
217
+
218
+ template <typename W, typename F, typename E>
219
+ bool push(W* wrapper, F&& f, E* elems) {
220
+ E* el;
221
+ for (unsigned k = 0;;) {
222
+ circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
223
+ if (cc == 0) return false; // no reader
224
+ el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
225
+ // check all consumers have finished reading this element
226
+ auto cur_rc = el->rc_.load(std::memory_order_acquire);
227
+ circ::cc_t rem_cc = cur_rc & ep_mask;
228
+ if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch_)) {
229
+ return false; // has not finished yet
230
+ }
231
+ // consider rem_cc to be 0 here
232
+ if (el->rc_.compare_exchange_weak(
233
+ cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
234
+ break;
235
+ }
236
+ ipc::yield(k);
237
+ }
238
+ std::forward<F>(f)(&(el->data_));
239
+ wt_.fetch_add(1, std::memory_order_release);
240
+ return true;
241
+ }
242
+
243
+ template <typename W, typename F, typename E>
244
+ bool force_push(W* wrapper, F&& f, E* elems) {
245
+ E* el;
246
+ epoch_ += ep_incr;
247
+ for (unsigned k = 0;;) {
248
+ circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
249
+ if (cc == 0) return false; // no reader
250
+ el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
251
+ // check all consumers have finished reading this element
252
+ auto cur_rc = el->rc_.load(std::memory_order_acquire);
253
+ circ::cc_t rem_cc = cur_rc & ep_mask;
254
+ if (cc & rem_cc) {
255
+ ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
256
+ cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
257
+ if (cc == 0) return false; // no reader
258
+ }
259
+ // just compare & exchange
260
+ if (el->rc_.compare_exchange_weak(
261
+ cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
262
+ break;
263
+ }
264
+ ipc::yield(k);
265
+ }
266
+ std::forward<F>(f)(&(el->data_));
267
+ wt_.fetch_add(1, std::memory_order_release);
268
+ return true;
269
+ }
270
+
271
+ template <typename W, typename F, typename R, typename E>
272
+ bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E* elems) {
273
+ if (cur == cursor()) return false; // acquire
274
+ auto* el = elems + circ::index_of(cur++);
275
+ std::forward<F>(f)(&(el->data_));
276
+ for (unsigned k = 0;;) {
277
+ auto cur_rc = el->rc_.load(std::memory_order_acquire);
278
+ if ((cur_rc & ep_mask) == 0) {
279
+ std::forward<R>(out)(true);
280
+ return true;
281
+ }
282
+ auto nxt_rc = cur_rc & ~static_cast<rc_t>(wrapper->connected_id());
283
+ if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
284
+ std::forward<R>(out)((nxt_rc & ep_mask) == 0);
285
+ return true;
286
+ }
287
+ ipc::yield(k);
288
+ }
289
+ }
290
+ };
291
+
292
+ template <>
293
+ struct prod_cons_impl<wr<relat::multi, relat::multi, trans::broadcast>> {
294
+
295
+ using rc_t = std::uint64_t;
296
+ using flag_t = std::uint64_t;
297
+
298
+ enum : rc_t {
299
+ rc_mask = 0x00000000ffffffffull,
300
+ ep_mask = 0x00ffffffffffffffull,
301
+ ep_incr = 0x0100000000000000ull,
302
+ ic_mask = 0xff000000ffffffffull,
303
+ ic_incr = 0x0000000100000000ull
304
+ };
305
+
306
+ template <std::size_t DataSize, std::size_t AlignSize>
307
+ struct elem_t {
308
+ std::aligned_storage_t<DataSize, AlignSize> data_ {};
309
+ std::atomic<rc_t > rc_ { 0 }; // read-counter
310
+ std::atomic<flag_t> f_ct_ { 0 }; // commit flag
311
+ };
312
+
313
+ alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
314
+ alignas(cache_line_size) std::atomic<rc_t> epoch_ { 0 };
315
+
316
+ circ::u2_t cursor() const noexcept {
317
+ return ct_.load(std::memory_order_acquire);
318
+ }
319
+
320
+ constexpr static rc_t inc_rc(rc_t rc) noexcept {
321
+ return (rc & ic_mask) | ((rc + ic_incr) & ~ic_mask);
322
+ }
323
+
324
+ constexpr static rc_t inc_mask(rc_t rc) noexcept {
325
+ return inc_rc(rc) & ~rc_mask;
326
+ }
327
+
328
+ template <typename W, typename F, typename E>
329
+ bool push(W* wrapper, F&& f, E* elems) {
330
+ E* el;
331
+ circ::u2_t cur_ct;
332
+ rc_t epoch = epoch_.load(std::memory_order_acquire);
333
+ for (unsigned k = 0;;) {
334
+ circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
335
+ if (cc == 0) return false; // no reader
336
+ el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
337
+ // check all consumers have finished reading this element
338
+ auto cur_rc = el->rc_.load(std::memory_order_relaxed);
339
+ circ::cc_t rem_cc = cur_rc & rc_mask;
340
+ if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch)) {
341
+ return false; // has not finished yet
342
+ }
343
+ else if (!rem_cc) {
344
+ auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
345
+ if ((cur_fl != cur_ct) && cur_fl) {
346
+ return false; // full
347
+ }
348
+ }
349
+ // consider rem_cc to be 0 here
350
+ if (el->rc_.compare_exchange_weak(
351
+ cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed) &&
352
+ epoch_.compare_exchange_weak(epoch, epoch, std::memory_order_acq_rel)) {
353
+ break;
354
+ }
355
+ ipc::yield(k);
356
+ }
357
+ // only one thread/process would touch here at one time
358
+ ct_.store(cur_ct + 1, std::memory_order_release);
359
+ std::forward<F>(f)(&(el->data_));
360
+ // set flag & try update wt
361
+ el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
362
+ return true;
363
+ }
364
+
365
+ template <typename W, typename F, typename E>
366
+ bool force_push(W* wrapper, F&& f, E* elems) {
367
+ E* el;
368
+ circ::u2_t cur_ct;
369
+ rc_t epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
370
+ for (unsigned k = 0;;) {
371
+ circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
372
+ if (cc == 0) return false; // no reader
373
+ el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
374
+ // check all consumers have finished reading this element
375
+ auto cur_rc = el->rc_.load(std::memory_order_acquire);
376
+ circ::cc_t rem_cc = cur_rc & rc_mask;
377
+ if (cc & rem_cc) {
378
+ ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
379
+ cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
380
+ if (cc == 0) return false; // no reader
381
+ }
382
+ // just compare & exchange
383
+ if (el->rc_.compare_exchange_weak(
384
+ cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed)) {
385
+ if (epoch == epoch_.load(std::memory_order_acquire)) {
386
+ break;
387
+ }
388
+ else if (push(wrapper, std::forward<F>(f), elems)) {
389
+ return true;
390
+ }
391
+ epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
392
+ }
393
+ ipc::yield(k);
394
+ }
395
+ // only one thread/process would touch here at one time
396
+ ct_.store(cur_ct + 1, std::memory_order_release);
397
+ std::forward<F>(f)(&(el->data_));
398
+ // set flag & try update wt
399
+ el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
400
+ return true;
401
+ }
402
+
403
+ template <typename W, typename F, typename R, typename E, std::size_t N>
404
+ bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E(& elems)[N]) {
405
+ auto* el = elems + circ::index_of(cur);
406
+ auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
407
+ if (cur_fl != ~static_cast<flag_t>(cur)) {
408
+ return false; // empty
409
+ }
410
+ ++cur;
411
+ std::forward<F>(f)(&(el->data_));
412
+ for (unsigned k = 0;;) {
413
+ auto cur_rc = el->rc_.load(std::memory_order_acquire);
414
+ if ((cur_rc & rc_mask) == 0) {
415
+ std::forward<R>(out)(true);
416
+ el->f_ct_.store(cur + N - 1, std::memory_order_release);
417
+ return true;
418
+ }
419
+ auto nxt_rc = inc_rc(cur_rc) & ~static_cast<rc_t>(wrapper->connected_id());
420
+ bool last_one = false;
421
+ if ((last_one = (nxt_rc & rc_mask) == 0)) {
422
+ el->f_ct_.store(cur + N - 1, std::memory_order_release);
423
+ }
424
+ if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
425
+ std::forward<R>(out)(last_one);
426
+ return true;
427
+ }
428
+ ipc::yield(k);
429
+ }
430
+ }
431
+ };
432
+
433
+ } // namespace ipc
crazy_functions/下载arxiv论文翻译摘要.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
+ from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down, get_conf
3
+ import re, requests, unicodedata, os
4
+
5
+ def download_arxiv_(url_pdf):
6
+ if 'arxiv.org' not in url_pdf:
7
+ if ('.' in url_pdf) and ('/' not in url_pdf):
8
+ new_url = 'https://arxiv.org/abs/'+url_pdf
9
+ print('下载编号:', url_pdf, '自动定位:', new_url)
10
+ # download_arxiv_(new_url)
11
+ return download_arxiv_(new_url)
12
+ else:
13
+ print('不能识别的URL!')
14
+ return None
15
+ if 'abs' in url_pdf:
16
+ url_pdf = url_pdf.replace('abs', 'pdf')
17
+ url_pdf = url_pdf + '.pdf'
18
+
19
+ url_abs = url_pdf.replace('.pdf', '').replace('pdf', 'abs')
20
+ title, other_info = get_name(_url_=url_abs)
21
+
22
+ paper_id = title.split()[0] # '[1712.00559]'
23
+ if '2' in other_info['year']:
24
+ title = other_info['year'] + ' ' + title
25
+
26
+ known_conf = ['NeurIPS', 'NIPS', 'Nature', 'Science', 'ICLR', 'AAAI']
27
+ for k in known_conf:
28
+ if k in other_info['comment']:
29
+ title = k + ' ' + title
30
+
31
+ download_dir = './gpt_log/arxiv/'
32
+ os.makedirs(download_dir, exist_ok=True)
33
+
34
+ title_str = title.replace('?', '?')\
35
+ .replace(':', ':')\
36
+ .replace('\"', '“')\
37
+ .replace('\n', '')\
38
+ .replace(' ', ' ')\
39
+ .replace(' ', ' ')
40
+
41
+ requests_pdf_url = url_pdf
42
+ file_path = download_dir+title_str
43
+ # if os.path.exists(file_path):
44
+ # print('返回缓存文件')
45
+ # return './gpt_log/arxiv/'+title_str
46
+
47
+ print('下载中')
48
+ proxies, = get_conf('proxies')
49
+ r = requests.get(requests_pdf_url, proxies=proxies)
50
+ with open(file_path, 'wb+') as f:
51
+ f.write(r.content)
52
+ print('下载完成')
53
+
54
+ # print('输出下载命令:','aria2c -o \"%s\" %s'%(title_str,url_pdf))
55
+ # subprocess.call('aria2c --all-proxy=\"172.18.116.150:11084\" -o \"%s\" %s'%(download_dir+title_str,url_pdf), shell=True)
56
+
57
+ x = "%s %s %s.bib" % (paper_id, other_info['year'], other_info['authors'])
58
+ x = x.replace('?', '?')\
59
+ .replace(':', ':')\
60
+ .replace('\"', '“')\
61
+ .replace('\n', '')\
62
+ .replace(' ', ' ')\
63
+ .replace(' ', ' ')
64
+ return './gpt_log/arxiv/'+title_str, other_info
65
+
66
+
67
+ def get_name(_url_):
68
+ import os
69
+ from bs4 import BeautifulSoup
70
+ print('正在获取文献名!')
71
+ print(_url_)
72
+
73
+ # arxiv_recall = {}
74
+ # if os.path.exists('./arxiv_recall.pkl'):
75
+ # with open('./arxiv_recall.pkl', 'rb') as f:
76
+ # arxiv_recall = pickle.load(f)
77
+
78
+ # if _url_ in arxiv_recall:
79
+ # print('在缓存中')
80
+ # return arxiv_recall[_url_]
81
+
82
+ proxies, = get_conf('proxies')
83
+ res = requests.get(_url_, proxies=proxies)
84
+
85
+ bs = BeautifulSoup(res.text, 'html.parser')
86
+ other_details = {}
87
+
88
+ # get year
89
+ try:
90
+ year = bs.find_all(class_='dateline')[0].text
91
+ year = re.search(r'(\d{4})', year, re.M | re.I).group(1)
92
+ other_details['year'] = year
93
+ abstract = bs.find_all(class_='abstract mathjax')[0].text
94
+ other_details['abstract'] = abstract
95
+ except:
96
+ other_details['year'] = ''
97
+ print('年份获取失败')
98
+
99
+ # get author
100
+ try:
101
+ authors = bs.find_all(class_='authors')[0].text
102
+ authors = authors.split('Authors:')[1]
103
+ other_details['authors'] = authors
104
+ except:
105
+ other_details['authors'] = ''
106
+ print('authors获取失败')
107
+
108
+ # get comment
109
+ try:
110
+ comment = bs.find_all(class_='metatable')[0].text
111
+ real_comment = None
112
+ for item in comment.replace('\n', ' ').split(' '):
113
+ if 'Comments' in item:
114
+ real_comment = item
115
+ if real_comment is not None:
116
+ other_details['comment'] = real_comment
117
+ else:
118
+ other_details['comment'] = ''
119
+ except:
120
+ other_details['comment'] = ''
121
+ print('年份获取失败')
122
+
123
+ title_str = BeautifulSoup(
124
+ res.text, 'html.parser').find('title').contents[0]
125
+ print('获取成功:', title_str)
126
+ # arxiv_recall[_url_] = (title_str+'.pdf', other_details)
127
+ # with open('./arxiv_recall.pkl', 'wb') as f:
128
+ # pickle.dump(arxiv_recall, f)
129
+
130
+ return title_str+'.pdf', other_details
131
+
132
+
133
+
134
+ @CatchException
135
+ def 下载arxiv论文并翻译摘要(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
136
+
137
+ CRAZY_FUNCTION_INFO = "下载arxiv论文并翻译摘要,函数插件作者[binary-husky]。正在提取摘要并下载PDF文档……"
138
+ import glob
139
+ import os
140
+
141
+ # 基本信息:功能、贡献者
142
+ chatbot.append(["函数插件功能?", CRAZY_FUNCTION_INFO])
143
+ yield chatbot, history, '正常'
144
+
145
+ # 尝试导入依赖,如果缺少依赖,则给出安装建议
146
+ try:
147
+ import pdfminer, bs4
148
+ except:
149
+ report_execption(chatbot, history,
150
+ a = f"解析项目: {txt}",
151
+ b = f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade pdfminer beautifulsoup4```。")
152
+ yield chatbot, history, '正常'
153
+ return
154
+
155
+ # 清空历史,以免输入溢出
156
+ history = []
157
+
158
+ # 提取摘要,下载PDF文档
159
+ try:
160
+ pdf_path, info = download_arxiv_(txt)
161
+ except:
162
+ report_execption(chatbot, history,
163
+ a = f"解析项目: {txt}",
164
+ b = f"下载pdf文件未成功")
165
+ yield chatbot, history, '正常'
166
+ return
167
+
168
+ # 翻译摘要等
169
+ i_say = f"请你阅读以下学术论文相关的材料,提取摘要,翻译为中文。材料如下:{str(info)}"
170
+ i_say_show_user = f'请你阅读以下学术论文相关的材料,提取摘要,翻译为中文。论文:{pdf_path}'
171
+ chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
172
+ yield chatbot, history, '正常'
173
+ msg = '正常'
174
+ # ** gpt request **
175
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[]) # 带超时倒计时
176
+ chatbot[-1] = (i_say_show_user, gpt_say)
177
+ history.append(i_say_show_user); history.append(gpt_say)
178
+ yield chatbot, history, msg
179
+ # 写入文件
180
+ import shutil
181
+ # 重置文件的创建时间
182
+ shutil.copyfile(pdf_path, f'./gpt_log/{os.path.basename(pdf_path)}'); os.remove(pdf_path)
183
+ res = write_results_to_file(history)
184
+ chatbot.append(("完成了吗?", res + "\n\nPDF文件也已经下载"))
185
+ yield chatbot, history, msg
186
+
crazy_functions/代码重写为全英文_多线程.py CHANGED
@@ -1,41 +1,97 @@
1
  import threading
2
- from predict import predict_no_ui_long_connection
3
- from toolbox import CatchException, write_results_to_file
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  @CatchException
8
  def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt, WEB_PORT):
9
- history = [] # 清空历史,以免输入溢出
10
- # 集合文件
11
- import time, glob, os
 
 
 
 
 
 
 
 
 
 
 
 
12
  os.makedirs('gpt_log/generated_english_version', exist_ok=True)
13
  os.makedirs('gpt_log/generated_english_version/crazy_functions', exist_ok=True)
14
  file_manifest = [f for f in glob.glob('./*.py') if ('test_project' not in f) and ('gpt_log' not in f)] + \
15
  [f for f in glob.glob('./crazy_functions/*.py') if ('test_project' not in f) and ('gpt_log' not in f)]
 
16
  i_say_show_user_buffer = []
17
 
18
- # 随便显示点什么防止卡顿的感觉
19
  for index, fp in enumerate(file_manifest):
20
  # if 'test_project' in fp: continue
21
  with open(fp, 'r', encoding='utf-8') as f:
22
  file_content = f.read()
23
- i_say_show_user =f'[{index}/{len(file_manifest)}] 接下来请将以下代码中包含的所有中文转化为英文,只输出代码: {os.path.abspath(fp)}'
24
  i_say_show_user_buffer.append(i_say_show_user)
25
  chatbot.append((i_say_show_user, "[Local Message] 等待多线程操作,中间过程不予显示."))
26
  yield chatbot, history, '正常'
27
 
28
- # 任务函数
 
 
 
 
 
 
 
 
 
 
29
  mutable_return = [None for _ in file_manifest]
 
30
  def thread_worker(fp,index):
 
 
 
31
  with open(fp, 'r', encoding='utf-8') as f:
32
  file_content = f.read()
33
- i_say = f'接下来请将以下代码中包含的所有中文转化为英文,只输出代码,文件名是{fp},文件代码是 ```{file_content}```'
34
- # ** gpt request **
35
- gpt_say = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
36
- mutable_return[index] = gpt_say
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # 所有线程同时开始执行任务函数
39
  handles = [threading.Thread(target=thread_worker, args=(fp,index)) for index, fp in enumerate(file_manifest)]
40
  for h in handles:
41
  h.daemon = True
@@ -43,19 +99,23 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt,
43
  chatbot.append(('开始了吗?', f'多线程操作已经开始'))
44
  yield chatbot, history, '正常'
45
 
46
- # 循环轮询各个线程是否执行完毕
47
  cnt = 0
48
  while True:
49
- time.sleep(1)
 
50
  th_alive = [h.is_alive() for h in handles]
51
  if not any(th_alive): break
52
- stat = ['执行中' if alive else '已完成' for alive in th_alive]
53
- stat_str = '|'.join(stat)
54
- cnt += 1
55
- chatbot[-1] = (chatbot[-1][0], f'多线程操作已经开始,完成情况: {stat_str}' + ''.join(['.']*(cnt%4)))
 
 
 
56
  yield chatbot, history, '正常'
57
 
58
- # 把结果写入文件
59
  for index, h in enumerate(handles):
60
  h.join() # 这里其实不需要join了,肯定已经都结束了
61
  fp = file_manifest[index]
@@ -63,13 +123,17 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt,
63
  i_say_show_user = i_say_show_user_buffer[index]
64
 
65
  where_to_relocate = f'gpt_log/generated_english_version/{fp}'
66
- with open(where_to_relocate, 'w+', encoding='utf-8') as f: f.write(gpt_say.lstrip('```').rstrip('```'))
 
 
 
 
67
  chatbot.append((i_say_show_user, f'[Local Message] 已完成{os.path.abspath(fp)}的转化,\n\n存入{os.path.abspath(where_to_relocate)}'))
68
  history.append(i_say_show_user); history.append(gpt_say)
69
  yield chatbot, history, '正常'
70
  time.sleep(1)
71
 
72
- # 备份一个文件
73
  res = write_results_to_file(history)
74
  chatbot.append(("生成一份任务执行报告", res))
75
  yield chatbot, history, '正常'
 
1
  import threading
2
+ from request_llm.bridge_chatgpt import predict_no_ui_long_connection
3
+ from toolbox import CatchException, write_results_to_file, report_execption
4
+ from .crazy_utils import breakdown_txt_to_satisfy_token_limit
5
 
6
+ def extract_code_block_carefully(txt):
7
+ splitted = txt.split('```')
8
+ n_code_block_seg = len(splitted) - 1
9
+ if n_code_block_seg <= 1: return txt
10
+ # 剩下的情况都开头除去 ``` 结尾除去一次 ```
11
+ txt_out = '```'.join(splitted[1:-1])
12
+ return txt_out
13
+
14
+
15
+
16
+ def break_txt_into_half_at_some_linebreak(txt):
17
+ lines = txt.split('\n')
18
+ n_lines = len(lines)
19
+ pre = lines[:(n_lines//2)]
20
+ post = lines[(n_lines//2):]
21
+ return "\n".join(pre), "\n".join(post)
22
 
23
 
24
  @CatchException
25
  def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt, WEB_PORT):
26
+ # 第1步:清空历史,以免输入溢出
27
+ history = []
28
+
29
+ # 第2步:尝试导入依赖,如果缺少依赖,则给出安装建议
30
+ try:
31
+ import openai, transformers
32
+ except:
33
+ report_execption(chatbot, history,
34
+ a = f"解析项目: {txt}",
35
+ b = f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade openai transformers```。")
36
+ yield chatbot, history, '正常'
37
+ return
38
+
39
+ # 第3步:集合文件
40
+ import time, glob, os, shutil, re, openai
41
  os.makedirs('gpt_log/generated_english_version', exist_ok=True)
42
  os.makedirs('gpt_log/generated_english_version/crazy_functions', exist_ok=True)
43
  file_manifest = [f for f in glob.glob('./*.py') if ('test_project' not in f) and ('gpt_log' not in f)] + \
44
  [f for f in glob.glob('./crazy_functions/*.py') if ('test_project' not in f) and ('gpt_log' not in f)]
45
+ # file_manifest = ['./toolbox.py']
46
  i_say_show_user_buffer = []
47
 
48
+ # 第4步:随便显示点什么防止卡顿的感觉
49
  for index, fp in enumerate(file_manifest):
50
  # if 'test_project' in fp: continue
51
  with open(fp, 'r', encoding='utf-8') as f:
52
  file_content = f.read()
53
+ i_say_show_user =f'[{index}/{len(file_manifest)}] 接下来请将以下代码中包含的所有中文转化为英文,只输出转化后的英文代码,请用代码块输出代码: {os.path.abspath(fp)}'
54
  i_say_show_user_buffer.append(i_say_show_user)
55
  chatbot.append((i_say_show_user, "[Local Message] 等待多线程操作,中间过程不予显示."))
56
  yield chatbot, history, '正常'
57
 
58
+
59
+ # 第5步:Token限制下的截断与处理
60
+ MAX_TOKEN = 3000
61
+ from transformers import GPT2TokenizerFast
62
+ print('加载tokenizer中')
63
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
64
+ get_token_fn = lambda txt: len(tokenizer(txt)["input_ids"])
65
+ print('加载tokenizer结束')
66
+
67
+
68
+ # 第6步:任务函数
69
  mutable_return = [None for _ in file_manifest]
70
+ observe_window = [[""] for _ in file_manifest]
71
  def thread_worker(fp,index):
72
+ if index > 10:
73
+ time.sleep(60)
74
+ print('Openai 限制免费用户每分钟20次请求,降低请求频率中。')
75
  with open(fp, 'r', encoding='utf-8') as f:
76
  file_content = f.read()
77
+ i_say_template = lambda fp, file_content: f'接下来请将以下代码中包含的所有中文转化为英文,只输出代码,文件名是{fp},文件代码是 ```{file_content}```'
78
+ try:
79
+ gpt_say = ""
80
+ # 分解代码文件
81
+ file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, get_token_fn, MAX_TOKEN)
82
+ for file_content_partial in file_content_breakdown:
83
+ i_say = i_say_template(fp, file_content_partial)
84
+ # # ** gpt request **
85
+ gpt_say_partial = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=[], sys_prompt=sys_prompt, observe_window=observe_window[index])
86
+ gpt_say_partial = extract_code_block_carefully(gpt_say_partial)
87
+ gpt_say += gpt_say_partial
88
+ mutable_return[index] = gpt_say
89
+ except ConnectionAbortedError as token_exceed_err:
90
+ print('至少一个线程任务Token溢出而失败', e)
91
+ except Exception as e:
92
+ print('至少一个线程任务意外失败', e)
93
 
94
+ # 第7步:所有线程同时开始执行任务函数
95
  handles = [threading.Thread(target=thread_worker, args=(fp,index)) for index, fp in enumerate(file_manifest)]
96
  for h in handles:
97
  h.daemon = True
 
99
  chatbot.append(('开始了吗?', f'多线程操作已经开始'))
100
  yield chatbot, history, '正常'
101
 
102
+ # 第8步:循环轮询各个线程是否执行完毕
103
  cnt = 0
104
  while True:
105
+ cnt += 1
106
+ time.sleep(0.2)
107
  th_alive = [h.is_alive() for h in handles]
108
  if not any(th_alive): break
109
+ # 更好的UI视觉效果
110
+ observe_win = []
111
+ for thread_index, alive in enumerate(th_alive):
112
+ observe_win.append("[ ..."+observe_window[thread_index][0][-60:].replace('\n','').replace('```','...').replace(' ','.').replace('<br/>','.....').replace('$','.')+"... ]")
113
+ stat = [f'执行中: {obs}\n\n' if alive else '已完成\n\n' for alive, obs in zip(th_alive, observe_win)]
114
+ stat_str = ''.join(stat)
115
+ chatbot[-1] = (chatbot[-1][0], f'多线程操作已经开始,完成情况: \n\n{stat_str}' + ''.join(['.']*(cnt%10+1)))
116
  yield chatbot, history, '正常'
117
 
118
+ # 第9步:把结果写入文件
119
  for index, h in enumerate(handles):
120
  h.join() # 这里其实不需要join了,肯定已经都结束了
121
  fp = file_manifest[index]
 
123
  i_say_show_user = i_say_show_user_buffer[index]
124
 
125
  where_to_relocate = f'gpt_log/generated_english_version/{fp}'
126
+ if gpt_say is not None:
127
+ with open(where_to_relocate, 'w+', encoding='utf-8') as f:
128
+ f.write(gpt_say)
129
+ else: # 失败
130
+ shutil.copyfile(file_manifest[index], where_to_relocate)
131
  chatbot.append((i_say_show_user, f'[Local Message] 已完成{os.path.abspath(fp)}的转化,\n\n存入{os.path.abspath(where_to_relocate)}'))
132
  history.append(i_say_show_user); history.append(gpt_say)
133
  yield chatbot, history, '正常'
134
  time.sleep(1)
135
 
136
+ # 第10步:备份一个文件
137
  res = write_results_to_file(history)
138
  chatbot.append(("生成一份任务执行报告", res))
139
  yield chatbot, history, '正常'
crazy_functions/总结word文档.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
+ from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
+ fast_debug = False
4
+
5
+
6
+ def 解析docx(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt):
7
+ import time, os
8
+ # pip install python-docx 用于docx格式,跨平台
9
+ # pip install pywin32 用于doc格式,仅支持Win平台
10
+
11
+ print('begin analysis on:', file_manifest)
12
+ for index, fp in enumerate(file_manifest):
13
+ if fp.split(".")[-1] == "docx":
14
+ from docx import Document
15
+ doc = Document(fp)
16
+ file_content = "\n".join([para.text for para in doc.paragraphs])
17
+ else:
18
+ import win32com.client
19
+ word = win32com.client.Dispatch("Word.Application")
20
+ word.visible = False
21
+ # 打开文件
22
+ print('fp', os.getcwd())
23
+ doc = word.Documents.Open(os.getcwd() + '/' + fp)
24
+ # file_content = doc.Content.Text
25
+ doc = word.ActiveDocument
26
+ file_content = doc.Range().Text
27
+ doc.Close()
28
+ word.Quit()
29
+
30
+ print(file_content)
31
+
32
+ prefix = "接下来请你逐文件分析下面的论文文件," if index == 0 else ""
33
+ # private_upload里面的文件名在解压zip后容易出现乱码(rar和7z格式正常),故可以只分析文章内容,不输入文件名
34
+ i_say = prefix + f'请对下面的文章片段用中英文做概述,文件名是{os.path.relpath(fp, project_folder)},' \
35
+ f'文章内容是 ```{file_content}```'
36
+ i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 假设你是论文审稿专家,请对下面的文章片段做概述: {os.path.abspath(fp)}'
37
+ chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
38
+ yield chatbot, history, '正常'
39
+
40
+ if not fast_debug:
41
+ msg = '正常'
42
+ # ** gpt request **
43
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature,
44
+ history=[]) # 带超时倒计时
45
+ chatbot[-1] = (i_say_show_user, gpt_say)
46
+ history.append(i_say_show_user);
47
+ history.append(gpt_say)
48
+ yield chatbot, history, msg
49
+ if not fast_debug: time.sleep(2)
50
+
51
+ """
52
+ # 可按需启用
53
+ i_say = f'根据你上述的分析,对全文进行概括,用学术性语言写一段中文摘要,然后再写一篇英文的。'
54
+ chatbot.append((i_say, "[Local Message] waiting gpt response."))
55
+ yield chatbot, history, '正常'
56
+
57
+
58
+ i_say = f'我想让你做一个论文写作导师。您的任务是使用人工智能工具(例如自然语言处理)提供有关如何改进其上述文章的反馈。' \
59
+ f'您还应该利用您在有效写作技巧方面的修辞知识和经验来建议作者可以更好地以书面形式表达他们的想法和想法的方法。' \
60
+ f'根据你之前的分析,提出建议'
61
+ chatbot.append((i_say, "[Local Message] waiting gpt response."))
62
+ yield chatbot, history, '正常'
63
+
64
+ """
65
+
66
+ if not fast_debug:
67
+ msg = '正常'
68
+ # ** gpt request **
69
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say, chatbot, top_p, temperature,
70
+ history=history) # 带超时倒计时
71
+
72
+ chatbot[-1] = (i_say, gpt_say)
73
+ history.append(i_say)
74
+ history.append(gpt_say)
75
+ yield chatbot, history, msg
76
+ res = write_results_to_file(history)
77
+ chatbot.append(("完成了吗?", res))
78
+ yield chatbot, history, msg
79
+
80
+
81
+ @CatchException
82
+ def 总结word文档(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
83
+ import glob, os
84
+
85
+ # 基本信息:功能、贡献者
86
+ chatbot.append([
87
+ "函数插件功能?",
88
+ "批量总结Word文档。函数插件贡献者: JasonGuo1"])
89
+ yield chatbot, history, '正常'
90
+
91
+ # 尝试导入依赖,如果缺少依赖,则给出安装建议
92
+ try:
93
+ from docx import Document
94
+ except:
95
+ report_execption(chatbot, history,
96
+ a=f"解析项目: {txt}",
97
+ b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade python-docx pywin32```。")
98
+ yield chatbot, history, '正常'
99
+ return
100
+
101
+ # 清空历史,以免输入溢出
102
+ history = []
103
+
104
+ # 检测输入参数,如没有给定输入参数,直接退出
105
+ if os.path.exists(txt):
106
+ project_folder = txt
107
+ else:
108
+ if txt == "": txt = '空空如也的输入栏'
109
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
110
+ yield chatbot, history, '正常'
111
+ return
112
+
113
+ # 搜索需要处理的文件清单
114
+ file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.docx', recursive=True)] + \
115
+ [f for f in glob.glob(f'{project_folder}/**/*.doc', recursive=True)]
116
+ # [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)] + \
117
+ # [f for f in glob.glob(f'{project_folder}/**/*.cpp', recursive=True)] + \
118
+ # [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
119
+
120
+ # 如果没找到任何文件
121
+ if len(file_manifest) == 0:
122
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.docx或doc文件: {txt}")
123
+ yield chatbot, history, '正常'
124
+ return
125
+
126
+ # 开始正式执行任务
127
+ yield from 解析docx(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
crazy_functions/批量总结PDF文档.py CHANGED
@@ -1,7 +1,61 @@
1
- from predict import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
 
 
3
  fast_debug = False
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def 解析PDF(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt):
7
  import time, glob, os, fitz
@@ -11,6 +65,7 @@ def 解析PDF(file_manifest, project_folder, top_p, temperature, chatbot, histor
11
  file_content = ""
12
  for page in doc:
13
  file_content += page.get_text()
 
14
  print(file_content)
15
 
16
  prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
@@ -58,7 +113,7 @@ def 批量总结PDF文档(txt, top_p, temperature, chatbot, history, systemPromp
58
  # 基本信息:功能、贡献者
59
  chatbot.append([
60
  "函数插件功能?",
61
- "批量总结PDF文档。函数插件贡献者: ValeriaWong"])
62
  yield chatbot, history, '正常'
63
 
64
  # 尝试导入依赖,如果缺少依赖,则给出安装建议
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
+ import re
4
+ import unicodedata
5
  fast_debug = False
6
 
7
+ def is_paragraph_break(match):
8
+ """
9
+ 根据给定的匹配结果来判断换行符是否表示段落分隔。
10
+ 如果换行符前为句子结束标志(句号,感叹号,问号),且下一个字符为大写字母,则换行符更有可能表示段落分隔。
11
+ 也可以根据之前的内容长度来判断段落是否已经足够长。
12
+ """
13
+ prev_char, next_char = match.groups()
14
+
15
+ # 句子结束标志
16
+ sentence_endings = ".!?"
17
+
18
+ # 设定一个最小段落长度阈值
19
+ min_paragraph_length = 140
20
+
21
+ if prev_char in sentence_endings and next_char.isupper() and len(match.string[:match.start(1)]) > min_paragraph_length:
22
+ return "\n\n"
23
+ else:
24
+ return " "
25
+
26
+ def normalize_text(text):
27
+ """
28
+ 通过把连字(ligatures)等文本特殊符号转换为其基本形式来对文本进行归一化处理。
29
+ 例如,将连字 "fi" 转换为 "f" 和 "i"。
30
+ """
31
+ # 对文本进行归一化处理,分解连字
32
+ normalized_text = unicodedata.normalize("NFKD", text)
33
+
34
+ # 替换其他特殊字符
35
+ cleaned_text = re.sub(r'[^\x00-\x7F]+', '', normalized_text)
36
+
37
+ return cleaned_text
38
+
39
+ def clean_text(raw_text):
40
+ """
41
+ 对从 PDF 提取出的原始文本进行清洗和格式化处理。
42
+ 1. 对原始文本进行归一化处理。
43
+ 2. 替换跨行的连词,例如 “Espe-\ncially” 转换为 “Especially”。
44
+ 3. 根据 heuristic 规则判断换行符是否是段落分隔,并相应地进行替换。
45
+ """
46
+ # 对文本进行归一化处理
47
+ normalized_text = normalize_text(raw_text)
48
+
49
+ # 替换跨行的连词
50
+ text = re.sub(r'(\w+-\n\w+)', lambda m: m.group(1).replace('-\n', ''), normalized_text)
51
+
52
+ # 根据前后相邻字符的特点,找到原文本中的换行符
53
+ newlines = re.compile(r'(\S)\n(\S)')
54
+
55
+ # 根据 heuristic 规则,用空格或段落分隔符替换原换行符
56
+ final_text = re.sub(newlines, lambda m: m.group(1) + is_paragraph_break(m) + m.group(2), text)
57
+
58
+ return final_text.strip()
59
 
60
  def 解析PDF(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt):
61
  import time, glob, os, fitz
 
65
  file_content = ""
66
  for page in doc:
67
  file_content += page.get_text()
68
+ file_content = clean_text(file_content)
69
  print(file_content)
70
 
71
  prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
 
113
  # 基本信息:功能、贡献者
114
  chatbot.append([
115
  "函数插件功能?",
116
+ "批量总结PDF文档。函数插件贡献者: ValeriaWong,Eralien"])
117
  yield chatbot, history, '正常'
118
 
119
  # 尝试导入依赖,如果缺少依赖,则给出安装建议
crazy_functions/批量总结PDF文档pdfminer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
+ from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
+
4
+ fast_debug = False
5
+
6
+ def readPdf(pdfPath):
7
+ """
8
+ 读取pdf文件,返回文本内容
9
+ """
10
+ import pdfminer
11
+ from pdfminer.pdfparser import PDFParser
12
+ from pdfminer.pdfdocument import PDFDocument
13
+ from pdfminer.pdfpage import PDFPage, PDFTextExtractionNotAllowed
14
+ from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
15
+ from pdfminer.pdfdevice import PDFDevice
16
+ from pdfminer.layout import LAParams
17
+ from pdfminer.converter import PDFPageAggregator
18
+
19
+ fp = open(pdfPath, 'rb')
20
+
21
+ # Create a PDF parser object associated with the file object
22
+ parser = PDFParser(fp)
23
+
24
+ # Create a PDF document object that stores the document structure.
25
+ # Password for initialization as 2nd parameter
26
+ document = PDFDocument(parser)
27
+ # Check if the document allows text extraction. If not, abort.
28
+ if not document.is_extractable:
29
+ raise PDFTextExtractionNotAllowed
30
+
31
+ # Create a PDF resource manager object that stores shared resources.
32
+ rsrcmgr = PDFResourceManager()
33
+
34
+ # Create a PDF device object.
35
+ # device = PDFDevice(rsrcmgr)
36
+
37
+ # BEGIN LAYOUT ANALYSIS.
38
+ # Set parameters for analysis.
39
+ laparams = LAParams(
40
+ char_margin=10.0,
41
+ line_margin=0.2,
42
+ boxes_flow=0.2,
43
+ all_texts=False,
44
+ )
45
+ # Create a PDF page aggregator object.
46
+ device = PDFPageAggregator(rsrcmgr, laparams=laparams)
47
+ # Create a PDF interpreter object.
48
+ interpreter = PDFPageInterpreter(rsrcmgr, device)
49
+
50
+ # loop over all pages in the document
51
+ outTextList = []
52
+ for page in PDFPage.create_pages(document):
53
+ # read the page into a layout object
54
+ interpreter.process_page(page)
55
+ layout = device.get_result()
56
+ for obj in layout._objs:
57
+ if isinstance(obj, pdfminer.layout.LTTextBoxHorizontal):
58
+ # print(obj.get_text())
59
+ outTextList.append(obj.get_text())
60
+
61
+ return outTextList
62
+
63
+
64
+ def 解析Paper(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt):
65
+ import time, glob, os
66
+ from bs4 import BeautifulSoup
67
+ print('begin analysis on:', file_manifest)
68
+ for index, fp in enumerate(file_manifest):
69
+ if ".tex" in fp:
70
+ with open(fp, 'r', encoding='utf-8') as f:
71
+ file_content = f.read()
72
+ if ".pdf" in fp.lower():
73
+ file_content = readPdf(fp)
74
+ file_content = BeautifulSoup(''.join(file_content), features="lxml").body.text.encode('gbk', 'ignore').decode('gbk')
75
+
76
+ prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
77
+ i_say = prefix + f'请对下面的文章片段用中文做一个概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{file_content}```'
78
+ i_say_show_user = prefix + f'[{index}/{len(file_manifest)}] 请对下面的文章片段做一个概述: {os.path.abspath(fp)}'
79
+ chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
80
+ print('[1] yield chatbot, history')
81
+ yield chatbot, history, '正常'
82
+
83
+ if not fast_debug:
84
+ msg = '正常'
85
+ # ** gpt request **
86
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[]) # 带超时倒计时
87
+
88
+ print('[2] end gpt req')
89
+ chatbot[-1] = (i_say_show_user, gpt_say)
90
+ history.append(i_say_show_user); history.append(gpt_say)
91
+ print('[3] yield chatbot, history')
92
+ yield chatbot, history, msg
93
+ print('[4] next')
94
+ if not fast_debug: time.sleep(2)
95
+
96
+ all_file = ', '.join([os.path.relpath(fp, project_folder) for index, fp in enumerate(file_manifest)])
97
+ i_say = f'根据以上你自己的分析,对全文进行概括,用学术性语言写一段中文摘要,然后再写一段英文摘要(包括{all_file})。'
98
+ chatbot.append((i_say, "[Local Message] waiting gpt response."))
99
+ yield chatbot, history, '正常'
100
+
101
+ if not fast_debug:
102
+ msg = '正常'
103
+ # ** gpt request **
104
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say, chatbot, top_p, temperature, history=history) # 带超时倒计时
105
+
106
+ chatbot[-1] = (i_say, gpt_say)
107
+ history.append(i_say); history.append(gpt_say)
108
+ yield chatbot, history, msg
109
+ res = write_results_to_file(history)
110
+ chatbot.append(("完成了吗?", res))
111
+ yield chatbot, history, msg
112
+
113
+
114
+
115
+ @CatchException
116
+ def 批量总结PDF文档pdfminer(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
117
+ history = [] # 清空历史,以免输入溢出
118
+ import glob, os
119
+
120
+ # 基本信息:功能、贡献者
121
+ chatbot.append([
122
+ "函数插件功能?",
123
+ "批量总结PDF文档,此版本使用pdfminer插件,带token约简功能。函数插件贡献者: Euclid-Jie。"])
124
+ yield chatbot, history, '正常'
125
+
126
+ # 尝试导入依赖,如果缺少依赖,则给出安装建议
127
+ try:
128
+ import pdfminer, bs4
129
+ except:
130
+ report_execption(chatbot, history,
131
+ a = f"解析项目: {txt}",
132
+ b = f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade pdfminer beautifulsoup4```。")
133
+ yield chatbot, history, '正常'
134
+ return
135
+ if os.path.exists(txt):
136
+ project_folder = txt
137
+ else:
138
+ if txt == "": txt = '空空如也的输入栏'
139
+ report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
140
+ yield chatbot, history, '正常'
141
+ return
142
+ file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)] + \
143
+ [f for f in glob.glob(f'{project_folder}/**/*.pdf', recursive=True)] # + \
144
+ # [f for f in glob.glob(f'{project_folder}/**/*.cpp', recursive=True)] + \
145
+ # [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
146
+ if len(file_manifest) == 0:
147
+ report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.tex或pdf文件: {txt}")
148
+ yield chatbot, history, '正常'
149
+ return
150
+ yield from 解析Paper(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
151
+
crazy_functions/批量翻译PDF文档_多线程.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolbox import CatchException, report_execption, write_results_to_file
2
+ from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
3
+ from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
4
+
5
+
6
+ def read_and_clean_pdf_text(fp):
7
+ """
8
+ **输入参数说明**
9
+ - `fp`:需要读取和清理文本的pdf文件路径
10
+
11
+ **输出参数说明**
12
+ - `meta_txt`:清理后的文本内容字符串
13
+ - `page_one_meta`:第一页清理后的文本内容列表
14
+
15
+ **函数功能**
16
+ 读取pdf文件并清理其中的文本内容,清理规则包括:
17
+ - 提取所有块元的文本信息,并合并为一个字符串
18
+ - 去除短块(字符数小于100)并替换为回车符
19
+ - 清理多余的空行
20
+ - 合并小写字母开头的段落块并替换为空格
21
+ - 清除重复的换行
22
+ - 将每个换行符替换为两个换行符,使每个段落之间有两个换行符分隔
23
+ """
24
+ import fitz
25
+ import re
26
+ import numpy as np
27
+ # file_content = ""
28
+ with fitz.open(fp) as doc:
29
+ meta_txt = []
30
+ meta_font = []
31
+ for index, page in enumerate(doc):
32
+ # file_content += page.get_text()
33
+ text_areas = page.get_text("dict") # 获取页面上的文本信息
34
+
35
+ # 块元提取 for each word segment with in line for each line cross-line words for each block
36
+ meta_txt.extend([" ".join(["".join([wtf['text'] for wtf in l['spans']]) for l in t['lines']]).replace(
37
+ '- ', '') for t in text_areas['blocks'] if 'lines' in t])
38
+ meta_font.extend([np.mean([np.mean([wtf['size'] for wtf in l['spans']])
39
+ for l in t['lines']]) for t in text_areas['blocks'] if 'lines' in t])
40
+ if index == 0:
41
+ page_one_meta = [" ".join(["".join([wtf['text'] for wtf in l['spans']]) for l in t['lines']]).replace(
42
+ '- ', '') for t in text_areas['blocks'] if 'lines' in t]
43
+
44
+ def 把字符太少的块清除为回车(meta_txt):
45
+ for index, block_txt in enumerate(meta_txt):
46
+ if len(block_txt) < 100:
47
+ meta_txt[index] = '\n'
48
+ return meta_txt
49
+ meta_txt = 把字符太少的块清除为回车(meta_txt)
50
+
51
+ def 清理多余的空行(meta_txt):
52
+ for index in reversed(range(1, len(meta_txt))):
53
+ if meta_txt[index] == '\n' and meta_txt[index-1] == '\n':
54
+ meta_txt.pop(index)
55
+ return meta_txt
56
+ meta_txt = 清理多余的空行(meta_txt)
57
+
58
+ def 合并小写开头的段落块(meta_txt):
59
+ def starts_with_lowercase_word(s):
60
+ pattern = r"^[a-z]+"
61
+ match = re.match(pattern, s)
62
+ if match:
63
+ return True
64
+ else:
65
+ return False
66
+ for _ in range(100):
67
+ for index, block_txt in enumerate(meta_txt):
68
+ if starts_with_lowercase_word(block_txt):
69
+ if meta_txt[index-1] != '\n':
70
+ meta_txt[index-1] += ' '
71
+ else:
72
+ meta_txt[index-1] = ''
73
+ meta_txt[index-1] += meta_txt[index]
74
+ meta_txt[index] = '\n'
75
+ return meta_txt
76
+ meta_txt = 合并小写开头的段落块(meta_txt)
77
+ meta_txt = 清理多余的空行(meta_txt)
78
+
79
+ meta_txt = '\n'.join(meta_txt)
80
+ # 清除重复的换行
81
+ for _ in range(5):
82
+ meta_txt = meta_txt.replace('\n\n', '\n')
83
+
84
+ # 换行 -> 双换行
85
+ meta_txt = meta_txt.replace('\n', '\n\n')
86
+
87
+ return meta_txt, page_one_meta
88
+
89
+
90
+ @CatchException
91
+ def 批量翻译PDF文档(txt, top_p, temperature, chatbot, history, sys_prompt, WEB_PORT):
92
+ import glob
93
+ import os
94
+
95
+ # 基本信息:功能、贡献者
96
+ chatbot.append([
97
+ "函数插件功能?",
98
+ "批量总结PDF文档。函数插件贡献者: Binary-Husky(二进制哈士奇)"])
99
+ yield chatbot, history, '正常'
100
+
101
+ # 尝试导入依赖,如果缺少依赖,则给出安装建议
102
+ try:
103
+ import fitz
104
+ import tiktoken
105
+ except:
106
+ report_execption(chatbot, history,
107
+ a=f"解析项目: {txt}",
108
+ b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade pymupdf tiktoken```。")
109
+ yield chatbot, history, '正常'
110
+ return
111
+
112
+ # 清空历史,以免输入溢出
113
+ history = []
114
+
115
+ # 检测输入参数,如没有给定输入参数,直接退出
116
+ if os.path.exists(txt):
117
+ project_folder = txt
118
+ else:
119
+ if txt == "":
120
+ txt = '空空如也的输入栏'
121
+ report_execption(chatbot, history,
122
+ a=f"解��项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
123
+ yield chatbot, history, '正常'
124
+ return
125
+
126
+ # 搜索需要处理的文件清单
127
+ file_manifest = [f for f in glob.glob(
128
+ f'{project_folder}/**/*.pdf', recursive=True)]
129
+
130
+ # 如果没找到任何文件
131
+ if len(file_manifest) == 0:
132
+ report_execption(chatbot, history,
133
+ a=f"解析项目: {txt}", b=f"找不到任何.tex或.pdf文件: {txt}")
134
+ yield chatbot, history, '正常'
135
+ return
136
+
137
+ # 开始正式执行任务
138
+ yield from 解析PDF(file_manifest, project_folder, top_p, temperature, chatbot, history, sys_prompt)
139
+
140
+
141
+ def 解析PDF(file_manifest, project_folder, top_p, temperature, chatbot, history, sys_prompt):
142
+ import os
143
+ import tiktoken
144
+ TOKEN_LIMIT_PER_FRAGMENT = 1600
145
+ generated_conclusion_files = []
146
+ for index, fp in enumerate(file_manifest):
147
+ # 读取PDF文件
148
+ file_content, page_one = read_and_clean_pdf_text(fp)
149
+ # 递归地切割PDF文件
150
+ from .crazy_utils import breakdown_txt_to_satisfy_token_limit_for_pdf
151
+ enc = tiktoken.get_encoding("gpt2")
152
+ def get_token_num(txt): return len(enc.encode(txt))
153
+ # 分解文本
154
+ paper_fragments = breakdown_txt_to_satisfy_token_limit_for_pdf(
155
+ txt=file_content, get_token_fn=get_token_num, limit=TOKEN_LIMIT_PER_FRAGMENT)
156
+ page_one_fragments = breakdown_txt_to_satisfy_token_limit_for_pdf(
157
+ txt=str(page_one), get_token_fn=get_token_num, limit=TOKEN_LIMIT_PER_FRAGMENT//4)
158
+ # 为了更好的效果,我们剥离Introduction之后的部分
159
+ paper_meta = page_one_fragments[0].split('introduction')[0].split(
160
+ 'Introduction')[0].split('INTRODUCTION')[0]
161
+ # 单线,获取文章meta信息
162
+ paper_meta_info = yield from request_gpt_model_in_new_thread_with_ui_alive(
163
+ inputs=f"以下是一篇学术论文的基础信息,请从中提取出“标题”、“收录会议或期刊”、“作者”、“摘要”、“编号”、“作者邮箱”这六个部分。请用markdown格式输出,最后用中文翻译摘要部分。请提取:{paper_meta}",
164
+ inputs_show_user=f"请从{fp}中提取出“标题”、“收录会议或期刊”等基本信息。",
165
+ top_p=top_p, temperature=temperature,
166
+ chatbot=chatbot, history=[],
167
+ sys_prompt="Your job is to collect information from materials。",
168
+ )
169
+ # 多线,翻译
170
+ gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
171
+ inputs_array=[
172
+ f"以下是你需要翻译的文章段落:\n{frag}" for frag in paper_fragments],
173
+ inputs_show_user_array=[f"" for _ in paper_fragments],
174
+ top_p=top_p, temperature=temperature,
175
+ chatbot=chatbot,
176
+ history_array=[[paper_meta] for _ in paper_fragments],
177
+ sys_prompt_array=[
178
+ "请你作为一个学术翻译,把整个段落翻译成中文,要求语言简洁,禁止重复输出原文。" for _ in paper_fragments],
179
+ max_workers=16 # OpenAI所允许的最大并行过载
180
+ )
181
+
182
+ final = ["", paper_meta_info + '\n\n---\n\n---\n\n---\n\n']
183
+ final.extend(gpt_response_collection)
184
+ create_report_file_name = f"{os.path.basename(fp)}.trans.md"
185
+ res = write_results_to_file(final, file_name=create_report_file_name)
186
+ generated_conclusion_files.append(
187
+ f'./gpt_log/{create_report_file_name}')
188
+ chatbot.append((f"{fp}完成了吗?", res))
189
+ msg = "完成"
190
+ yield chatbot, history, msg
191
+
192
+ # 准备文件的下载
193
+ import shutil
194
+ for pdf_path in generated_conclusion_files:
195
+ # 重命名文件
196
+ rename_file = f'./gpt_log/总结论文-{os.path.basename(pdf_path)}'
197
+ if os.path.exists(rename_file):
198
+ os.remove(rename_file)
199
+ shutil.copyfile(pdf_path, rename_file)
200
+ if os.path.exists(pdf_path):
201
+ os.remove(pdf_path)
202
+ chatbot.append(("给出输出文件清单", str(generated_conclusion_files)))
203
+ yield chatbot, history, msg
crazy_functions/生成函数注释.py CHANGED
@@ -1,4 +1,4 @@
1
- from predict import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
  fast_debug = False
4
 
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
  fast_debug = False
4
 
crazy_functions/解析项目源代码.py CHANGED
@@ -1,4 +1,4 @@
1
- from predict import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
  fast_debug = False
4
 
@@ -50,7 +50,8 @@ def 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot,
50
  def 解析项目本身(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
51
  history = [] # 清空历史,以免输入溢出
52
  import time, glob, os
53
- file_manifest = [f for f in glob.glob('*.py')]
 
54
  for index, fp in enumerate(file_manifest):
55
  # if 'test_project' in fp: continue
56
  with open(fp, 'r', encoding='utf-8') as f:
@@ -65,7 +66,7 @@ def 解析项目本身(txt, top_p, temperature, chatbot, history, systemPromptTx
65
  if not fast_debug:
66
  # ** gpt request **
67
  # gpt_say = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature)
68
- gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[]) # 带超时倒计时
69
 
70
  chatbot[-1] = (i_say_show_user, gpt_say)
71
  history.append(i_say_show_user); history.append(gpt_say)
@@ -79,7 +80,7 @@ def 解析项目本身(txt, top_p, temperature, chatbot, history, systemPromptTx
79
  if not fast_debug:
80
  # ** gpt request **
81
  # gpt_say = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history)
82
- gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say, chatbot, top_p, temperature, history=history) # 带超时倒计时
83
 
84
  chatbot[-1] = (i_say, gpt_say)
85
  history.append(i_say); history.append(gpt_say)
@@ -118,8 +119,8 @@ def 解析一个C项目的头文件(txt, top_p, temperature, chatbot, history, s
118
  report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
119
  yield chatbot, history, '正常'
120
  return
121
- file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.h', recursive=True)] # + \
122
- # [f for f in glob.glob(f'{project_folder}/**/*.cpp', recursive=True)] + \
123
  # [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
124
  if len(file_manifest) == 0:
125
  report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.h头文件: {txt}")
@@ -140,6 +141,7 @@ def 解析一个C项目(txt, top_p, temperature, chatbot, history, systemPromptT
140
  return
141
  file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.h', recursive=True)] + \
142
  [f for f in glob.glob(f'{project_folder}/**/*.cpp', recursive=True)] + \
 
143
  [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
144
  if len(file_manifest) == 0:
145
  report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.h头文件: {txt}")
@@ -147,3 +149,66 @@ def 解析一个C项目(txt, top_p, temperature, chatbot, history, systemPromptT
147
  return
148
  yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
  fast_debug = False
4
 
 
50
  def 解析项目本身(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
51
  history = [] # 清空历史,以免输入溢出
52
  import time, glob, os
53
+ file_manifest = [f for f in glob.glob('./*.py') if ('test_project' not in f) and ('gpt_log' not in f)] + \
54
+ [f for f in glob.glob('./crazy_functions/*.py') if ('test_project' not in f) and ('gpt_log' not in f)]
55
  for index, fp in enumerate(file_manifest):
56
  # if 'test_project' in fp: continue
57
  with open(fp, 'r', encoding='utf-8') as f:
 
66
  if not fast_debug:
67
  # ** gpt request **
68
  # gpt_say = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature)
69
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], long_connection=True) # 带超时倒计时
70
 
71
  chatbot[-1] = (i_say_show_user, gpt_say)
72
  history.append(i_say_show_user); history.append(gpt_say)
 
80
  if not fast_debug:
81
  # ** gpt request **
82
  # gpt_say = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history)
83
+ gpt_say = yield from predict_no_ui_but_counting_down(i_say, i_say, chatbot, top_p, temperature, history=history, long_connection=True) # 带超时倒计时
84
 
85
  chatbot[-1] = (i_say, gpt_say)
86
  history.append(i_say); history.append(gpt_say)
 
119
  report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
120
  yield chatbot, history, '正常'
121
  return
122
+ file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.h', recursive=True)] + \
123
+ [f for f in glob.glob(f'{project_folder}/**/*.hpp', recursive=True)] #+ \
124
  # [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
125
  if len(file_manifest) == 0:
126
  report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.h头文件: {txt}")
 
141
  return
142
  file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.h', recursive=True)] + \
143
  [f for f in glob.glob(f'{project_folder}/**/*.cpp', recursive=True)] + \
144
+ [f for f in glob.glob(f'{project_folder}/**/*.hpp', recursive=True)] + \
145
  [f for f in glob.glob(f'{project_folder}/**/*.c', recursive=True)]
146
  if len(file_manifest) == 0:
147
  report_execption(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.h头文件: {txt}")
 
149
  return
150
  yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
151
 
152
+
153
+ @CatchException
154
+ def 解析一个Java项目(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
155
+ history = [] # 清空历史,以免输入溢出
156
+ import glob, os
157
+ if os.path.exists(txt):
158
+ project_folder = txt
159
+ else:
160
+ if txt == "": txt = '空空如也的输入栏'
161
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
162
+ yield chatbot, history, '正常'
163
+ return
164
+ file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.java', recursive=True)] + \
165
+ [f for f in glob.glob(f'{project_folder}/**/*.jar', recursive=True)] + \
166
+ [f for f in glob.glob(f'{project_folder}/**/*.xml', recursive=True)] + \
167
+ [f for f in glob.glob(f'{project_folder}/**/*.sh', recursive=True)]
168
+ if len(file_manifest) == 0:
169
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何java文件: {txt}")
170
+ yield chatbot, history, '正常'
171
+ return
172
+ yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
173
+
174
+
175
+ @CatchException
176
+ def 解析一个Rect项目(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
177
+ history = [] # 清空历史,以免输入溢出
178
+ import glob, os
179
+ if os.path.exists(txt):
180
+ project_folder = txt
181
+ else:
182
+ if txt == "": txt = '空空如也的输入栏'
183
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
184
+ yield chatbot, history, '正常'
185
+ return
186
+ file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.ts', recursive=True)] + \
187
+ [f for f in glob.glob(f'{project_folder}/**/*.tsx', recursive=True)] + \
188
+ [f for f in glob.glob(f'{project_folder}/**/*.json', recursive=True)] + \
189
+ [f for f in glob.glob(f'{project_folder}/**/*.js', recursive=True)] + \
190
+ [f for f in glob.glob(f'{project_folder}/**/*.jsx', recursive=True)]
191
+ if len(file_manifest) == 0:
192
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何Rect文件: {txt}")
193
+ yield chatbot, history, '正常'
194
+ return
195
+ yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
196
+
197
+
198
+ @CatchException
199
+ def 解析一个Golang项目(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
200
+ history = [] # 清空历史,以免输入溢出
201
+ import glob, os
202
+ if os.path.exists(txt):
203
+ project_folder = txt
204
+ else:
205
+ if txt == "": txt = '空空如也的输入栏'
206
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
207
+ yield chatbot, history, '正常'
208
+ return
209
+ file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.go', recursive=True)]
210
+ if len(file_manifest) == 0:
211
+ report_execption(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何golang文件: {txt}")
212
+ yield chatbot, history, '正常'
213
+ return
214
+ yield from 解析源代码(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)
crazy_functions/读文章写摘要.py CHANGED
@@ -1,4 +1,4 @@
1
- from predict import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
  fast_debug = False
4
 
 
1
+ from request_llm.bridge_chatgpt import predict_no_ui
2
  from toolbox import CatchException, report_execption, write_results_to_file, predict_no_ui_but_counting_down
3
  fast_debug = False
4
 
crazy_functions/高级功能函数模板.py CHANGED
@@ -1,25 +1,20 @@
1
- from predict import predict_no_ui_long_connection
2
- from toolbox import CatchException, report_execption, write_results_to_file
3
  import datetime
4
-
5
  @CatchException
6
  def 高阶功能模板函数(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
7
  history = [] # 清空历史,以免输入溢出
8
- chatbot.append(("这是什么功能?", "[Local Message] 请注意,您正在调用一个函数模板,该函数面向希望实现更多有趣功能的开发者,它可以作为创建新功能函数的模板。为了做到简单易读,该函数只有25行代码,不会实时反馈文字流或心跳,请耐心等待程序输出完成。另外您若希望分享新的功能模组,请不吝PR!"))
9
  yield chatbot, history, '正常' # 由于请求gpt需要一段时间,我们先及时地做一次状态显示
10
-
11
  for i in range(5):
12
  currentMonth = (datetime.date.today() + datetime.timedelta(days=i)).month
13
  currentDay = (datetime.date.today() + datetime.timedelta(days=i)).day
14
- i_say = f'历史中哪些事件发生在{currentMonth}月{currentDay}日?列举两条并发送相关图片。发送图片时,请使用Markdown,将Unsplash API中的PUT_YOUR_QUERY_HERE替换成描述改事件的三个最重要的单词。'
15
- chatbot.append((i_say, "[Local Message] waiting gpt response."))
16
- yield chatbot, history, '正常' # 由于请求gpt需要一段时间,我们先及时地做一次状态显示
17
-
18
- # history = [] 每次询问不携带之前的询问历史
19
- gpt_say = predict_no_ui_long_connection(
20
- inputs=i_say, top_p=top_p, temperature=temperature, history=[],
21
- sys_prompt="当你想发送一张照片时,请使用Markdown, 并且不要有反斜线, 不要用代码块。使用 Unsplash API (https://source.unsplash.com/1280x720/? < PUT_YOUR_QUERY_HERE >)。") # 请求gpt,需要一段时间
22
-
23
  chatbot[-1] = (i_say, gpt_say)
24
  history.append(i_say);history.append(gpt_say)
25
- yield chatbot, history, '正常' # 显示
 
1
+ from toolbox import CatchException
2
+ from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
3
  import datetime
 
4
  @CatchException
5
  def 高阶功能模板函数(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
6
  history = [] # 清空历史,以免输入溢出
7
+ chatbot.append(("这是什么功能?", "[Local Message] 请注意,您正在调用一个[函数插件]的模板,该函数面向希望实现更多有趣功能的开发者,它可以作为创建新功能函数的模板(该函数只有20行代码)。此外我们也提供可同步处理大量文件的多线程Demo供您参考。您若希望分享新的功能模组,请不吝PR!"))
8
  yield chatbot, history, '正常' # 由于请求gpt需要一段时间,我们先及时地做一次状态显示
 
9
  for i in range(5):
10
  currentMonth = (datetime.date.today() + datetime.timedelta(days=i)).month
11
  currentDay = (datetime.date.today() + datetime.timedelta(days=i)).day
12
+ i_say = f'历史中哪些事件发生在{currentMonth}月{currentDay}日?列举两条并发送相关图片。发送图片时,请使用Markdown,将Unsplash API中的PUT_YOUR_QUERY_HERE替换成描述该事件的一个最重要的单词。'
13
+ gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
14
+ inputs=i_say, inputs_show_user=i_say,
15
+ top_p=top_p, temperature=temperature, chatbot=chatbot, history=[],
16
+ sys_prompt="当你想发送一张照片时,请使用Markdown, 并且不要有反斜线, 不要用代码块。使用 Unsplash API (https://source.unsplash.com/1280x720/? < PUT_YOUR_QUERY_HERE >)。"
17
+ )
 
 
 
18
  chatbot[-1] = (i_say, gpt_say)
19
  history.append(i_say);history.append(gpt_say)
20
+ yield chatbot, history, '正常'
functional_crazy.py DELETED
@@ -1,66 +0,0 @@
1
- # UserVisibleLevel是过滤器参数。
2
- # 由于UI界面空间有限,所以通过这种方式决定UI界面中显示哪些插件
3
- # 默认函数插件 VisibleLevel 是 0
4
- # 当 UserVisibleLevel >= 函数插件的 VisibleLevel 时,该函数插件才会被显示出来
5
- UserVisibleLevel = 1
6
-
7
- def get_crazy_functionals():
8
- from crazy_functions.读文章写摘要 import 读文章写摘要
9
- from crazy_functions.生成函数注释 import 批量生成函数注释
10
- from crazy_functions.解析项目源代码 import 解析项目本身
11
- from crazy_functions.解析项目源代码 import 解析一个Python项目
12
- from crazy_functions.解析项目源代码 import 解析一个C项目的头文件
13
- from crazy_functions.解析项目源代码 import 解析一个C项目
14
- from crazy_functions.高级功能函数模板 import 高阶功能模板函数
15
- from crazy_functions.代码重写为全英文_多线程 import 全项目切换英文
16
-
17
- function_plugins = {
18
- "请解析并解构此项目本身": {
19
- "Function": 解析项目本身
20
- },
21
- "解析整个py项目": {
22
- "Color": "stop", # 按钮颜色
23
- "Function": 解析一个Python项目
24
- },
25
- "解析整个C++项目头文件": {
26
- "Color": "stop", # 按钮颜色
27
- "Function": 解析一个C项目的头文件
28
- },
29
- "解析整个C++项目": {
30
- "Color": "stop", # 按钮颜色
31
- "Function": 解析一个C项目
32
- },
33
- "读tex论文写摘要": {
34
- "Color": "stop", # 按钮颜色
35
- "Function": 读文章写摘要
36
- },
37
- "批量生成函数注释": {
38
- "Color": "stop", # 按钮颜色
39
- "Function": 批量生成函数注释
40
- },
41
- "[多线程demo] 把本项目源代码切换成全英文": {
42
- "Function": 全项目切换英文
43
- },
44
- "[函数插件模板demo] 历史上的今天": {
45
- "Function": 高阶功能模板函数
46
- },
47
- }
48
-
49
- # VisibleLevel=1 经过测试,但功能未达到理想状态
50
- if UserVisibleLevel >= 1:
51
- from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
52
- function_plugins.update({
53
- "[仅供开发调试] 批量总结PDF文档": {
54
- "Color": "stop",
55
- "Function": 批量总结PDF文档
56
- },
57
- })
58
-
59
- # VisibleLevel=2 尚未充分测试的函数插件,放在这里
60
- if UserVisibleLevel >= 2:
61
- function_plugins.update({
62
- })
63
-
64
- return function_plugins
65
-
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,112 +1,173 @@
1
  import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
2
  import gradio as gr
3
- from predict import predict
4
- from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf
5
 
6
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
7
- proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = \
8
- get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION')
9
-
10
 
11
  # 如果WEB_PORT是-1, 则随机选取WEB端口
12
  PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
13
- AUTHENTICATION = None if AUTHENTICATION == [] else AUTHENTICATION
14
 
15
  initial_prompt = "Serve me as a writing and programming assistant."
16
- title_html = """<h1 align="center">ChatGPT 学术优化</h1>"""
 
17
 
18
  # 问询记录, python 版本建议3.9+(越新越好)
19
  import logging
20
- os.makedirs('gpt_log', exist_ok=True)
21
- try:logging.basicConfig(filename='gpt_log/chat_secrets.log', level=logging.INFO, encoding='utf-8')
22
- except:logging.basicConfig(filename='gpt_log/chat_secrets.log', level=logging.INFO)
23
- print('所有问询记录将自动保存在本地目录./gpt_log/chat_secrets.log, 请注意自我隐私保护哦!')
24
 
25
  # 一些普通功能模块
26
- from functional import get_functionals
27
- functional = get_functionals()
28
 
29
- # 对一些丧心病狂的实验性功能模块进行测试
30
- from functional_crazy import get_crazy_functionals
31
- crazy_functional = get_crazy_functionals()
32
 
33
  # 处理markdown文本格式的转变
34
  gr.Chatbot.postprocess = format_io
35
 
36
  # 做一些外观色彩上的调整
37
- from theme import adjust_theme
38
  set_theme = adjust_theme()
39
 
 
 
 
 
 
 
 
 
 
 
 
40
  cancel_handles = []
41
- with gr.Blocks(theme=set_theme, analytics_enabled=False) as demo:
42
  gr.HTML(title_html)
43
- with gr.Row():
44
- with gr.Column(scale=2):
45
  chatbot = gr.Chatbot()
46
- chatbot.style(height=1150)
47
- chatbot.style()
48
  history = gr.State([])
49
- with gr.Column(scale=1):
50
- with gr.Row():
51
- txt = gr.Textbox(show_label=False, placeholder="Input question here.").style(container=False)
52
- with gr.Row():
53
- submitBtn = gr.Button("提交", variant="primary")
54
- with gr.Row():
55
- resetBtn = gr.Button("重置", variant="secondary"); resetBtn.style(size="sm")
56
- stopBtn = gr.Button("停止", variant="secondary"); stopBtn.style(size="sm")
57
- with gr.Row():
58
- from check_proxy import check_proxy
59
- statusDisplay = gr.Markdown(f"Tip: 按Enter提交, 按Shift+Enter换行。当前模型: {LLM_MODEL} \n {check_proxy(proxies)}")
60
- with gr.Row():
61
- for k in functional:
62
- variant = functional[k]["Color"] if "Color" in functional[k] else "secondary"
63
- functional[k]["Button"] = gr.Button(k, variant=variant)
64
- with gr.Row():
65
- gr.Markdown("注意:以下“红颜色”标识的函数插件需从input区读取路径作为参数.")
66
- with gr.Row():
67
- for k in crazy_functional:
68
- variant = crazy_functional[k]["Color"] if "Color" in crazy_functional[k] else "secondary"
69
- crazy_functional[k]["Button"] = gr.Button(k, variant=variant)
70
- with gr.Row():
71
- gr.Markdown("上传本地文件,供上面的函数插件调用.")
72
- with gr.Row():
73
- file_upload = gr.Files(label='任何文件, 但推荐上传压缩文件(zip, tar)', file_count="multiple")
74
- system_prompt = gr.Textbox(show_label=True, placeholder=f"System Prompt", label="System prompt", value=initial_prompt).style(container=True)
75
- with gr.Accordion("arguments", open=False):
 
 
 
 
 
 
 
 
 
 
76
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)",)
77
  temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
78
-
79
- predict_args = dict(fn=predict, inputs=[txt, top_p, temperature, chatbot, history, system_prompt], outputs=[chatbot, history, statusDisplay], show_progress=True)
80
- empty_txt_args = dict(fn=lambda: "", inputs=[], outputs=[txt]) # 用于在提交后清空输入栏
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  cancel_handles.append(txt.submit(**predict_args))
83
- # txt.submit(**empty_txt_args) 在提交后清空输入栏
84
  cancel_handles.append(submitBtn.click(**predict_args))
85
- # submitBtn.click(**empty_txt_args) 在提交后清空输入栏
86
- resetBtn.click(lambda: ([], [], "已重置"), None, [chatbot, history, statusDisplay])
 
 
87
  for k in functional:
88
- click_handle = functional[k]["Button"].click(predict,
89
- [txt, top_p, temperature, chatbot, history, system_prompt, gr.State(True), gr.State(k)], [chatbot, history, statusDisplay], show_progress=True)
90
  cancel_handles.append(click_handle)
 
91
  file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt], [chatbot, txt])
92
- for k in crazy_functional:
93
- click_handle = crazy_functional[k]["Button"].click(crazy_functional[k]["Function"],
94
- [txt, top_p, temperature, chatbot, history, system_prompt, gr.State(PORT)], [chatbot, history, statusDisplay]
95
- )
96
- try: click_handle.then(on_report_generated, [file_upload, chatbot], [file_upload, chatbot])
97
- except: pass
98
  cancel_handles.append(click_handle)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  stopBtn.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
100
-
101
  # gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
102
  def auto_opentab_delay():
103
  import threading, webbrowser, time
104
- print(f"URL http://localhost:{PORT}")
 
 
105
  def open():
106
  time.sleep(2)
107
- webbrowser.open_new_tab(f'http://localhost:{PORT}')
108
- t = threading.Thread(target=open)
109
- t.daemon = True; t.start()
 
110
 
111
  auto_opentab_delay()
112
  demo.title = "ChatGPT 学术优化"
 
1
  import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
2
  import gradio as gr
3
+ from request_llm.bridge_chatgpt import predict
4
+ from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
5
 
6
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
7
+ proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT = \
8
+ get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT')
 
9
 
10
  # 如果WEB_PORT是-1, 则随机选取WEB端口
11
  PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
12
+ if not AUTHENTICATION: AUTHENTICATION = None
13
 
14
  initial_prompt = "Serve me as a writing and programming assistant."
15
+ title_html = "<h1 align=\"center\">ChatGPT 学术优化</h1>"
16
+ description = """代码开源和更新[地址🚀](https://github.com/binary-husky/chatgpt_academic),感谢热情的[开发者们❤️](https://github.com/binary-husky/chatgpt_academic/graphs/contributors)"""
17
 
18
  # 问询记录, python 版本建议3.9+(越新越好)
19
  import logging
20
+ os.makedirs("gpt_log", exist_ok=True)
21
+ try:logging.basicConfig(filename="gpt_log/chat_secrets.log", level=logging.INFO, encoding="utf-8")
22
+ except:logging.basicConfig(filename="gpt_log/chat_secrets.log", level=logging.INFO)
23
+ print("所有问询记录将自动保存在本地目录./gpt_log/chat_secrets.log, 请注意自我隐私保护哦!")
24
 
25
  # 一些普通功能模块
26
+ from core_functional import get_core_functions
27
+ functional = get_core_functions()
28
 
29
+ # 高级函数插件
30
+ from crazy_functional import get_crazy_functions
31
+ crazy_fns = get_crazy_functions()
32
 
33
  # 处理markdown文本格式的转变
34
  gr.Chatbot.postprocess = format_io
35
 
36
  # 做一些外观色彩上的调整
37
+ from theme import adjust_theme, advanced_css
38
  set_theme = adjust_theme()
39
 
40
+ # 代理与自动更新
41
+ from check_proxy import check_proxy, auto_update
42
+ proxy_info = check_proxy(proxies)
43
+
44
+ gr_L1 = lambda: gr.Row().style()
45
+ gr_L2 = lambda scale: gr.Column(scale=scale)
46
+ if LAYOUT == "TOP-DOWN":
47
+ gr_L1 = lambda: DummyWith()
48
+ gr_L2 = lambda scale: gr.Row()
49
+ CHATBOT_HEIGHT /= 2
50
+
51
  cancel_handles = []
52
+ with gr.Blocks(theme=set_theme, analytics_enabled=False, css=advanced_css) as demo:
53
  gr.HTML(title_html)
54
+ with gr_L1():
55
+ with gr_L2(scale=2):
56
  chatbot = gr.Chatbot()
57
+ chatbot.style(height=CHATBOT_HEIGHT)
 
58
  history = gr.State([])
59
+ with gr_L2(scale=1):
60
+ with gr.Accordion("输入区", open=True) as area_input_primary:
61
+ with gr.Row():
62
+ txt = gr.Textbox(show_label=False, placeholder="Input question here.").style(container=False)
63
+ with gr.Row():
64
+ submitBtn = gr.Button("提交", variant="primary")
65
+ with gr.Row():
66
+ resetBtn = gr.Button("重置", variant="secondary"); resetBtn.style(size="sm")
67
+ stopBtn = gr.Button("停止", variant="secondary"); stopBtn.style(size="sm")
68
+ with gr.Row():
69
+ status = gr.Markdown(f"Tip: 按Enter提交, 按Shift+Enter换行。当前模型: {LLM_MODEL} \n {proxy_info}")
70
+ with gr.Accordion("基础功能区", open=True) as area_basic_fn:
71
+ with gr.Row():
72
+ for k in functional:
73
+ variant = functional[k]["Color"] if "Color" in functional[k] else "secondary"
74
+ functional[k]["Button"] = gr.Button(k, variant=variant)
75
+ with gr.Accordion("函数插件区", open=True) as area_crazy_fn:
76
+ with gr.Row():
77
+ gr.Markdown("注意:以下“红颜色”标识的函数插件需从输入区读取路径作为参数.")
78
+ with gr.Row():
79
+ for k in crazy_fns:
80
+ if not crazy_fns[k].get("AsButton", True): continue
81
+ variant = crazy_fns[k]["Color"] if "Color" in crazy_fns[k] else "secondary"
82
+ crazy_fns[k]["Button"] = gr.Button(k, variant=variant)
83
+ crazy_fns[k]["Button"].style(size="sm")
84
+ with gr.Row():
85
+ with gr.Accordion("更多函数插件", open=True):
86
+ dropdown_fn_list = [k for k in crazy_fns.keys() if not crazy_fns[k].get("AsButton", True)]
87
+ with gr.Column(scale=1):
88
+ dropdown = gr.Dropdown(dropdown_fn_list, value=r"打开插件列表", label="").style(container=False)
89
+ with gr.Column(scale=1):
90
+ switchy_bt = gr.Button(r"请先从插件列表中选择", variant="secondary")
91
+ with gr.Row():
92
+ with gr.Accordion("点击展开“文件上传区”。上传本地文件可供红色函数插件调用。", open=False) as area_file_up:
93
+ file_upload = gr.Files(label="任何文件, 但推荐上传压缩文件(zip, tar)", file_count="multiple")
94
+ with gr.Accordion("展开SysPrompt & 交互界面布局 & Github地址", open=(LAYOUT == "TOP-DOWN")):
95
+ system_prompt = gr.Textbox(show_label=True, placeholder=f"System Prompt", label="System prompt", value=initial_prompt)
96
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)",)
97
  temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
98
+ checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "底部输入区"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区")
99
+ gr.Markdown(description)
100
+ with gr.Accordion("备选输入区", open=True, visible=False) as area_input_secondary:
101
+ with gr.Row():
102
+ txt2 = gr.Textbox(show_label=False, placeholder="Input question here.", label="输入区2").style(container=False)
103
+ with gr.Row():
104
+ submitBtn2 = gr.Button("提交", variant="primary")
105
+ with gr.Row():
106
+ resetBtn2 = gr.Button("重置", variant="secondary"); resetBtn.style(size="sm")
107
+ stopBtn2 = gr.Button("停止", variant="secondary"); stopBtn.style(size="sm")
108
+ # 功能区显示开关与功能区的互动
109
+ def fn_area_visibility(a):
110
+ ret = {}
111
+ ret.update({area_basic_fn: gr.update(visible=("基础功能区" in a))})
112
+ ret.update({area_crazy_fn: gr.update(visible=("函数插件区" in a))})
113
+ ret.update({area_input_primary: gr.update(visible=("底部输入区" not in a))})
114
+ ret.update({area_input_secondary: gr.update(visible=("底部输入区" in a))})
115
+ if "底部输入区" in a: ret.update({txt: gr.update(value="")})
116
+ return ret
117
+ checkboxes.select(fn_area_visibility, [checkboxes], [area_basic_fn, area_crazy_fn, area_input_primary, area_input_secondary, txt, txt2] )
118
+ # 整理反复出现的控件句柄组合
119
+ input_combo = [txt, txt2, top_p, temperature, chatbot, history, system_prompt]
120
+ output_combo = [chatbot, history, status]
121
+ predict_args = dict(fn=ArgsGeneralWrapper(predict), inputs=input_combo, outputs=output_combo)
122
+ # 提交按钮、重置按钮
123
  cancel_handles.append(txt.submit(**predict_args))
124
+ cancel_handles.append(txt2.submit(**predict_args))
125
  cancel_handles.append(submitBtn.click(**predict_args))
126
+ cancel_handles.append(submitBtn2.click(**predict_args))
127
+ resetBtn.click(lambda: ([], [], "已重置"), None, output_combo)
128
+ resetBtn2.click(lambda: ([], [], "已重置"), None, output_combo)
129
+ # 基础功能区的回调函数注册
130
  for k in functional:
131
+ click_handle = functional[k]["Button"].click(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True), gr.State(k)], outputs=output_combo)
 
132
  cancel_handles.append(click_handle)
133
+ # 文件上传区,接收文件后与chatbot的互动
134
  file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt], [chatbot, txt])
135
+ # 函数插件-固定按钮区
136
+ for k in crazy_fns:
137
+ if not crazy_fns[k].get("AsButton", True): continue
138
+ click_handle = crazy_fns[k]["Button"].click(ArgsGeneralWrapper(crazy_fns[k]["Function"]), [*input_combo, gr.State(PORT)], output_combo)
139
+ click_handle.then(on_report_generated, [file_upload, chatbot], [file_upload, chatbot])
 
140
  cancel_handles.append(click_handle)
141
+ # 函数插件-下拉菜单与随变按钮的互动
142
+ def on_dropdown_changed(k):
143
+ variant = crazy_fns[k]["Color"] if "Color" in crazy_fns[k] else "secondary"
144
+ return {switchy_bt: gr.update(value=k, variant=variant)}
145
+ dropdown.select(on_dropdown_changed, [dropdown], [switchy_bt] )
146
+ # 随变按钮的回调函数注册
147
+ def route(k, *args, **kwargs):
148
+ if k in [r"打开插件列表", r"请先从插件列表中选择"]: return
149
+ yield from ArgsGeneralWrapper(crazy_fns[k]["Function"])(*args, **kwargs)
150
+ click_handle = switchy_bt.click(route,[switchy_bt, *input_combo, gr.State(PORT)], output_combo)
151
+ click_handle.then(on_report_generated, [file_upload, chatbot], [file_upload, chatbot])
152
+ # def expand_file_area(file_upload, area_file_up):
153
+ # if len(file_upload)>0: return {area_file_up: gr.update(open=True)}
154
+ # click_handle.then(expand_file_area, [file_upload, area_file_up], [area_file_up])
155
+ cancel_handles.append(click_handle)
156
+ # 终止按钮的回调函数注册
157
  stopBtn.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
158
+ stopBtn2.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
159
  # gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
160
  def auto_opentab_delay():
161
  import threading, webbrowser, time
162
+ print(f"如果浏览器没有自动打开,请复制并转到以下URL")
163
+ print(f"\t(亮色主体): http://localhost:{PORT}")
164
+ print(f"\t(暗色主体): http://localhost:{PORT}/?__dark-theme=true")
165
  def open():
166
  time.sleep(2)
167
+ try: auto_update() # 检查新版本
168
+ except: pass
169
+ webbrowser.open_new_tab(f"http://localhost:{PORT}/?__dark-theme=true")
170
+ threading.Thread(target=open, name="open-browser", daemon=True).start()
171
 
172
  auto_opentab_delay()
173
  demo.title = "ChatGPT 学术优化"
project_self_analysis.md DELETED
@@ -1,122 +0,0 @@
1
- # chatgpt-academic项目分析报告
2
- (Author补充:以下分析均由本项目调用ChatGPT一键生成,如果有不准确的地方全怪GPT)
3
-
4
- ## [0/10] 程序摘要: check_proxy.py
5
-
6
- 这个程序是一个用来检查代理服务器是否有效的 Python 程序代码。程序文件名为 check_proxy.py。其中定义了一个函数 check_proxy,该函数接收一个代理配置信息 proxies,使用 requests 库向一个代理服务器发送请求,获取该代理的所在地信息并返回。如果请求超时或者异常,该函数将返回一个代理无效的结果。
7
-
8
- 程序代码分为两个部分,首先是 check_proxy 函数的定义部分,其次是程序文件的入口部分,在该部分代码中,程序从 config_private.py 文件或者 config.py 文件中加载代理配置信息,然后调用 check_proxy 函数来检测代理服务器是否有效。如果配置文件 config_private.py 存在,则会加载其中的代理配置信息,否则会从 config.py 文件中读取。
9
-
10
- ## [1/10] 程序摘要: config.py
11
-
12
- 本程序文件名为config.py,主要功能是存储应用所需的常量和配置信息。
13
-
14
- 其中,包含了应用所需的OpenAI API密钥、API接口地址、网络代理设置、超时设置、网络端口和OpenAI模型选择等信息,在运行应用前需要进行相应的配置。在未配置网络代理时,程序给出了相应的警告提示。
15
-
16
- 此外,还包含了一个检查函数,用于检查是否忘记修改API密钥。
17
-
18
- 总之,config.py文件是应用中的一个重要配置文件,用来存储应用所需的常量和配置信息,需要在应用运行前进行相应的配置。
19
-
20
- ## [2/10] 程序摘要: config_private.py
21
-
22
- 该文件是一个配置文件,命名为config_private.py。它是一个Python脚本,用于配置OpenAI的API密钥、模型和其它相关设置。该配置文件还可以设置是否使用代理。如果使用代理,需要设置代理协议、地址和端口。在设置代理之后,该文件还包括一些用于测试代理是否正常工作的代码。该文件还包括超时时间、随机端口、重试次数等设置。在文件末尾,还有一个检查代码,如果没有更改API密钥,则抛出异常。
23
-
24
- ## [3/10] 程序摘要: functional.py
25
-
26
- 该程序文件名为 functional.py,其中包含一个名为 get_functionals 的函数,该函数返回一个字典,该字典包含了各种翻译、校对等功能的名称、前缀、后缀以及默认按钮颜色等信息。具体功能包括:英语学术润色、中文学术润色、查找语法错误、中英互译、中译英、学术中译英、英译中、解释代码等。该程序的作用为提供各种翻译、校对等功能的模板,以便后续程序可以直接调用。
27
-
28
- (Author补充:这个文件汇总了模块化的Prompt调用,如果发现了新的好用Prompt,别藏着哦^_^速速PR)
29
-
30
-
31
- ## [4/10] 程序摘要: functional_crazy.py
32
-
33
- 这个程序文件 functional_crazy.py 导入了一些 python 模块,并提供了一个函数 get_crazy_functionals(),该函数返回不同实验功能的描述和函数。其中,使用的的模块包括:
34
-
35
- - crazy_functions.读文章写摘要 中的 读文章写摘要
36
- - crazy_functions.生成函数注释 中的 批量生成函数注释
37
- - crazy_functions.解析项目源代码 中的 解析项目本身、解析一个Python项目、解析一个C项目的头文件、解析一个C项目
38
- - crazy_functions.高级功能函数模板 中的 高阶功能模板函数
39
-
40
- 返回的实验功能函数包括:
41
-
42
- - "[实验] 请解析并解构此项目本身",包含函数:解析项目本身
43
- - "[实验] 解析整个py项目(配合input输入框)",包含函数:解析一个Python项目
44
- - "[实验] 解析整个C++项目头文件(配合input输入框)",包含函数:解析一个C项目的头文件
45
- - "[实验] 解析整个C++项目(配合input输入框)",包含函数:解析一个C项目
46
- - "[实验] 读tex论文写摘要(配合input输入框)",包含函数:读文章写摘要
47
- - "[实验] 批量生成函数注释(配合input输入框)",包含函数:批量生成函数注释
48
- - "[实验] 实验功能函数模板",包含函数:高阶功能模板函数
49
-
50
- 这些函数用于系统开发和测试,方便开发者进行特定程序语言后台功能开发的测试和实验,增加系统可靠稳定性和用户友好性。
51
-
52
- (Author补充:这个文件汇总了模块化的函数,如此设计以方便任何新功能的加入)
53
-
54
- ## [5/10] 程序摘要: main.py
55
-
56
- 该程序是一个基于Gradio框架的聊天机器人应用程序。用户可以通过输入问题来获取答案,并与聊天机器人进行对话。该应用程序还集成了一些实验性功能模块,用户可以通过上传本地文件或点击相关按钮来使用这些模块。程序还可以生成对话日志,并且具有一些外观上的调整。在运行时,它会自动打开一个网页并在本地启动服务器。
57
-
58
-
59
- ## [6/10] 程序摘要: predict.py
60
-
61
- 该程序文件名为predict.py,主要是针对一个基于ChatGPT的聊天机器人进行交互和预测。
62
-
63
- 第一部分是导入所需的库和配置文件。
64
-
65
- 第二部分是一个用于获取Openai返回的完整错误信息的函数。
66
-
67
- 第三部分是用于一次性完成向ChatGPT发送请求和等待回复的函数。
68
-
69
- 第四部分是用于基础的对话功能的函数,通过stream参数可以选择是否显示中间的过程。
70
-
71
- 第五部分是用于整合所需信息和选择LLM模型生成的HTTP请求。
72
-
73
- (Author补充:主要是predict_no_ui和predict两个函数。前者不用stream,方便、高效、易用。后者用stream,展现效果好。)
74
-
75
- ## [7/10] 程序摘要: show_math.py
76
-
77
- 这是一个名为show_math.py的Python程序文件,主要用于将Markdown-LaTeX混合文本转换为HTML格式,并包括MathML数学公式。程序使用latex2mathml.converter库将LaTeX公式转换为MathML格式,并使用正则表达式递归地翻译输入的Markdown-LaTeX混合文本。程序包括转换成双美元符号($$)形式、转换成单美元符号($)形式、转换成\[\]形式以及转换成\(\)形式的LaTeX数学公式。如果转换中出现错误,程序将返回相应的错误消息。
78
-
79
- ## [8/10] 程序摘要: theme.py
80
-
81
- 这是一个名为theme.py的程序文件,用于设置Gradio界面的颜色和字体主题。该文件中定义了一个名为adjust_theme()的函数,其作用是返回一个Gradio theme对象,设置了Gradio界面的颜色和字体主题。在该函数里面,使用了Graido可用的颜色列表,主要参数包括primary_hue、neutral_hue、font和font_mono等,用于设置Gradio界面的主题色调、字体等。另外,该函数还实现了一些参数的自定义,如input_background_fill_dark、button_transition、button_shadow_hover等,用于设置Gradio界面的渐变、阴影等特效。如果Gradio版本过于陈旧,该函数会抛出异常并返回None。
82
-
83
- ## [9/10] 程序摘要: toolbox.py
84
-
85
- 该文件为Python程序文件,文件名为toolbox.py。主要功能包括:
86
-
87
- 1. 导入markdown、mdtex2html、threading、functools等模块。
88
- 2. 定义函数predict_no_ui_but_counting_down,用于生成对话。
89
- 3. 定义函数write_results_to_file,用于将对话记录生成Markdown文件。
90
- 4. 定义函数regular_txt_to_markdown,将普通文本转换为Markdown格式的文本。
91
- 5. 定义装饰器函数CatchException,用于捕获函数执行异常并返回生成器。
92
- 6. 定义函数report_execption,用于向chatbot中添加错误信息。
93
- 7. 定义函数text_divide_paragraph,用于将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
94
- 8. 定义函数markdown_convertion,用于将Markdown格式的文本转换为HTML格式。
95
- 9. 定义函数format_io,用于将输入和输出解析为HTML格式。
96
- 10. 定义函数find_free_port,用于返回当前系统中可用的未使用端口。
97
- 11. 定义函数extract_archive,用于解压归档文件。
98
- 12. 定义函数find_recent_files,用于查找最近创建的文件。
99
- 13. 定义函数on_file_uploaded,用于处理上传文件的操作。
100
- 14. 定义函数on_report_generated,用于处理生成报告文件的操作。
101
-
102
- ## 程序的整体功能和构架做出概括。然后用一张markdown表格整理每个文件的功能。
103
-
104
- 这是一个基于Gradio框架的聊天机器人应用,支持通过文本聊天来获取答案,并可以使用一系列实验性功能模块,例如生成函数注释、解析项目源代码、读取Latex论文写摘要等。 程序架构分为前端和后端两个部分。前端使用Gradio实现,包括用户输入区域、应答区域、按钮、调用方式等。后端使用Python实现,包括聊天机器人模型、实验性功能模块、模板模块、管理模块、主程序模块等。
105
-
106
- 每个程序文件的功能如下:
107
-
108
- | 文件名 | 功能描述 |
109
- |:----:|:----:|
110
- | check_proxy.py | 检查代理服务器是否有效 |
111
- | config.py | 存储应用所需的常量和配置信息 |
112
- | config_private.py | 存储Openai的API密钥、模型和其他相关设置 |
113
- | functional.py | 提供各种翻译、校对等实用模板 |
114
- | functional_crazy.py | 提供一些实验性质的高级功能 |
115
- | main.py | 基于Gradio框架的聊天机器人应用程序的主程序 |
116
- | predict.py | 用于chatbot预测方案创建,向ChatGPT发送请求和获取回复 |
117
- | show_math.py | 将Markdown-LaTeX混合文本转换为HTML格式,并包括MathML数学公式 |
118
- | theme.py | 设置Gradio界面的颜色和字体主题 |
119
- | toolbox.py | 定义一系列工具函数,用于对输入输出进行格式转换、文件操作、异常捕捉和处理等 |
120
-
121
- 这些程序文件共同组成了一个聊天机器人应用程序的前端和后端实现,使用户可以方便地进行聊天,并可以使用相应的实验功能模块。
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
request_llm/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 如何使用其他大语言模型(dev分支测试中)
2
+
3
+ ## 1. 先运行text-generation
4
+ ``` sh
5
+ # 下载模型( text-generation 这么牛的项目,别忘了给人家star )
6
+ git clone https://github.com/oobabooga/text-generation-webui.git
7
+
8
+ # 安装text-generation的额外依赖
9
+ pip install accelerate bitsandbytes flexgen gradio llamacpp markdown numpy peft requests rwkv safetensors sentencepiece tqdm datasets git+https://github.com/huggingface/transformers
10
+
11
+ # 切换路径
12
+ cd text-generation-webui
13
+
14
+ # 下载模型
15
+ python download-model.py facebook/galactica-1.3b
16
+ # 其他可选如 facebook/opt-1.3b
17
+ # facebook/galactica-6.7b
18
+ # facebook/galactica-120b
19
+ # facebook/pygmalion-1.3b 等
20
+ # 详情见 https://github.com/oobabooga/text-generation-webui
21
+
22
+ # 启动text-generation,注意把模型的斜杠改成下划线
23
+ python server.py --cpu --listen --listen-port 7860 --model facebook_galactica-1.3b
24
+ ```
25
+
26
+ ## 2. 修改config.py
27
+ ``` sh
28
+ # LLM_MODEL格式较复杂 TGUI:[模型]@[ws地址]:[ws端口] , 端口要和上面给定的端口一致
29
+ LLM_MODEL = "TGUI:galactica-1.3b@localhost:7860"
30
+ ```
31
+
32
+ ## 3. 运行!
33
+ ``` sh
34
+ cd chatgpt-academic
35
+ python main.py
36
+ ```
predict.py → request_llm/bridge_chatgpt.py RENAMED
@@ -12,6 +12,7 @@
12
  """
13
 
14
  import json
 
15
  import gradio as gr
16
  import logging
17
  import traceback
@@ -71,12 +72,22 @@ def predict_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
71
  raise ConnectionAbortedError("Json解析不合常规,可能是文本过长" + response.text)
72
 
73
 
74
- def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_prompt=""):
75
  """
76
- 发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免有人中途掐网线。
 
 
 
 
 
 
 
 
 
 
77
  """
 
78
  headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt=sys_prompt, stream=True)
79
-
80
  retry = 0
81
  while True:
82
  try:
@@ -96,13 +107,28 @@ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_pr
96
  except StopIteration: break
97
  if len(chunk)==0: continue
98
  if not chunk.startswith('data:'):
99
- chunk = get_full_error(chunk.encode('utf8'), stream_response)
100
- raise ConnectionAbortedError("OpenAI拒绝了请求:" + chunk.decode())
101
- delta = json.loads(chunk.lstrip('data:'))['choices'][0]["delta"]
 
 
 
 
102
  if len(delta) == 0: break
103
  if "role" in delta: continue
104
- if "content" in delta: result += delta["content"]; print(delta["content"], end='')
 
 
 
 
 
 
 
 
 
105
  else: raise RuntimeError("意外Json结构:"+delta)
 
 
106
  return result
107
 
108
 
@@ -118,11 +144,11 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
118
  additional_fn代表点击的哪个按钮,按钮见functional.py
119
  """
120
  if additional_fn is not None:
121
- import functional
122
- importlib.reload(functional) # 热更新prompt
123
- functional = functional.get_functionals()
124
- if "PreProcess" in functional[additional_fn]: inputs = functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
125
- inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"]
126
 
127
  if stream:
128
  raw_input = inputs
@@ -179,15 +205,17 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
179
  chunk = get_full_error(chunk, stream_response)
180
  error_msg = chunk.decode()
181
  if "reduce the length" in error_msg:
182
- chatbot[-1] = (chatbot[-1][0], "[Local Message] Input (or history) is too long, please reduce input or clear history by refreshing this page.")
183
- history = []
184
  elif "Incorrect API key" in error_msg:
185
- chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key provided.")
 
 
186
  else:
187
  from toolbox import regular_txt_to_markdown
188
- tb_str = regular_txt_to_markdown(traceback.format_exc())
189
- chatbot[-1] = (chatbot[-1][0], f"[Local Message] Json Error \n\n {tb_str} \n\n {regular_txt_to_markdown(chunk.decode()[4:])}")
190
- yield chatbot, history, "Json解析不合常规" + error_msg
191
  return
192
 
193
  def generate_payload(inputs, top_p, temperature, history, system_prompt, stream):
 
12
  """
13
 
14
  import json
15
+ import time
16
  import gradio as gr
17
  import logging
18
  import traceback
 
72
  raise ConnectionAbortedError("Json解析不合常规,可能是文本过长" + response.text)
73
 
74
 
75
+ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_prompt="", observe_window=None):
76
  """
77
+ 发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
78
+ inputs:
79
+ 是本次问询的输入
80
+ sys_prompt:
81
+ 系统静默prompt
82
+ top_p, temperature:
83
+ chatGPT的内部调优参数
84
+ history:
85
+ 是之前的对话列表
86
+ observe_window = None:
87
+ 用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
88
  """
89
+ watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
90
  headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt=sys_prompt, stream=True)
 
91
  retry = 0
92
  while True:
93
  try:
 
107
  except StopIteration: break
108
  if len(chunk)==0: continue
109
  if not chunk.startswith('data:'):
110
+ error_msg = get_full_error(chunk.encode('utf8'), stream_response).decode()
111
+ if "reduce the length" in error_msg:
112
+ raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg)
113
+ else:
114
+ raise RuntimeError("OpenAI拒绝了请求:" + error_msg)
115
+ json_data = json.loads(chunk.lstrip('data:'))['choices'][0]
116
+ delta = json_data["delta"]
117
  if len(delta) == 0: break
118
  if "role" in delta: continue
119
+ if "content" in delta:
120
+ result += delta["content"]
121
+ print(delta["content"], end='')
122
+ if observe_window is not None:
123
+ # 观测窗,把已经获取的数据显示出去
124
+ if len(observe_window) >= 1: observe_window[0] += delta["content"]
125
+ # 看门狗,如果超过期限没有喂狗,则终止
126
+ if len(observe_window) >= 2:
127
+ if (time.time()-observe_window[1]) > watch_dog_patience:
128
+ raise RuntimeError("程序终止。")
129
  else: raise RuntimeError("意外Json结构:"+delta)
130
+ if json_data['finish_reason'] == 'length':
131
+ raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
132
  return result
133
 
134
 
 
144
  additional_fn代表点击的哪个按钮,按钮见functional.py
145
  """
146
  if additional_fn is not None:
147
+ import core_functional
148
+ importlib.reload(core_functional) # 热更新prompt
149
+ core_functional = core_functional.get_core_functions()
150
+ if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
151
+ inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
152
 
153
  if stream:
154
  raw_input = inputs
 
205
  chunk = get_full_error(chunk, stream_response)
206
  error_msg = chunk.decode()
207
  if "reduce the length" in error_msg:
208
+ chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长,或历史数据过长. 历史缓存数据现已释放,您可以请再次尝试.")
209
+ history = [] # 清除历史
210
  elif "Incorrect API key" in error_msg:
211
+ chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由,拒绝服务.")
212
+ elif "exceeded your current quota" in error_msg:
213
+ chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由,拒绝服务.")
214
  else:
215
  from toolbox import regular_txt_to_markdown
216
+ tb_str = '```\n' + traceback.format_exc() + '```'
217
+ chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk.decode()[4:])}")
218
+ yield chatbot, history, "Json异常" + error_msg
219
  return
220
 
221
  def generate_payload(inputs, top_p, temperature, history, system_prompt, stream):
request_llm/bridge_tgui.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Contributed by SagsMug. Modified by binary-husky
3
+ https://github.com/oobabooga/text-generation-webui/pull/175
4
+ '''
5
+
6
+ import asyncio
7
+ import json
8
+ import random
9
+ import string
10
+ import websockets
11
+ import logging
12
+ import time
13
+ import threading
14
+ import importlib
15
+ from toolbox import get_conf
16
+ LLM_MODEL, = get_conf('LLM_MODEL')
17
+
18
+ # "TGUI:galactica-1.3b@localhost:7860"
19
+ model_name, addr_port = LLM_MODEL.split('@')
20
+ assert ':' in addr_port, "LLM_MODEL 格式不正确!" + LLM_MODEL
21
+ addr, port = addr_port.split(':')
22
+
23
+ def random_hash():
24
+ letters = string.ascii_lowercase + string.digits
25
+ return ''.join(random.choice(letters) for i in range(9))
26
+
27
+ async def run(context, max_token=512):
28
+ params = {
29
+ 'max_new_tokens': max_token,
30
+ 'do_sample': True,
31
+ 'temperature': 0.5,
32
+ 'top_p': 0.9,
33
+ 'typical_p': 1,
34
+ 'repetition_penalty': 1.05,
35
+ 'encoder_repetition_penalty': 1.0,
36
+ 'top_k': 0,
37
+ 'min_length': 0,
38
+ 'no_repeat_ngram_size': 0,
39
+ 'num_beams': 1,
40
+ 'penalty_alpha': 0,
41
+ 'length_penalty': 1,
42
+ 'early_stopping': True,
43
+ 'seed': -1,
44
+ }
45
+ session = random_hash()
46
+
47
+ async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket:
48
+ while content := json.loads(await websocket.recv()):
49
+ #Python3.10 syntax, replace with if elif on older
50
+ if content["msg"] == "send_hash":
51
+ await websocket.send(json.dumps({
52
+ "session_hash": session,
53
+ "fn_index": 12
54
+ }))
55
+ elif content["msg"] == "estimation":
56
+ pass
57
+ elif content["msg"] == "send_data":
58
+ await websocket.send(json.dumps({
59
+ "session_hash": session,
60
+ "fn_index": 12,
61
+ "data": [
62
+ context,
63
+ params['max_new_tokens'],
64
+ params['do_sample'],
65
+ params['temperature'],
66
+ params['top_p'],
67
+ params['typical_p'],
68
+ params['repetition_penalty'],
69
+ params['encoder_repetition_penalty'],
70
+ params['top_k'],
71
+ params['min_length'],
72
+ params['no_repeat_ngram_size'],
73
+ params['num_beams'],
74
+ params['penalty_alpha'],
75
+ params['length_penalty'],
76
+ params['early_stopping'],
77
+ params['seed'],
78
+ ]
79
+ }))
80
+ elif content["msg"] == "process_starts":
81
+ pass
82
+ elif content["msg"] in ["process_generating", "process_completed"]:
83
+ yield content["output"]["data"][0]
84
+ # You can search for your desired end indicator and
85
+ # stop generation by closing the websocket here
86
+ if (content["msg"] == "process_completed"):
87
+ break
88
+
89
+
90
+
91
+
92
+
93
+ def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None):
94
+ """
95
+ 发送至chatGPT,流式获取输出。
96
+ 用于基础的对话功能。
97
+ inputs 是本次问询的输入
98
+ top_p, temperature是chatGPT的内部调优参数
99
+ history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
100
+ chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
101
+ additional_fn代表点击的哪个按钮,按钮见functional.py
102
+ """
103
+ if additional_fn is not None:
104
+ import core_functional
105
+ importlib.reload(core_functional) # 热更新prompt
106
+ core_functional = core_functional.get_core_functions()
107
+ if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话)
108
+ inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"]
109
+
110
+ raw_input = "What I would like to say is the following: " + inputs
111
+ logging.info(f'[raw_input] {raw_input}')
112
+ history.extend([inputs, ""])
113
+ chatbot.append([inputs, ""])
114
+ yield chatbot, history, "等待响应"
115
+
116
+ prompt = inputs
117
+ tgui_say = ""
118
+
119
+ mutable = ["", time.time()]
120
+ def run_coorotine(mutable):
121
+ async def get_result(mutable):
122
+ async for response in run(prompt):
123
+ print(response[len(mutable[0]):])
124
+ mutable[0] = response
125
+ if (time.time() - mutable[1]) > 3:
126
+ print('exit when no listener')
127
+ break
128
+ asyncio.run(get_result(mutable))
129
+
130
+ thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True)
131
+ thread_listen.start()
132
+
133
+ while thread_listen.is_alive():
134
+ time.sleep(1)
135
+ mutable[1] = time.time()
136
+ # Print intermediate steps
137
+ if tgui_say != mutable[0]:
138
+ tgui_say = mutable[0]
139
+ history[-1] = tgui_say
140
+ chatbot[-1] = (history[-2], history[-1])
141
+ yield chatbot, history, "status_text"
142
+
143
+ logging.info(f'[response] {tgui_say}')
144
+
145
+
146
+
147
+ def predict_tgui_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
148
+ raw_input = "What I would like to say is the following: " + inputs
149
+ prompt = inputs
150
+ tgui_say = ""
151
+ mutable = ["", time.time()]
152
+ def run_coorotine(mutable):
153
+ async def get_result(mutable):
154
+ async for response in run(prompt, max_token=20):
155
+ print(response[len(mutable[0]):])
156
+ mutable[0] = response
157
+ if (time.time() - mutable[1]) > 3:
158
+ print('exit when no listener')
159
+ break
160
+ asyncio.run(get_result(mutable))
161
+ thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))
162
+ thread_listen.start()
163
+ while thread_listen.is_alive():
164
+ time.sleep(1)
165
+ mutable[1] = time.time()
166
+ tgui_say = mutable[0]
167
+ return tgui_say
requirements.txt CHANGED
@@ -1,3 +1,12 @@
1
  gradio>=3.23
2
  requests[socks]
 
 
 
 
3
  mdtex2html
 
 
 
 
 
 
1
  gradio>=3.23
2
  requests[socks]
3
+ transformers
4
+ python-markdown-math
5
+ beautifulsoup4
6
+ latex2mathml
7
  mdtex2html
8
+ tiktoken
9
+ Markdown
10
+ pymupdf
11
+ openai
12
+ numpy
self_analysis.md ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chatgpt-academic项目自译解报告
2
+ (Author补充:以下分析均由本项目调用ChatGPT一键生成,如果有不准确的地方,全怪GPT😄)
3
+
4
+ ## 对程序的整体功能和构架做出概括。然后用一张markdown表格整理每个文件的功能(包括'check_proxy.py', 'config.py'等)。
5
+
6
+ 整体概括:
7
+
8
+ 该程序是一个基于自然语言处理和机器学习的科学论文辅助工具,主要功能包括聊天机器人、批量总结PDF文档、批量翻译PDF文档、生成函数注释、解析项目源代码等。程序基于 Gradio 构建 Web 服务,并集成了代理和自动更新功能,提高了用户的使用体验。
9
+
10
+ 文件功能表格:
11
+
12
+ | 文件名称 | 功能 |
13
+ | ------------------------------------------------------------ | ------------------------------------------------------------ |
14
+ | .\check_proxy.py | 检查代理设置功能。 |
15
+ | .\config.py | 配置文件,存储程序的基本设置。 |
16
+ | .\config_private.py | 存储代理网络地址的文件。 |
17
+ | .\core_functional.py | 主要的程序逻辑,包括聊天机器人和文件处理。 |
18
+ | .\cradle.py | 程序入口,初始化程序和启动 Web 服务。 |
19
+ | .\crazy_functional.py | 辅助程序功能,包括PDF文档处理、代码处理、函数注释生成等。 |
20
+ | .\main.py | 包含聊天机器人的具体实现。 |
21
+ | .\show_math.py | 处理 LaTeX 公式的函数。 |
22
+ | .\theme.py | 存储 Gradio Web 服务的 CSS 样式文件。 |
23
+ | .\toolbox.py | 提供了一系列工具函数,包括文件读写、网页抓取、解析函数参数、生成 HTML 等。 |
24
+ | ./crazy_functions/crazy_utils.py | 提供各种工具函数,如解析字符串、清洗文本、清理目录结构等。 |
25
+ | ./crazy_functions/\_\_init\_\_.py | crazy_functions 模块的入口文件。 |
26
+ | ./crazy_functions/下载arxiv论文翻译摘要.py | 对 arxiv.org 上的 PDF 论文进行下载和翻译。 |
27
+ | ./crazy_functions/代码重写为全英文_多线程.py | 将代码文件中的中文注释和字符串替换为英文。 |
28
+ | ./crazy_functions/总结word文档.py | 读取 Word 文档并生成摘要。 |
29
+ | ./crazy_functions/批量总结PDF文档.py | 批量读取 PDF 文件并生成摘要。 |
30
+ | ./crazy_functions/批量总结PDF文档pdfminer.py | 使用 pdfminer 库进行 PDF 文件处理。 |
31
+ | ./crazy_functions/批量翻译PDF文档_多线程.py | 使用多线程技术批量翻译 PDF 文件。 |
32
+ | ./crazy_functions/生成函数注释.py | 给 Python 函数自动生成说明文档。 |
33
+ | ./crazy_functions/解析项目源代码.py | 解析项目中的源代码,提取注释和函数名等信息。 |
34
+ | ./crazy_functions/读文章写摘要.py | 读取多个文本文件并生成对应的摘要。 |
35
+ | ./crazy_functions/高级功能函数模板.py | 使用 GPT 模型进行文本处理。 |
36
+
37
+
38
+
39
+ ## [0/22] 程序概述: check_proxy.py
40
+
41
+ 该程序的文件名是check_proxy.py,主要有两个函数:check_proxy和auto_update。
42
+
43
+ check_proxy函数中会借助requests库向一个IP查询API发送请求,并返回该IP的地理位置信息。同时根据返回的数据来判断代理是否有效。
44
+
45
+ auto_update函数主要用于检查程序更新,会从Github获取程序最新的版本信息,如果当前版本和最新版本相差较大,则会提示用户进行更新。该函数中也会依赖requests库进行网络请求。
46
+
47
+ 在程序的开头,还添加了一句防止代理网络影响的代码。程序使用了自己编写的toolbox模块中的get_conf函数来获取代理设置。
48
+
49
+ ## [1/22] 程序概述: config.py
50
+
51
+ 该程序文件是一个Python模块,文件名为config.py。该模块包含了一些变量和配置选项,用���配置一个OpenAI的聊天机器人。具体的配置选项如下:
52
+
53
+ - API_KEY: 密钥,用于连接OpenAI的API。需要填写有效的API密钥。
54
+ - USE_PROXY: 是否使用代理。如果需要使用代理,需要将其改为True。
55
+ - proxies: 代理的协议、地址和端口。
56
+ - CHATBOT_HEIGHT: 聊天机器人对话框的高度。
57
+ - LAYOUT: 聊天机器人对话框的布局,默认为左右布局。
58
+ - TIMEOUT_SECONDS: 发送请求到OpenAI后,等待多久判定为超时。
59
+ - WEB_PORT: 网页的端口,-1代表随机端口。
60
+ - MAX_RETRY: 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制。
61
+ - LLM_MODEL: OpenAI模型选择,目前只对某些用户开放的gpt4。
62
+ - API_URL: OpenAI的API地址。
63
+ - CONCURRENT_COUNT: 使用的线程数。
64
+ - AUTHENTICATION: 用户名和密码,如果需要。
65
+
66
+ ## [2/22] 程序概述: config_private.py
67
+
68
+ 该程序文件名为config_private.py,包含了API_KEY的设置和代理的配置。使用了一个名为API_KEY的常量来存储私人的API密钥。此外,还有一个名为USE_PROXY的常量来标记是否需要使用代理。如果需要代理,则使用了一个名为proxies的字典来存储代理网络的地址,其中包括协议类型、地址和端口。
69
+
70
+ ## [3/22] 程序概述: core_functional.py
71
+
72
+ 该程序文件名为`core_functional.py`,主要是定义了一些核心功能函数,包括英语和中文学术润色、查找语法错误、中译英、学术中英互译、英译中、找图片和解释代码等。每个功能都有一个`Prefix`属性和`Suffix`属性,`Prefix`是指在用户输入的任务前面要显示的文本,`Suffix`是指在任务后面要显示的文本。此外,还有一个`Color`属性指示按钮的颜色,以及一个`PreProcess`函数表示对输入进行预处理的函数。
73
+
74
+ ## [4/22] 程序概述: cradle.py
75
+
76
+ 该程序文件名为cradle.py,主要功能是检测当前版本与远程最新版本是否一致,如果不一致则输出新版本信息并提示更新。其流程大致如下:
77
+
78
+ 1. 导入相关模块与自定义工具箱函数get_conf
79
+ 2. 读取配置文件中的代理proxies
80
+ 3. 使用requests模块请求远程版本信息(url为https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version)并加载为json格式
81
+ 4. 获取远程版本号、是否显示新功能信息、新功能内容
82
+ 5. 读取本地版本文件version并加载为json格式
83
+ 6. 获取当前版本号
84
+ 7. 比较当前版本与远程版本,如果远程版本号比当前版本号高0.05以上,则输出新版本信息并提示更新
85
+ 8. 如果不需要更新,则直接返回
86
+
87
+ ## [5/22] 程序概述: crazy_functional.py
88
+
89
+ 该程序文件名为.\crazy_functional.py,主要定义了一个名为get_crazy_functions()的函数,该函数返回一个字典类型的变量function_plugins,其中包含了一些函数插件。
90
+
91
+ 一些重要的函数插件包括:
92
+
93
+ - 读文章写摘要:可以自动读取Tex格式的论文,并生成其摘要。
94
+
95
+ - 批量生成函数注释:可以批量生成Python函数的文档注释。
96
+
97
+ - 解析项目源代码:可以解析Python、C++、Golang、Java及React项目的源代码。
98
+
99
+ - 批量总结PDF文档:可以对PDF文档进行批量总结,以提取其中的关键信息。
100
+
101
+ - 一键下载arxiv论文并翻译摘要:可以自动下载arxiv.org网站上的PDF论文,并翻译生成其摘要。
102
+
103
+ - 批量翻译PDF文档(多线程):可以对PDF文档进行批量翻译,并使用多线程方式提高翻译效率。
104
+
105
+ ## [6/22] 程序概述: main.py
106
+
107
+ 本程序为一个基于 Gradio 和 GPT-3 的交互式聊天机器人,文件名为 main.py。其中主要功能包括:
108
+
109
+ 1. 使用 Gradio 建立 Web 界面,实现用户与聊天机器人的交互;
110
+ 2. 通过 bridge_chatgpt 模块,利用 GPT-3 模型实现聊天机器人的逻辑;
111
+ 3. 提供一些基础功能和高级函数插件,用户可以通过按钮选择使用;
112
+ 4. 提供文档格式转变、外观调整以及代理和自动更新等功能。
113
+
114
+ 程序的主要流程为:
115
+
116
+ 1. 导入所需的库和模块,并通过 get_conf 函数获取配置信息;
117
+ 2. 设置 Gradio 界面的各个组件,包括聊天窗口、输入区、功能区、函数插件区等;
118
+ 3. 注册各个组件的回调函数,包括用户输入、信号按钮等,实现机器人逻辑的交互;
119
+ 4. 通过 Gradio 的 queue 函数和 launch 函数启动 Web 服务,并提供聊天机器人的功能。
120
+
121
+ 此外,程序还提供了代理和自动更新功能,可以确保用户的使用体验。
122
+
123
+ ## [7/22] 程序概述: show_math.py
124
+
125
+ 该程序是一个Python脚本,文件名为show_math.py。它转换Markdown和LaTeX混合语法到带MathML的HTML。程序使用latex2mathml模块来实现从LaTeX到MathML的转换,将符号转换为HTML实体以批量处理。程序利用正则表达式和递归函数的方法处理不同形式的LaTeX语法,支持以下四种情况:$$形式、$形式、\[..]形式和\(...\)形式。如果无法转��某个公式,则在该位置插入一条错误消息。最后,程序输出HTML字符串。
126
+
127
+ ## [8/22] 程序概述: theme.py
128
+
129
+ 该程序文件为一个Python脚本,其功能是调整Gradio应用的主题和样式,包括字体、颜色、阴影、背景等等。在程序中,使用了Gradio提供的默认颜色主题,并针对不同元素设置了相应的样式属性,以达到美化显示的效果。此外,程序中还包含了一段高级CSS样式代码,针对表格、列表、聊天气泡、行内代码等元素进行了样式设定。
130
+
131
+ ## [9/22] 程序概述: toolbox.py
132
+
133
+ 此程序文件主要包含了一系列用于聊天机器人开发的实用工具函数和装饰器函数。主要函数包括:
134
+
135
+ 1. ArgsGeneralWrapper:一个装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
136
+
137
+ 2. get_reduce_token_percent:一个函数,用于计算自然语言处理时会出现的token溢出比例。
138
+
139
+ 3. predict_no_ui_but_counting_down:一个函数,调用聊天接口,并且保留了一定的界面心跳功能,即当对话太长时,会自动采用二分法截断。
140
+
141
+ 4. write_results_to_file:一个函数,将对话记录history生成Markdown格式的文本,并写入文件中。
142
+
143
+ 5. regular_txt_to_markdown:一个函数,将普通文本转换为Markdown格式的文本。
144
+
145
+ 6. CatchException:一个装饰器函数,捕捉函数调度中的异常,并封装到一个生成器中返回,并显示到聊天当中。
146
+
147
+ 7. HotReload:一个装饰器函数,实现函数插件的热更新。
148
+
149
+ 8. report_execption:一个函数,向chatbot中添加错误信息。
150
+
151
+ 9. text_divide_paragraph:一个函数,将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
152
+
153
+ 10. markdown_convertion:一个函数,将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
154
+
155
+ 11. close_up_code_segment_during_stream:一个函数,用于在gpt输出代码的中途,即输出了前面的```,但还没输出完后面的```,补上后面的```。
156
+
157
+ 12. format_io:一个函数,将输入和输出解析为HTML格式。将输出部分的Markdown和数学公式转换为HTML格式。
158
+
159
+ 13. find_free_port:一个函数,返回当前系统中可用的未使用端口。
160
+
161
+ 14. extract_archive:一个函数,解压缩文件。
162
+
163
+ 15. find_recent_files:一个函数,查找目录下一分钟内创建的文件。
164
+
165
+ 16. on_file_uploaded:一个函数,响应用户上传的文件。
166
+
167
+ ## [10/22] 程序概述: crazy_functions\crazy_utils.py
168
+
169
+ 这是一个名为"crazy_utils.py"的Python程序文件,包含了两个函数:
170
+ 1. `breakdown_txt_to_satisfy_token_limit()`:接受文本字符串、计算文本单词数量的函数和单词数量限制作为输入参数,将长文本拆分成合适的长度,以满足单词数量限制。这个函数使用一个递归方法去拆分长文本。
171
+ 2. `breakdown_txt_to_satisfy_token_limit_for_pdf()`:类似于`breakdown_txt_to_satisfy_token_limit()`,但是它使用一个不同的递归方法来拆分长文本,以满足PDF文档中的需求。当出现无法继续拆分的情况时,该函数将使用一个中文句号标记插入文本来截断长文本。如果还是无法拆分,则会引发运行时异常。
172
+
173
+ ## [11/22] 程序概述: crazy_functions\__init__.py
174
+
175
+ 这个程序文件是一个 Python 的包,包名为 "crazy_functions",并且是其中的一个子模块 "__init__.py"。该包中可能包含多个函数或类,用于实现各种疯狂的功能。由于该文件的具体代码没有给出,因此无法进一步确定该包中的功能。通常情况下,一个包应该具有 __init__.py、__main__.py 和其它相关的模块文件,用于实现该包的各种功能。
176
+
177
+ ## [12/22] 程序概述: crazy_functions\下载arxiv论文翻译摘要.py
178
+
179
+ 这个程序实现的功能是下载arxiv论文并翻译摘要,文件名为`下载arxiv论文翻译摘要.py`。这个程序引入了`requests`、`unicodedata`、`os`、`re`等Python标准库,以及`pdfminer`、`bs4`等第三方库。其中`download_arxiv_`函数主要实现了从arxiv网站下载论文的功能,包括解析链接、获取论文信息、下载论文和生成文件名等,`get_name`函数则是为了从arxiv网站中获取论文信息创建的辅助函数。`下载arxiv论文并翻译摘要`函数则是实现了从下载好的PDF文件中提取摘要,然后使用预先训练的GPT模型翻译为中文的功能。同时,该函数还会将历史记录写入文件中。函数还会通过`CatchException`函数来捕获程序中出现的异常信息。
180
+
181
+ ## [13/22] 程序概述: crazy_functions\代码重写为全英文_多线程.py
182
+
183
+ 该程序文件为一个Python多线程程序,文件名为"crazy_functions\代码重写为全英文_多线程.py"。该程序使用了多线程技术,将一个大任务拆成多个小任务,同时执行,提高运行效率。
184
+
185
+ 程序的主要功能是将Python文件中的中文转换为英文,同时���转换后的代码输出。程序先清空历史记录,然后尝试导入openai和transformers等依赖库。程序接下来会读取当前路径下的.py文件和crazy_functions文件夹中的.py文件,并将其整合成一个文件清单。随后程序会使用GPT2模型进行中英文的翻译,并将结果保存在本地路径下的"gpt_log/generated_english_version"文件夹中。程序最终会生成一个任务执行报告。
186
+
187
+ 需要注意的是,该程序依赖于"request_llm"和"toolbox"库以及本地的"crazy_utils"模块。
188
+
189
+ ## [14/22] 程序概述: crazy_functions\总结word文档.py
190
+
191
+ 该程序文件是一个 Python 脚本文件,文件名为 ./crazy_functions/总结word文档.py。该脚本是一个函数插件,提供了名为“总结word文档”的函数。该函数的主要功能是批量读取给定文件夹下的 Word 文档文件,并使用 GPT 模型生成对每个文件的概述和意见建议。其中涉及到了读取 Word 文档、使用 GPT 模型等操作,依赖于许多第三方库。该文件也提供了导入依赖的方法,使用该脚本需要安装依赖库 python-docx 和 pywin32。函数功能实现的过程中,使用了一些用于调试的变量(如 fast_debug),可在需要时设置为 True。该脚本文件也提供了对程序功能和贡献者的注释。
192
+
193
+ ## [15/22] 程序概述: crazy_functions\批量总结PDF文档.py
194
+
195
+ 该程序文件名为 `./crazy_functions\批量总结PDF文档.py`,主要实现了批量处理PDF文档的功能。具体实现了以下几个函数:
196
+
197
+ 1. `is_paragraph_break(match)`:根据给定的匹配结果判断换行符是否表示段落分隔。
198
+ 2. `normalize_text(text)`:通过将文本特殊符号转换为其基本形式来对文本进行归一化处理。
199
+ 3. `clean_text(raw_text)`:对从 PDF 提取出的原始文本进行清洗和格式化处理。
200
+ 4. `解析PDF(file_manifest, project_folder, top_p, temperature, chatbot, history, systemPromptTxt)`:对给定的PDF文件进行分析并生成相应的概述。
201
+ 5. `批量总结PDF文档(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT)`:批量处理PDF文件,对其进行摘要生成。
202
+
203
+ 其中,主要用到了第三方库`pymupdf`对PDF文件进行处理。程序通过调用`fitz.open`函数打开PDF文件,使用`page.get_text()`方法获取PDF文本内容。然后,使用`clean_text`函数对文本进行清洗和格式化处理,生成最终的摘要。最后,调用`write_results_to_file`函数将历史记录写入文件并输出。
204
+
205
+ ## [16/22] 程序概述: crazy_functions\批量总结PDF文档pdfminer.py
206
+
207
+ 这个程序文件名是./crazy_functions\批量总结PDF文档pdfminer.py,是一个用于批量读取PDF文件,解析其中的内容,并对其进行概括的程序。程序中引用了pdfminer和beautifulsoup4等Python库,读取PDF文件并将其转化为文本内容,然后利用GPT模型生成摘要语言,最终输出一个中文和英文的摘要。程序还有一些错误处理的代码,会输出错误信息。
208
+
209
+ ## [17/22] 程序概述: crazy_functions\批量翻译PDF文档_多线程.py
210
+
211
+ 这是一个 Python 程序文件,文件名为 `批量翻译PDF文档_多线程.py`,包含多个函数。主要功能是批量处理 PDF 文档,解析其中的文本,进行清洗和格式化处理,并使用 OpenAI 的 GPT 模型进行翻译。其中使用了多线程技术来提高程序的效率和并行度。
212
+
213
+ ## [18/22] 程序概述: crazy_functions\生成函数注释.py
214
+
215
+ 该程序文件名为./crazy_functions\生成函数注释.py。该文件包含两个函数,分别为`生成函数注释`和`批量生成函数注释`。
216
+
217
+ 函数`生成函数注释`包含参数`file_manifest`、`project_folder`、`top_p`、`temperature`、`chatbot`、`history`和`systemPromptTxt`。其中,`file_manifest`为一个包含待处理文件路径的列表,`project_folder`表示项目文件夹路径,`top_p`和`temperature`是GPT模型参数,`chatbot`为与用户交互的聊天机器人,`history`记录聊天机器人与用户的历史记录,`systemPromptTxt`为聊天机器人发送信息前的提示语。`生成函数注释`通过读取文件内容,并调用GPT模型对文件中的所有函数生成注释,最后使用markdown表格输出结果。函数中还包含一些条件判断和计时器,以及调用其他自定义模块的函数。
218
+
219
+ 函数`批量生成函数注释`包含参数`txt`、`top_p`、`temperature`、`chatbot`、`history`、`systemPromptTxt`和`WEB_PORT`。其中,`txt`表示用户输入的项目文件夹路径,其他参数含义与`生成函数注释`中相同。`批量生成函数注释`主要是通过解析项目文件夹,获取所有待处理文件的路径,并调用函数`生成函数注释`对每个文件进行处理,最终生成注释表格输出给用户。
220
+
221
+ ## [19/22] 程序概述: crazy_functions\解析项目源代码.py
222
+
223
+ 该程序文件包含了多个函数,用于解析不同类型的项目,如Python项目、C项目、Java项目等。其中,最核心的函��是`解析源代码()`,它会对给定的一组文件进行分析,并返回对应的结果。具体流程如下:
224
+
225
+ 1. 遍历所有待分析的文件,对每个文件进行如下处理:
226
+
227
+ 1.1 从文件中读取代码内容,构造成一个字符串。
228
+
229
+ 1.2 构造一条GPT请求,向`predict_no_ui_but_counting_down()`函数发送请求,等待GPT回复。
230
+
231
+ 1.3 将GPT回复添加到机器人会话列表中,更新历史记录。
232
+
233
+ 1.4 如果不是快速调试模式,则等待2秒钟,继续分析下一个文件。
234
+
235
+ 2. 如果所有文件都分析完成,则向机器人会话列表中添加一条新消息,提示用户整个分析过程已经结束。
236
+
237
+ 3. 返回机器人会话列表和历史记录。
238
+
239
+ 除此之外,该程序文件还定义了若干个函数,用于针对不同类型的项目进行解析。这些函数会按照不同的方式调用`解析源代码()`函数。例如,对于Python项目,只需要分析.py文件;对于C项目,需要同时分析.h和.cpp文件等。每个函数中都会首先根据给定的项目路径读取相应的文件,然后调用`解析源代码()`函数进行分析。
240
+
241
+ ## [20/22] 程序概述: crazy_functions\读文章写摘要.py
242
+
243
+ 该程序文件为一个名为“读文章写摘要”的Python函数,用于解析项目文件夹中的.tex文件,并使用GPT模型生成文章的中英文摘要。函数使用了request_llm.bridge_chatgpt和toolbox模块中的函数,并包含两个子函数:解析Paper和CatchException。函数参数包括txt,top_p,temperature,chatbot,history,systemPromptTxt和WEB_PORT。执行过程中函数首先清空历史,然后根据项目文件夹中的.tex文件列表,对每个文件调用解析Paper函数生成中文摘要,最后根据所有文件的中文摘要,调用GPT模型生成英文摘要。函数运行过程中会将结果写入文件并返回聊天机器人和历史记录。
244
+
245
+ ## [21/22] 程序概述: crazy_functions\高级功能函数模板.py
246
+
247
+ 该程序文件为一个高级功能函数模板,文件名为"./crazy_functions\高级功能函数模板.py"。
248
+
249
+ 该文件导入了两个模块,分别是"request_llm.bridge_chatgpt"和"toolbox"。其中"request_llm.bridge_chatgpt"模块包含了一个函数"predict_no_ui_long_connection",该函数用于请求GPT模型进行对话生成。"toolbox"模块包含了三个函数,分别是"catchException"、"report_exception"和"write_results_to_file"函数,这三个函数主要用于异常处理和日志记录等。
250
+
251
+ 该文件定义了一个名为"高阶功能模板函数"的函数,并通过"decorator"装饰器将该函数装饰为一个异常处理函数,可以处理函数执行过程中出现的错误。该函数的作用是生成历史事件查询的问题,并向用户询问历史中哪些事件发生在指定日期,并索要相关图片。在查询完所有日期后,该函数返回所有历史事件及其相关图片的列表。其中,该函数的输入参数包括:
252
+
253
+ 1. txt: 一个字符串,表示当前消息的文本内容。
254
+ 2. top_p: 一个浮点数,表示GPT模型生成文本时的"top_p"参数。
255
+ 3. temperature: 一个浮点数,表示GPT模型生成文本时的"temperature"参数。
256
+ 4. chatbot: 一个列表,表示当前对话的记录列表。
257
+ 5. history: 一个列表,表示当前对话的历史记录列表。
258
+ 6. systemPromptTxt: 一个字符串,表示当前对话的系统提示信息。
259
+ 7. WEB_PORT: 一个整数,表示当前应用程序的WEB端口号。
260
+
261
+ 该函数在执行过程中,会先清空历史记录,以免输入溢出。然后,它会循环5次,生成5个历史事件查询的问题,并向用户请求输入相关信息。每次询问不携带之前的询问历史。在生成每个问题时,该函数会向"chatbot"列表中添加一条消息记录,并设置该记录的初始状态为"[Local Message] waiting gpt response."。然后,该函数会调用"predict_no_ui_long_connection"函数向GPT模型请求生成一段文本,并将生成的文本作为回答。如果请求过程中出现异常,该函数会忽略异常。最后,该函数将问题和回答添加到"chatbot"列表和"history"列表中,并将"chatbot"和"history"列表作为函数的返回值返回。
262
+
theme.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
 
3
  # gradio可用颜色列表
4
  # gr.themes.utils.colors.slate (石板色)
@@ -24,14 +24,16 @@ import gradio as gr
24
  # gr.themes.utils.colors.pink (粉红色)
25
  # gr.themes.utils.colors.rose (玫瑰色)
26
 
 
27
  def adjust_theme():
28
- try:
29
- color_er = gr.themes.utils.colors.pink
30
- set_theme = gr.themes.Default(
31
- primary_hue=gr.themes.utils.colors.orange,
32
- neutral_hue=gr.themes.utils.colors.gray,
33
- font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
34
- font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
 
35
  set_theme.set(
36
  # Colors
37
  input_background_fill_dark="*neutral_800",
@@ -77,6 +79,78 @@ def adjust_theme():
77
  button_cancel_text_color=color_er.c600,
78
  button_cancel_text_color_dark="white",
79
  )
80
- except:
81
- set_theme = None; print('gradio版本较旧, 不能自定义字体和颜色')
 
82
  return set_theme
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
 
3
  # gradio可用颜色列表
4
  # gr.themes.utils.colors.slate (石板色)
 
24
  # gr.themes.utils.colors.pink (粉红色)
25
  # gr.themes.utils.colors.rose (玫瑰色)
26
 
27
+
28
  def adjust_theme():
29
+ try:
30
+ color_er = gr.themes.utils.colors.fuchsia
31
+ set_theme = gr.themes.Default(
32
+ primary_hue=gr.themes.utils.colors.orange,
33
+ neutral_hue=gr.themes.utils.colors.gray,
34
+ font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui",
35
+ "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
36
+ font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
37
  set_theme.set(
38
  # Colors
39
  input_background_fill_dark="*neutral_800",
 
79
  button_cancel_text_color=color_er.c600,
80
  button_cancel_text_color_dark="white",
81
  )
82
+ except:
83
+ set_theme = None
84
+ print('gradio版本较旧, 不能自定义字体和颜色')
85
  return set_theme
86
+
87
+
88
+ advanced_css = """
89
+ /* 设置表格的外边距为1em,内部单元格之间边框合并,空单元格显示. */
90
+ .markdown-body table {
91
+ margin: 1em 0;
92
+ border-collapse: collapse;
93
+ empty-cells: show;
94
+ }
95
+
96
+ /* 设置表格单元格的内边距为5px,边框粗细为1.2px,颜色为--border-color-primary. */
97
+ .markdown-body th, .markdown-body td {
98
+ border: 1.2px solid var(--border-color-primary);
99
+ padding: 5px;
100
+ }
101
+
102
+ /* 设置表头背景颜色为rgba(175,184,193,0.2),透明度为0.2. */
103
+ .markdown-body thead {
104
+ background-color: rgba(175,184,193,0.2);
105
+ }
106
+
107
+ /* 设置表头单元格的内边距为0.5em和0.2em. */
108
+ .markdown-body thead th {
109
+ padding: .5em .2em;
110
+ }
111
+
112
+ /* 去掉列表前缀的默认间距,使其与文本线对齐. */
113
+ .markdown-body ol, .markdown-body ul {
114
+ padding-inline-start: 2em !important;
115
+ }
116
+
117
+ /* 设定聊天气泡的样式,包括圆角、最大宽度和阴影等. */
118
+ [class *= "message"] {
119
+ border-radius: var(--radius-xl) !important;
120
+ /* padding: var(--spacing-xl) !important; */
121
+ /* font-size: var(--text-md) !important; */
122
+ /* line-height: var(--line-md) !important; */
123
+ /* min-height: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl)); */
124
+ /* min-width: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl)); */
125
+ }
126
+ [data-testid = "bot"] {
127
+ max-width: 95%;
128
+ /* width: auto !important; */
129
+ border-bottom-left-radius: 0 !important;
130
+ }
131
+ [data-testid = "user"] {
132
+ max-width: 100%;
133
+ /* width: auto !important; */
134
+ border-bottom-right-radius: 0 !important;
135
+ }
136
+
137
+ /* 行内代码的背景设为淡灰色,设定圆角和间距. */
138
+ .markdown-body code {
139
+ display: inline;
140
+ white-space: break-spaces;
141
+ border-radius: 6px;
142
+ margin: 0 2px 0 2px;
143
+ padding: .2em .4em .1em .4em;
144
+ background-color: rgba(175,184,193,0.2);
145
+ }
146
+ /* 设定代码块的样式,包括背景颜色、内、外边距、圆角。 */
147
+ .markdown-body pre code {
148
+ display: block;
149
+ overflow: auto;
150
+ white-space: pre;
151
+ background-color: rgba(175,184,193,0.2);
152
+ border-radius: 10px;
153
+ padding: 1em;
154
+ margin: 1em 2em 1em 0.5em;
155
+ }
156
+ """
toolbox.py CHANGED
@@ -1,67 +1,134 @@
1
- import markdown, mdtex2html, threading, importlib, traceback
2
- from show_math import convert as convert_math
3
- from functools import wraps
 
 
 
 
 
 
 
4
 
5
- def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt=''):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
  调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
 
 
 
 
 
 
 
8
  """
9
  import time
10
- from predict import predict_no_ui
11
  from toolbox import get_conf
12
  TIMEOUT_SECONDS, MAX_RETRY = get_conf('TIMEOUT_SECONDS', 'MAX_RETRY')
13
  # 多线程的时候,需要一个mutable结构在不同线程之间传递信息
14
  # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
15
  mutable = [None, '']
16
  # multi-threading worker
 
17
  def mt(i_say, history):
18
  while True:
19
  try:
20
- mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 
 
 
 
 
21
  break
22
- except ConnectionAbortedError as e:
 
 
 
23
  if len(history) > 0:
24
- history = [his[len(his)//2:] for his in history if his is not None]
25
- mutable[1] = 'Warning! History conversation is too long, cut into half. '
26
  else:
27
- i_say = i_say[:len(i_say)//2]
28
- mutable[1] = 'Warning! Input file is too long, cut into half. '
29
  except TimeoutError as e:
30
- mutable[0] = '[Local Message] Failed with timeout.'
31
  raise TimeoutError
 
 
 
32
  # 创建新线程发出http请求
33
- thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start()
 
34
  # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
35
  cnt = 0
36
  while thread_name.is_alive():
37
  cnt += 1
38
- chatbot[-1] = (i_say_show_user, f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt%4)))
 
39
  yield chatbot, history, '正常'
40
  time.sleep(1)
41
  # 把gpt的输出从mutable中取出来
42
  gpt_say = mutable[0]
43
- if gpt_say=='[Local Message] Failed with timeout.': raise TimeoutError
 
44
  return gpt_say
45
 
 
46
  def write_results_to_file(history, file_name=None):
47
  """
48
  将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
49
  """
50
- import os, time
 
51
  if file_name is None:
52
  # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
53
- file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 
54
  os.makedirs('./gpt_log/', exist_ok=True)
55
- with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f:
56
  f.write('# chatGPT 分析报告\n')
57
  for i, content in enumerate(history):
58
- if i%2==0: f.write('## ')
 
 
 
 
 
 
59
  f.write(content)
60
  f.write('\n\n')
61
  res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
62
  print(res)
63
  return res
64
 
 
65
  def regular_txt_to_markdown(text):
66
  """
67
  将普通文本转换为Markdown格式的文本。
@@ -71,6 +138,7 @@ def regular_txt_to_markdown(text):
71
  text = text.replace('\n\n\n', '\n\n')
72
  return text
73
 
 
74
  def CatchException(f):
75
  """
76
  装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
@@ -83,17 +151,35 @@ def CatchException(f):
83
  from check_proxy import check_proxy
84
  from toolbox import get_conf
85
  proxies, = get_conf('proxies')
86
- tb_str = regular_txt_to_markdown(traceback.format_exc())
87
- chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n {tb_str} \n\n 当前代理可用性: \n\n {check_proxy(proxies)}")
 
 
 
88
  yield chatbot, history, f'异常 {e}'
89
  return decorated
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def report_execption(chatbot, history, a, b):
92
  """
93
  向chatbot中添加错误信息
94
  """
95
  chatbot.append((a, b))
96
- history.append(a); history.append(b)
 
 
97
 
98
  def text_divide_paragraph(text):
99
  """
@@ -110,26 +196,105 @@ def text_divide_paragraph(text):
110
  text = "</br>".join(lines)
111
  return text
112
 
 
113
  def markdown_convertion(txt):
114
  """
115
  将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
116
  """
117
- if ('$' in txt) and ('```' not in txt):
118
- return markdown.markdown(txt,extensions=['fenced_code','tables']) + '<br><br>' + \
119
- markdown.markdown(convert_math(txt, splitParagraphs=False),extensions=['fenced_code','tables'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
- return markdown.markdown(txt,extensions=['fenced_code','tables'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  def format_io(self, y):
125
  """
126
  将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
127
  """
128
- if y is None or y == []: return []
 
129
  i_ask, gpt_reply = y[-1]
130
- i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波
 
 
131
  y[-1] = (
132
- None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code','tables']),
 
133
  None if gpt_reply is None else markdown_convertion(gpt_reply)
134
  )
135
  return y
@@ -164,8 +329,33 @@ def extract_archive(file_path, dest_dir):
164
  with tarfile.open(file_path, 'r:*') as tarobj:
165
  tarobj.extractall(path=dest_dir)
166
  print("Successfully extracted tar archive to {}".format(dest_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
- return
 
 
169
 
170
  def find_recent_files(directory):
171
  """
@@ -180,59 +370,101 @@ def find_recent_files(directory):
180
 
181
  for filename in os.listdir(directory):
182
  file_path = os.path.join(directory, filename)
183
- if file_path.endswith('.log'): continue
184
- created_time = os.path.getctime(file_path)
 
185
  if created_time >= one_minute_ago:
186
- if os.path.isdir(file_path): continue
 
187
  recent_files.append(file_path)
188
 
189
  return recent_files
190
 
191
 
192
  def on_file_uploaded(files, chatbot, txt):
193
- if len(files) == 0: return chatbot, txt
194
- import shutil, os, time, glob
 
 
 
 
195
  from toolbox import extract_archive
196
- try: shutil.rmtree('./private_upload/')
197
- except: pass
 
 
198
  time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
199
  os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
 
200
  for file in files:
201
  file_origin_name = os.path.basename(file.orig_name)
202
  shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
203
- extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
204
- dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
205
- moved_files = [fp for fp in glob.glob('private_upload/**/*', recursive=True)]
 
206
  txt = f'private_upload/{time_tag}'
207
  moved_files_str = '\t\n\n'.join(moved_files)
208
- chatbot.append(['我上传了文件,请查收',
209
- f'[Local Message] 收到以下文件: \n\n{moved_files_str}\n\n调用路径参数已自动修正到: \n\n{txt}\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'])
 
 
210
  return chatbot, txt
211
 
212
 
213
  def on_report_generated(files, chatbot):
214
  from toolbox import find_recent_files
215
  report_files = find_recent_files('gpt_log')
216
- if len(report_files) == 0: return report_files, chatbot
 
217
  # files.extend(report_files)
218
- chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧文件上传区,请查收。'])
219
  return report_files, chatbot
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def get_conf(*args):
222
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
223
  res = []
224
  for arg in args:
225
- try: r = getattr(importlib.import_module('config_private'), arg)
226
- except: r = getattr(importlib.import_module('config'), arg)
227
  res.append(r)
228
- # 在读取API_KEY时,检查一下是不是忘了改config
229
- if arg=='API_KEY' and len(r) != 51:
230
- assert False, "正确的API_KEY密钥是51位,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
231
- "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
232
  return res
233
 
 
234
  def clear_line_break(txt):
235
  txt = txt.replace('\n', ' ')
236
  txt = txt.replace(' ', ' ')
237
  txt = txt.replace(' ', ' ')
238
- return txt
 
 
 
 
 
 
 
 
 
1
+ import markdown
2
+ import mdtex2html
3
+ import threading
4
+ import importlib
5
+ import traceback
6
+ import importlib
7
+ import inspect
8
+ import re
9
+ from latex2mathml.converter import convert as tex2mathml
10
+ from functools import wraps, lru_cache
11
 
12
+
13
+ def ArgsGeneralWrapper(f):
14
+ """
15
+ 装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
16
+ """
17
+ def decorated(txt, txt2, *args, **kwargs):
18
+ txt_passon = txt
19
+ if txt == "" and txt2 != "":
20
+ txt_passon = txt2
21
+ yield from f(txt_passon, *args, **kwargs)
22
+ return decorated
23
+
24
+
25
+ def get_reduce_token_percent(text):
26
+ try:
27
+ # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
28
+ pattern = r"(\d+)\s+tokens\b"
29
+ match = re.findall(pattern, text)
30
+ EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
31
+ max_limit = float(match[0]) - EXCEED_ALLO
32
+ current_tokens = float(match[1])
33
+ ratio = max_limit/current_tokens
34
+ assert ratio > 0 and ratio < 1
35
+ return ratio, str(int(current_tokens-max_limit))
36
+ except:
37
+ return 0.5, '不详'
38
+
39
+
40
+ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=True):
41
  """
42
  调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
43
+ i_say: 当前输入
44
+ i_say_show_user: 显示到对话界面上的当前输入,例如,输入整个文件时,你绝对不想把文件的内容都糊到对话界面上
45
+ chatbot: 对话界面句柄
46
+ top_p, temperature: gpt参数
47
+ history: gpt参数 对话历史
48
+ sys_prompt: gpt参数 sys_prompt
49
+ long_connection: 是否采用更稳定的连接方式(推荐)
50
  """
51
  import time
52
+ from request_llm.bridge_chatgpt import predict_no_ui, predict_no_ui_long_connection
53
  from toolbox import get_conf
54
  TIMEOUT_SECONDS, MAX_RETRY = get_conf('TIMEOUT_SECONDS', 'MAX_RETRY')
55
  # 多线程的时候,需要一个mutable结构在不同线程之间传递信息
56
  # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
57
  mutable = [None, '']
58
  # multi-threading worker
59
+
60
  def mt(i_say, history):
61
  while True:
62
  try:
63
+ if long_connection:
64
+ mutable[0] = predict_no_ui_long_connection(
65
+ inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
66
+ else:
67
+ mutable[0] = predict_no_ui(
68
+ inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
69
  break
70
+ except ConnectionAbortedError as token_exceeded_error:
71
+ # 尝试计算比例,尽可能多地保留文本
72
+ p_ratio, n_exceed = get_reduce_token_percent(
73
+ str(token_exceeded_error))
74
  if len(history) > 0:
75
+ history = [his[int(len(his) * p_ratio):]
76
+ for his in history if his is not None]
77
  else:
78
+ i_say = i_say[: int(len(i_say) * p_ratio)]
79
+ mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。'
80
  except TimeoutError as e:
81
+ mutable[0] = '[Local Message] 请求超时。'
82
  raise TimeoutError
83
+ except Exception as e:
84
+ mutable[0] = f'[Local Message] 异常:{str(e)}.'
85
+ raise RuntimeError(f'[Local Message] 异常:{str(e)}.')
86
  # 创建新线程发出http请求
87
+ thread_name = threading.Thread(target=mt, args=(i_say, history))
88
+ thread_name.start()
89
  # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
90
  cnt = 0
91
  while thread_name.is_alive():
92
  cnt += 1
93
+ chatbot[-1] = (i_say_show_user,
94
+ f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt % 4)))
95
  yield chatbot, history, '正常'
96
  time.sleep(1)
97
  # 把gpt的输出从mutable中取出来
98
  gpt_say = mutable[0]
99
+ if gpt_say == '[Local Message] Failed with timeout.':
100
+ raise TimeoutError
101
  return gpt_say
102
 
103
+
104
  def write_results_to_file(history, file_name=None):
105
  """
106
  将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
107
  """
108
+ import os
109
+ import time
110
  if file_name is None:
111
  # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
112
+ file_name = 'chatGPT分析报告' + \
113
+ time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
114
  os.makedirs('./gpt_log/', exist_ok=True)
115
+ with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
116
  f.write('# chatGPT 分析报告\n')
117
  for i, content in enumerate(history):
118
+ try: # 这个bug没找到触发条件,暂时先这样顶一下
119
+ if type(content) != str:
120
+ content = str(content)
121
+ except:
122
+ continue
123
+ if i % 2 == 0:
124
+ f.write('## ')
125
  f.write(content)
126
  f.write('\n\n')
127
  res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
128
  print(res)
129
  return res
130
 
131
+
132
  def regular_txt_to_markdown(text):
133
  """
134
  将普通文本转换为Markdown格式的文本。
 
138
  text = text.replace('\n\n\n', '\n\n')
139
  return text
140
 
141
+
142
  def CatchException(f):
143
  """
144
  装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
 
151
  from check_proxy import check_proxy
152
  from toolbox import get_conf
153
  proxies, = get_conf('proxies')
154
+ tb_str = '```\n' + traceback.format_exc() + '```'
155
+ if chatbot is None or len(chatbot) == 0:
156
+ chatbot = [["插件调度异常", "异常原因"]]
157
+ chatbot[-1] = (chatbot[-1][0],
158
+ f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
159
  yield chatbot, history, f'异常 {e}'
160
  return decorated
161
 
162
+
163
+ def HotReload(f):
164
+ """
165
+ 装饰器函数,实现函数插件热更新
166
+ """
167
+ @wraps(f)
168
+ def decorated(*args, **kwargs):
169
+ fn_name = f.__name__
170
+ f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name)
171
+ yield from f_hot_reload(*args, **kwargs)
172
+ return decorated
173
+
174
+
175
  def report_execption(chatbot, history, a, b):
176
  """
177
  向chatbot中添加错误信息
178
  """
179
  chatbot.append((a, b))
180
+ history.append(a)
181
+ history.append(b)
182
+
183
 
184
  def text_divide_paragraph(text):
185
  """
 
196
  text = "</br>".join(lines)
197
  return text
198
 
199
+
200
  def markdown_convertion(txt):
201
  """
202
  将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
203
  """
204
+ pre = '<div class="markdown-body">'
205
+ suf = '</div>'
206
+ markdown_extension_configs = {
207
+ 'mdx_math': {
208
+ 'enable_dollar_delimiter': True,
209
+ 'use_gitlab_delimiters': False,
210
+ },
211
+ }
212
+ find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
213
+
214
+ def tex2mathml_catch_exception(content, *args, **kwargs):
215
+ try:
216
+ content = tex2mathml(content, *args, **kwargs)
217
+ except:
218
+ content = content
219
+ return content
220
+
221
+ def replace_math_no_render(match):
222
+ content = match.group(1)
223
+ if 'mode=display' in match.group(0):
224
+ content = content.replace('\n', '</br>')
225
+ return f"<font color=\"#00FF00\">$$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$$</font>"
226
+ else:
227
+ return f"<font color=\"#00FF00\">$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$</font>"
228
+
229
+ def replace_math_render(match):
230
+ content = match.group(1)
231
+ if 'mode=display' in match.group(0):
232
+ if '\\begin{aligned}' in content:
233
+ content = content.replace('\\begin{aligned}', '\\begin{array}')
234
+ content = content.replace('\\end{aligned}', '\\end{array}')
235
+ content = content.replace('&', ' ')
236
+ content = tex2mathml_catch_exception(content, display="block")
237
+ return content
238
+ else:
239
+ return tex2mathml_catch_exception(content)
240
+
241
+ def markdown_bug_hunt(content):
242
+ """
243
+ 解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
244
+ """
245
+ content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">', '<script type="math/tex; mode=display">')
246
+ content = content.replace('</script>\n</script>', '</script>')
247
+ return content
248
+
249
+
250
+ if ('$' in txt) and ('```' not in txt): # 有$标识的公式符号,且没有代码段```的标识
251
+ # convert everything to html format
252
+ split = markdown.markdown(text='---')
253
+ convert_stage_1 = markdown.markdown(text=txt, extensions=['mdx_math', 'fenced_code', 'tables', 'sane_lists'], extension_configs=markdown_extension_configs)
254
+ convert_stage_1 = markdown_bug_hunt(convert_stage_1)
255
+ # re.DOTALL: Make the '.' special character match any character at all, including a newline; without this flag, '.' will match anything except a newline. Corresponds to the inline flag (?s).
256
+ # 1. convert to easy-to-copy tex (do not render math)
257
+ convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
258
+ # 2. convert to rendered equation
259
+ convert_stage_2_2, n = re.subn(find_equation_pattern, replace_math_render, convert_stage_1, flags=re.DOTALL)
260
+ # cat them together
261
+ return pre + convert_stage_2_1 + f'{split}' + convert_stage_2_2 + suf
262
  else:
263
+ return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables', 'sane_lists']) + suf
264
+
265
+
266
+ def close_up_code_segment_during_stream(gpt_reply):
267
+ """
268
+ 在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
269
+ """
270
+ if '```' not in gpt_reply:
271
+ return gpt_reply
272
+ if gpt_reply.endswith('```'):
273
+ return gpt_reply
274
+
275
+ # 排除了以上两个情况,我们
276
+ segments = gpt_reply.split('```')
277
+ n_mark = len(segments) - 1
278
+ if n_mark % 2 == 1:
279
+ # print('输出代码片段中!')
280
+ return gpt_reply+'\n```'
281
+ else:
282
+ return gpt_reply
283
 
284
 
285
  def format_io(self, y):
286
  """
287
  将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
288
  """
289
+ if y is None or y == []:
290
+ return []
291
  i_ask, gpt_reply = y[-1]
292
+ i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波
293
+ gpt_reply = close_up_code_segment_during_stream(
294
+ gpt_reply) # 当代码输出半截的时候,试着补上后个```
295
  y[-1] = (
296
+ None if i_ask is None else markdown.markdown(
297
+ i_ask, extensions=['fenced_code', 'tables']),
298
  None if gpt_reply is None else markdown_convertion(gpt_reply)
299
  )
300
  return y
 
329
  with tarfile.open(file_path, 'r:*') as tarobj:
330
  tarobj.extractall(path=dest_dir)
331
  print("Successfully extracted tar archive to {}".format(dest_dir))
332
+
333
+ # 第三方库,需要预先pip install rarfile
334
+ # 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
335
+ elif file_extension == '.rar':
336
+ try:
337
+ import rarfile
338
+ with rarfile.RarFile(file_path) as rf:
339
+ rf.extractall(path=dest_dir)
340
+ print("Successfully extracted rar archive to {}".format(dest_dir))
341
+ except:
342
+ print("Rar format requires additional dependencies to install")
343
+ return '\n\n需要安装pip install rarfile来解压rar文件'
344
+
345
+ # 第三方库,需要预先pip install py7zr
346
+ elif file_extension == '.7z':
347
+ try:
348
+ import py7zr
349
+ with py7zr.SevenZipFile(file_path, mode='r') as f:
350
+ f.extractall(path=dest_dir)
351
+ print("Successfully extracted 7z archive to {}".format(dest_dir))
352
+ except:
353
+ print("7z format requires additional dependencies to install")
354
+ return '\n\n需要安装pip install py7zr来解压7z文件'
355
  else:
356
+ return ''
357
+ return ''
358
+
359
 
360
  def find_recent_files(directory):
361
  """
 
370
 
371
  for filename in os.listdir(directory):
372
  file_path = os.path.join(directory, filename)
373
+ if file_path.endswith('.log'):
374
+ continue
375
+ created_time = os.path.getmtime(file_path)
376
  if created_time >= one_minute_ago:
377
+ if os.path.isdir(file_path):
378
+ continue
379
  recent_files.append(file_path)
380
 
381
  return recent_files
382
 
383
 
384
  def on_file_uploaded(files, chatbot, txt):
385
+ if len(files) == 0:
386
+ return chatbot, txt
387
+ import shutil
388
+ import os
389
+ import time
390
+ import glob
391
  from toolbox import extract_archive
392
+ try:
393
+ shutil.rmtree('./private_upload/')
394
+ except:
395
+ pass
396
  time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
397
  os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
398
+ err_msg = ''
399
  for file in files:
400
  file_origin_name = os.path.basename(file.orig_name)
401
  shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
402
+ err_msg += extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
403
+ dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
404
+ moved_files = [fp for fp in glob.glob(
405
+ 'private_upload/**/*', recursive=True)]
406
  txt = f'private_upload/{time_tag}'
407
  moved_files_str = '\t\n\n'.join(moved_files)
408
+ chatbot.append(['我上传了文件,请查收',
409
+ f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
410
+ f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
411
+ f'\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'+err_msg])
412
  return chatbot, txt
413
 
414
 
415
  def on_report_generated(files, chatbot):
416
  from toolbox import find_recent_files
417
  report_files = find_recent_files('gpt_log')
418
+ if len(report_files) == 0:
419
+ return None, chatbot
420
  # files.extend(report_files)
421
+ chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。'])
422
  return report_files, chatbot
423
 
424
+
425
+ @lru_cache(maxsize=128)
426
+ def read_single_conf_with_lru_cache(arg):
427
+ try:
428
+ r = getattr(importlib.import_module('config_private'), arg)
429
+ except:
430
+ r = getattr(importlib.import_module('config'), arg)
431
+ # 在读取API_KEY时,检查一下是不是忘了改config
432
+ if arg == 'API_KEY':
433
+ # 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
434
+ API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r)
435
+ if API_MATCH:
436
+ print(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
437
+ else:
438
+ assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
439
+ "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
440
+ if arg == 'proxies':
441
+ if r is None:
442
+ print('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问。建议:检查USE_PROXY选项是否修改。')
443
+ else:
444
+ print('[PROXY] 网络代理状态:已配置。配置信息如下:', r)
445
+ assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
446
+ return r
447
+
448
+
449
  def get_conf(*args):
450
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
451
  res = []
452
  for arg in args:
453
+ r = read_single_conf_with_lru_cache(arg)
 
454
  res.append(r)
 
 
 
 
455
  return res
456
 
457
+
458
  def clear_line_break(txt):
459
  txt = txt.replace('\n', ' ')
460
  txt = txt.replace(' ', ' ')
461
  txt = txt.replace(' ', ' ')
462
+ return txt
463
+
464
+
465
+ class DummyWith():
466
+ def __enter__(self):
467
+ return self
468
+
469
+ def __exit__(self, exc_type, exc_value, traceback):
470
+ return
version ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "version": 2.4,
3
+ "show_feature": true,
4
+ "new_feature": "(1)新增PDF全文翻译功能; (2)新增输入区切换位置的功能; (3)新增垂直布局选项; (4)多线程函数插件优化。"
5
+ }