stanley commited on
Commit
ab328e5
·
1 Parent(s): 3f7c7ec

pushin to huggin

Browse files
PyPatchMatch/.DS_Store ADDED
Binary file (6.15 kB). View file
 
PyPatchMatch/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /build/
2
+ /*.so
3
+ __pycache__
4
+ *.py[cod]
PyPatchMatch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jiayuan Mao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PyPatchMatch/Makefile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Makefile
3
+ # Jiayuan Mao, 2019-01-09 13:59
4
+ #
5
+
6
+ SRC_DIR = csrc
7
+ INC_DIR = csrc
8
+ OBJ_DIR = build/obj
9
+ TARGET = libpatchmatch.so
10
+
11
+ LIB_TARGET = $(TARGET)
12
+ INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
13
+
14
+ CXX = $(ENVIRONMENT_OPTIONS) g++
15
+ CXXFLAGS = -std=c++14
16
+ CXXFLAGS += -Ofast -ffast-math -w
17
+ # CXXFLAGS += -g
18
+ CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
19
+ CXXFLAGS += $(INCLUDE_DIR)
20
+ LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
21
+
22
+
23
+ CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
24
+ OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
25
+ DEPFILES = $(OBJS:.o=.d)
26
+
27
+ .PHONY: all clean rebuild test
28
+
29
+ all: $(LIB_TARGET)
30
+
31
+ $(OBJ_DIR)/%.o: %.cpp
32
+ @echo "[CC] $< ..."
33
+ @$(CXX) -c $< $(CXXFLAGS) -o $@
34
+
35
+ $(OBJ_DIR)/%.d: %.cpp
36
+ @mkdir -pv $(dir $@)
37
+ @echo "[dep] $< ..."
38
+ @$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
39
+
40
+ sinclude $(DEPFILES)
41
+
42
+ $(LIB_TARGET): $(OBJS)
43
+ @echo "[link] $(LIB_TARGET) ..."
44
+ @$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
45
+
46
+ clean:
47
+ rm -rf $(OBJ_DIR) $(LIB_TARGET)
48
+
49
+ rebuild:
50
+ +@make clean
51
+ +@make
52
+
53
+ # vim:ft=make
54
+ #
PyPatchMatch/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PatchMatch based Inpainting
2
+ =====================================
3
+ This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
4
+ This implementation is heavily based on the implementation by Younesse ANDAM:
5
+ (younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
6
+
7
+ Usage
8
+ -------------------------------------
9
+
10
+ You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
11
+ shared library `libpatchmatch.so`.
12
+
13
+ For Python users (example available at `examples/py_example.py`)
14
+
15
+ ```python
16
+ import patch_match
17
+
18
+ image = ... # either a numpy ndarray or a PIL Image object.
19
+ mask = ... # either a numpy ndarray or a PIL Image object.
20
+ result = patch_match.inpaint(image, mask, patch_size=5)
21
+ ```
22
+
23
+ For C++ users (examples available at `examples/cpp_example.cpp`)
24
+
25
+ ```cpp
26
+ #include "inpaint.h"
27
+
28
+ int main() {
29
+ cv::Mat image = ...
30
+ cv::Mat mask = ...
31
+
32
+ cv::Mat result = Inpainting(image, mask, 5).run();
33
+
34
+ return 0;
35
+ }
36
+ ```
37
+
38
+
39
+ README and COPYRIGHT by Younesse ANDAM
40
+ -------------------------------------
41
+ @Author: Younesse ANDAM
42
+
43
+ @Contact: [email protected]
44
+
45
+ Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
46
+ The algorithm is presented in the following paper
47
+ PatchMatch A Randomized Correspondence Algorithm
48
+ for Structural Image Editing
49
+ by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
50
+ ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
51
+
52
+ For more information please refer to
53
+ http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
54
+
55
+ Copyright (c) 2010-2011
56
+
57
+
58
+ Requirements
59
+ -------------------------------------
60
+
61
+ To run the project you need to install Opencv library and link it to your project.
62
+ Opencv can be download it here
63
+ http://opencv.org/downloads.html
64
+
PyPatchMatch/csrc/inpaint.cpp ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <opencv2/imgcodecs.hpp>
4
+ #include <opencv2/imgproc.hpp>
5
+ #include <opencv2/highgui.hpp>
6
+
7
+ #include "inpaint.h"
8
+
9
+ namespace {
10
+ static std::vector<double> kDistance2Similarity;
11
+
12
+ void init_kDistance2Similarity() {
13
+ double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
14
+ int length = (PatchDistanceMetric::kDistanceScale + 1);
15
+ kDistance2Similarity.resize(length);
16
+ for (int i = 0; i < length; ++i) {
17
+ double t = (double) i / length;
18
+ int j = (int) (100 * t);
19
+ int k = j + 1;
20
+ double vj = (j < 11) ? base[j] : 0;
21
+ double vk = (k < 11) ? base[k] : 0;
22
+ kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
23
+ }
24
+ }
25
+
26
+
27
+ inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
28
+ if (source.is_masked(ys, xs)) return;
29
+ if (source.is_globally_masked(ys, xs)) return;
30
+
31
+ auto source_ptr = source.get_image(ys, xs);
32
+ auto target_ptr = target.ptr<double>(yt, xt);
33
+
34
+ #pragma unroll
35
+ for (int c = 0; c < 3; ++c)
36
+ target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
37
+ target_ptr[3] += weight;
38
+ }
39
+ }
40
+
41
+ /**
42
+ * This algorithme uses a version proposed by Xavier Philippeau.
43
+ */
44
+
45
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
46
+ : m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
47
+ _initialize_pyramid();
48
+ }
49
+
50
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
51
+ : m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
52
+ _initialize_pyramid();
53
+ }
54
+
55
+ void Inpainting::_initialize_pyramid() {
56
+ auto source = m_initial;
57
+ m_pyramid.push_back(source);
58
+ while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
59
+ source = source.downsample();
60
+ m_pyramid.push_back(source);
61
+ }
62
+
63
+ if (kDistance2Similarity.size() == 0) {
64
+ init_kDistance2Similarity();
65
+ }
66
+ }
67
+
68
+ cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
69
+ srand(random_seed);
70
+ const int nr_levels = m_pyramid.size();
71
+
72
+ MaskedImage source, target;
73
+ for (int level = nr_levels - 1; level >= 0; --level) {
74
+ if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
75
+
76
+ source = m_pyramid[level];
77
+
78
+ if (level == nr_levels - 1) {
79
+ target = source.clone();
80
+ target.clear_mask();
81
+ m_source2target = NearestNeighborField(source, target, m_distance_metric);
82
+ m_target2source = NearestNeighborField(target, source, m_distance_metric);
83
+ } else {
84
+ m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
85
+ m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
86
+ }
87
+
88
+ if (verbose) std::cerr << "Initialization done." << std::endl;
89
+
90
+ if (verbose_visualize) {
91
+ auto visualize_size = m_initial.size();
92
+ cv::Mat source_visualize(visualize_size, m_initial.image().type());
93
+ cv::resize(source.image(), source_visualize, visualize_size);
94
+ cv::imshow("Source", source_visualize);
95
+ cv::Mat target_visualize(visualize_size, m_initial.image().type());
96
+ cv::resize(target.image(), target_visualize, visualize_size);
97
+ cv::imshow("Target", target_visualize);
98
+ cv::waitKey(0);
99
+ }
100
+
101
+ target = _expectation_maximization(source, target, level, verbose);
102
+ }
103
+
104
+ return target.image();
105
+ }
106
+
107
+ // EM-Like algorithm (see "PatchMatch" - page 6).
108
+ // Returns a double sized target image (unless level = 0).
109
+ MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
110
+ const int nr_iters_em = 1 + 2 * level;
111
+ const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
112
+ const int patch_size = m_distance_metric->patch_size();
113
+
114
+ MaskedImage new_source, new_target;
115
+
116
+ for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
117
+ if (iter_em != 0) {
118
+ m_source2target.set_target(new_target);
119
+ m_target2source.set_source(new_target);
120
+ target = new_target;
121
+ }
122
+
123
+ if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
124
+
125
+ auto size = source.size();
126
+ for (int i = 0; i < size.height; ++i) {
127
+ for (int j = 0; j < size.width; ++j) {
128
+ if (!source.contains_mask(i, j, patch_size)) {
129
+ m_source2target.set_identity(i, j);
130
+ m_target2source.set_identity(i, j);
131
+ }
132
+ }
133
+ }
134
+ if (verbose) std::cerr << " NNF minimization started." << std::endl;
135
+ m_source2target.minimize(nr_iters_nnf);
136
+ m_target2source.minimize(nr_iters_nnf);
137
+ if (verbose) std::cerr << " NNF minimization finished." << std::endl;
138
+
139
+ // Instead of upsizing the final target, we build the last target from the next level source image.
140
+ // Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
141
+ bool upscaled = false;
142
+ if (level >= 1 && iter_em == nr_iters_em - 1) {
143
+ new_source = m_pyramid[level - 1];
144
+ new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
145
+ upscaled = true;
146
+ } else {
147
+ new_source = m_pyramid[level];
148
+ new_target = target.clone();
149
+ }
150
+
151
+ auto vote = cv::Mat(new_target.size(), CV_64FC4);
152
+ vote.setTo(cv::Scalar::all(0));
153
+
154
+ // Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
155
+ _expectation_step(m_source2target, 1, vote, new_source, upscaled);
156
+ if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
157
+ _expectation_step(m_target2source, 0, vote, new_source, upscaled);
158
+ if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
159
+
160
+ // Compile votes and update pixel values.
161
+ _maximization_step(new_target, vote);
162
+ if (verbose) std::cerr << " Minimization step finished." << std::endl;
163
+ }
164
+
165
+ return new_target;
166
+ }
167
+
168
+ // Expectation step: vote for best estimations of each pixel.
169
+ void Inpainting::_expectation_step(
170
+ const NearestNeighborField &nnf, bool source2target,
171
+ cv::Mat &vote, const MaskedImage &source, bool upscaled
172
+ ) {
173
+ auto source_size = nnf.source_size();
174
+ auto target_size = nnf.target_size();
175
+ const int patch_size = m_distance_metric->patch_size();
176
+
177
+ for (int i = 0; i < source_size.height; ++i) {
178
+ for (int j = 0; j < source_size.width; ++j) {
179
+ if (nnf.source().is_globally_masked(i, j)) continue;
180
+ int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
181
+ double w = kDistance2Similarity[dp];
182
+
183
+ for (int di = -patch_size; di <= patch_size; ++di) {
184
+ for (int dj = -patch_size; dj <= patch_size; ++dj) {
185
+ int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
186
+ if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
187
+ if (nnf.source().is_globally_masked(ys, xs)) continue;
188
+ if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
189
+ if (nnf.target().is_globally_masked(yt, xt)) continue;
190
+
191
+ if (!source2target) {
192
+ std::swap(ys, yt);
193
+ std::swap(xs, xt);
194
+ }
195
+
196
+ if (upscaled) {
197
+ for (int uy = 0; uy < 2; ++uy) {
198
+ for (int ux = 0; ux < 2; ++ux) {
199
+ _weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
200
+ }
201
+ }
202
+ } else {
203
+ _weighted_copy(source, ys, xs, vote, yt, xt, w);
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+ // Maximization Step: maximum likelihood of target pixel.
212
+ void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
213
+ auto target_size = target.size();
214
+ for (int i = 0; i < target_size.height; ++i) {
215
+ for (int j = 0; j < target_size.width; ++j) {
216
+ const double *source_ptr = vote.ptr<double>(i, j);
217
+ unsigned char *target_ptr = target.get_mutable_image(i, j);
218
+
219
+ if (target.is_globally_masked(i, j)) {
220
+ continue;
221
+ }
222
+
223
+ if (source_ptr[3] > 0) {
224
+ unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
225
+ unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
226
+ unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
227
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
228
+ } else {
229
+ target.set_mask(i, j, 0);
230
+ }
231
+ }
232
+ }
233
+ }
234
+
PyPatchMatch/csrc/inpaint.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ class Inpainting {
9
+ public:
10
+ Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
11
+ Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
12
+ cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
13
+
14
+ private:
15
+ void _initialize_pyramid(void);
16
+ MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
17
+ void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
18
+ void _maximization_step(MaskedImage &target, const cv::Mat &vote);
19
+
20
+ MaskedImage m_initial;
21
+ std::vector<MaskedImage> m_pyramid;
22
+
23
+ NearestNeighborField m_source2target;
24
+ NearestNeighborField m_target2source;
25
+ const PatchDistanceMetric *m_distance_metric;
26
+ };
27
+
PyPatchMatch/csrc/masked_image.cpp ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "masked_image.h"
2
+ #include <algorithm>
3
+ #include <iostream>
4
+
5
+ const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
6
+ const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
7
+
8
+ bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
9
+ auto mask_size = size();
10
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
11
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
12
+ int yy = y + dy, xx = x + dx;
13
+ if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
14
+ if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
15
+ }
16
+ }
17
+ }
18
+ return false;
19
+ }
20
+
21
+ MaskedImage MaskedImage::downsample() const {
22
+ const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
23
+ const auto &kernel = MaskedImage::kDownsampleKernel;
24
+
25
+ const auto size = this->size();
26
+ const auto new_size = cv::Size(size.width / 2, size.height / 2);
27
+
28
+ auto ret = MaskedImage(new_size.width, new_size.height);
29
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
30
+ for (int y = 0; y < size.height - 1; y += 2) {
31
+ for (int x = 0; x < size.width - 1; x += 2) {
32
+ int r = 0, g = 0, b = 0, ksum = 0;
33
+ bool is_gmasked = true;
34
+
35
+ for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
36
+ for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
37
+ int yy = y + dy, xx = x + dx;
38
+ if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
39
+ if (!is_globally_masked(yy, xx)) {
40
+ is_gmasked = false;
41
+ }
42
+ if (!is_masked(yy, xx)) {
43
+ auto source_ptr = get_image(yy, xx);
44
+ int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
45
+ r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
46
+ ksum += k;
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
53
+
54
+ if (!m_global_mask.empty()) {
55
+ ret.set_global_mask(y / 2, x / 2, is_gmasked);
56
+ }
57
+ if (ksum > 0) {
58
+ auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
59
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
60
+ ret.set_mask(y / 2, x / 2, 0);
61
+ } else {
62
+ ret.set_mask(y / 2, x / 2, 1);
63
+ }
64
+ }
65
+ }
66
+
67
+ return ret;
68
+ }
69
+
70
+ MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
71
+ const auto size = this->size();
72
+ auto ret = MaskedImage(new_w, new_h);
73
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
74
+ for (int y = 0; y < new_h; ++y) {
75
+ for (int x = 0; x < new_w; ++x) {
76
+ int yy = y * size.height / new_h;
77
+ int xx = x * size.width / new_w;
78
+
79
+ if (is_globally_masked(yy, xx)) {
80
+ ret.set_global_mask(y, x, 1);
81
+ ret.set_mask(y, x, 1);
82
+ } else {
83
+ if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
84
+
85
+ if (is_masked(yy, xx)) {
86
+ ret.set_mask(y, x, 1);
87
+ } else {
88
+ auto source_ptr = get_image(yy, xx);
89
+ auto target_ptr = ret.get_mutable_image(y, x);
90
+ for (int c = 0; c < 3; ++c)
91
+ target_ptr[c] = source_ptr[c];
92
+ ret.set_mask(y, x, 0);
93
+ }
94
+ }
95
+ }
96
+ }
97
+
98
+ return ret;
99
+ }
100
+
101
+ MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
102
+ auto ret = upsample(new_w, new_h);
103
+ ret.set_global_mask_mat(new_global_mask);
104
+ return ret;
105
+ }
106
+
107
+ void MaskedImage::compute_image_gradients() {
108
+ if (m_image_grad_computed) {
109
+ return;
110
+ }
111
+
112
+ const auto size = m_image.size();
113
+ m_image_grady = cv::Mat(size, CV_8UC3);
114
+ m_image_gradx = cv::Mat(size, CV_8UC3);
115
+ m_image_grady = cv::Scalar::all(0);
116
+ m_image_gradx = cv::Scalar::all(0);
117
+
118
+ for (int i = 1; i < size.height - 1; ++i) {
119
+ const auto *ptr = m_image.ptr<unsigned char>(i, 0);
120
+ const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
121
+ const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
122
+ const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
123
+ const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
124
+ auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
125
+ auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
126
+ for (int j = 3; j < size.width * 3 - 3; ++j) {
127
+ mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
128
+ mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
129
+ }
130
+ }
131
+
132
+ m_image_grad_computed = true;
133
+ }
134
+
135
+ void MaskedImage::compute_image_gradients() const {
136
+ const_cast<MaskedImage *>(this)->compute_image_gradients();
137
+ }
138
+
PyPatchMatch/csrc/masked_image.h ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+
5
+ class MaskedImage {
6
+ public:
7
+ MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
8
+ // pass
9
+ }
10
+ MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
11
+ // pass
12
+ }
13
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
14
+ // pass
15
+ }
16
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
17
+ m_image(image), m_mask(mask), m_global_mask(global_mask),
18
+ m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
19
+ // pass
20
+ }
21
+ MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
22
+ m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
23
+ m_image = cv::Scalar::all(0);
24
+
25
+ m_mask = cv::Mat(cv::Size(width, height), CV_8U);
26
+ m_mask = cv::Scalar::all(0);
27
+ }
28
+ inline MaskedImage clone() {
29
+ return MaskedImage(
30
+ m_image.clone(), m_mask.clone(), m_global_mask.clone(),
31
+ m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
32
+ );
33
+ }
34
+
35
+ inline cv::Size size() const {
36
+ return m_image.size();
37
+ }
38
+ inline const cv::Mat &image() const {
39
+ return m_image;
40
+ }
41
+ inline const cv::Mat &mask() const {
42
+ return m_mask;
43
+ }
44
+ inline const cv::Mat &global_mask() const {
45
+ return m_global_mask;
46
+ }
47
+ inline const cv::Mat &grady() const {
48
+ assert(m_image_grad_computed);
49
+ return m_image_grady;
50
+ }
51
+ inline const cv::Mat &gradx() const {
52
+ assert(m_image_grad_computed);
53
+ return m_image_gradx;
54
+ }
55
+
56
+ inline void init_global_mask_mat() {
57
+ m_global_mask = cv::Mat(m_mask.size(), CV_8U);
58
+ m_global_mask.setTo(cv::Scalar(0));
59
+ }
60
+ inline void set_global_mask_mat(const cv::Mat &other) {
61
+ m_global_mask = other;
62
+ }
63
+
64
+ inline bool is_masked(int y, int x) const {
65
+ return static_cast<bool>(m_mask.at<unsigned char>(y, x));
66
+ }
67
+ inline bool is_globally_masked(int y, int x) const {
68
+ return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
69
+ }
70
+ inline void set_mask(int y, int x, bool value) {
71
+ m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
72
+ }
73
+ inline void set_global_mask(int y, int x, bool value) {
74
+ m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
75
+ }
76
+ inline void clear_mask() {
77
+ m_mask.setTo(cv::Scalar(0));
78
+ }
79
+
80
+ inline const unsigned char *get_image(int y, int x) const {
81
+ return m_image.ptr<unsigned char>(y, x);
82
+ }
83
+ inline unsigned char *get_mutable_image(int y, int x) {
84
+ return m_image.ptr<unsigned char>(y, x);
85
+ }
86
+
87
+ inline unsigned char get_image(int y, int x, int c) const {
88
+ return m_image.ptr<unsigned char>(y, x)[c];
89
+ }
90
+ inline int get_image_int(int y, int x, int c) const {
91
+ return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
92
+ }
93
+
94
+ bool contains_mask(int y, int x, int patch_size) const;
95
+ MaskedImage downsample() const;
96
+ MaskedImage upsample(int new_w, int new_h) const;
97
+ MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
98
+ void compute_image_gradients();
99
+ void compute_image_gradients() const;
100
+
101
+ static const cv::Size kDownsampleKernelSize;
102
+ static const int kDownsampleKernel[6];
103
+
104
+ private:
105
+ cv::Mat m_image;
106
+ cv::Mat m_mask;
107
+ cv::Mat m_global_mask;
108
+ cv::Mat m_image_grady;
109
+ cv::Mat m_image_gradx;
110
+ bool m_image_grad_computed = false;
111
+ };
112
+
PyPatchMatch/csrc/nnf.cpp ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <cmath>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ /**
9
+ * Nearest-Neighbor Field (see PatchMatch algorithm).
10
+ * This algorithme uses a version proposed by Xavier Philippeau.
11
+ *
12
+ */
13
+
14
+ template <typename T>
15
+ T clamp(T value, T min_value, T max_value) {
16
+ return std::min(std::max(value, min_value), max_value);
17
+ }
18
+
19
+ void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
20
+ auto this_size = source_size();
21
+ for (int i = 0; i < this_size.height; ++i) {
22
+ for (int j = 0; j < this_size.width; ++j) {
23
+ if (m_source.is_globally_masked(i, j)) continue;
24
+
25
+ auto this_ptr = mutable_ptr(i, j);
26
+ int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
27
+ if (distance < PatchDistanceMetric::kDistanceScale) {
28
+ continue;
29
+ }
30
+
31
+ int i_target = 0, j_target = 0;
32
+ for (int t = 0; t < max_retry; ++t) {
33
+ i_target = rand() % this_size.height;
34
+ j_target = rand() % this_size.width;
35
+ if (m_target.is_globally_masked(i_target, j_target)) continue;
36
+
37
+ distance = _distance(i, j, i_target, j_target);
38
+ if (distance < PatchDistanceMetric::kDistanceScale)
39
+ break;
40
+ }
41
+
42
+ this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
43
+ }
44
+ }
45
+ }
46
+
47
+ void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
48
+ const auto &this_size = source_size();
49
+ const auto &other_size = other.source_size();
50
+ double fi = static_cast<double>(this_size.height) / other_size.height;
51
+ double fj = static_cast<double>(this_size.width) / other_size.width;
52
+
53
+ for (int i = 0; i < this_size.height; ++i) {
54
+ for (int j = 0; j < this_size.width; ++j) {
55
+ if (m_source.is_globally_masked(i, j)) continue;
56
+
57
+ int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
58
+ int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
59
+ auto this_value = mutable_ptr(i, j);
60
+ auto other_value = other.ptr(ilow, jlow);
61
+
62
+ this_value[0] = static_cast<int>(other_value[0] * fi);
63
+ this_value[1] = static_cast<int>(other_value[1] * fj);
64
+ this_value[2] = _distance(i, j, this_value[0], this_value[1]);
65
+ }
66
+ }
67
+
68
+ _randomize_field(max_retry, false);
69
+ }
70
+
71
+ void NearestNeighborField::minimize(int nr_pass) {
72
+ const auto &this_size = source_size();
73
+ while (nr_pass--) {
74
+ for (int i = 0; i < this_size.height; ++i)
75
+ for (int j = 0; j < this_size.width; ++j) {
76
+ if (m_source.is_globally_masked(i, j)) continue;
77
+ if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
78
+ }
79
+ for (int i = this_size.height - 1; i >= 0; --i)
80
+ for (int j = this_size.width - 1; j >= 0; --j) {
81
+ if (m_source.is_globally_masked(i, j)) continue;
82
+ if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
83
+ }
84
+ }
85
+ }
86
+
87
+ void NearestNeighborField::_minimize_link(int y, int x, int direction) {
88
+ const auto &this_size = source_size();
89
+ const auto &this_target_size = target_size();
90
+ auto this_ptr = mutable_ptr(y, x);
91
+
92
+ // propagation along the y direction.
93
+ if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
94
+ int yp = at(y - direction, x, 0) + direction;
95
+ int xp = at(y - direction, x, 1);
96
+ int dp = _distance(y, x, yp, xp);
97
+ if (dp < at(y, x, 2)) {
98
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
99
+ }
100
+ }
101
+
102
+ // propagation along the x direction.
103
+ if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
104
+ int yp = at(y, x - direction, 0);
105
+ int xp = at(y, x - direction, 1) + direction;
106
+ int dp = _distance(y, x, yp, xp);
107
+ if (dp < at(y, x, 2)) {
108
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
109
+ }
110
+ }
111
+
112
+ // random search with a progressive step size.
113
+ int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
114
+ while (random_scale > 0) {
115
+ int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
116
+ int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
117
+ yp = clamp(yp, 0, target_size().height - 1);
118
+ xp = clamp(xp, 0, target_size().width - 1);
119
+
120
+ if (m_target.is_globally_masked(yp, xp)) {
121
+ random_scale /= 2;
122
+ }
123
+
124
+ int dp = _distance(y, x, yp, xp);
125
+ if (dp < at(y, x, 2)) {
126
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
127
+ }
128
+ random_scale /= 2;
129
+ }
130
+ }
131
+
132
+ const int PatchDistanceMetric::kDistanceScale = 65535;
133
+ const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
134
+
135
+ namespace {
136
+
137
+ inline int pow2(int i) {
138
+ return i * i;
139
+ }
140
+
141
+ int distance_masked_images(
142
+ const MaskedImage &source, int ys, int xs,
143
+ const MaskedImage &target, int yt, int xt,
144
+ int patch_size
145
+ ) {
146
+ long double distance = 0;
147
+ long double wsum = 0;
148
+
149
+ source.compute_image_gradients();
150
+ target.compute_image_gradients();
151
+
152
+ auto source_size = source.size();
153
+ auto target_size = target.size();
154
+
155
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
156
+ const int yys = ys + dy, yyt = yt + dy;
157
+
158
+ if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
159
+ distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
160
+ wsum += 2 * patch_size + 1;
161
+ continue;
162
+ }
163
+
164
+ const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
165
+ const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
166
+ const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
167
+ const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
168
+
169
+ const unsigned char *p_sgm = nullptr;
170
+ const unsigned char *p_tgm = nullptr;
171
+ if (!source.global_mask().empty()) {
172
+ p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
173
+ p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
174
+ }
175
+
176
+ const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
177
+ const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
178
+ const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
179
+ const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
180
+
181
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
182
+ int xxs = xs + dx, xxt = xt + dx;
183
+ wsum += 1;
184
+
185
+ if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
186
+ distance += PatchSSDDistanceMetric::kSSDScale;
187
+ continue;
188
+ }
189
+
190
+ if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
191
+ distance += PatchSSDDistanceMetric::kSSDScale;
192
+ continue;
193
+ }
194
+
195
+ int ssd = 0;
196
+ for (int c = 0; c < 3; ++c) {
197
+ int s_value = p_si[xxs * 3 + c];
198
+ int t_value = p_ti[xxt * 3 + c];
199
+ int s_gy = p_sgy[xxs * 3 + c];
200
+ int t_gy = p_tgy[xxt * 3 + c];
201
+ int s_gx = p_sgx[xxs * 3 + c];
202
+ int t_gx = p_tgx[xxt * 3 + c];
203
+
204
+ ssd += pow2(static_cast<int>(s_value) - t_value);
205
+ ssd += pow2(static_cast<int>(s_gx) - t_gx);
206
+ ssd += pow2(static_cast<int>(s_gy) - t_gy);
207
+ }
208
+ distance += ssd;
209
+ }
210
+ }
211
+
212
+ distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
213
+
214
+ int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
215
+ if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
216
+ return res;
217
+ }
218
+
219
+ }
220
+
221
+ int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
222
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
223
+ }
224
+
225
+ int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
226
+ fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
227
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
228
+ }
229
+
230
+ int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
231
+ double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
232
+ double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
233
+
234
+ double score1 = sqrt(dx * dx + dy *dy) / m_scale;
235
+ if (score1 < 0 || score1 > 1) score1 = 1;
236
+ score1 *= PatchDistanceMetric::kDistanceScale;
237
+
238
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
239
+ double score = score1 * m_weight + score2 / (1 + m_weight);
240
+ return static_cast<int>(score / (1 + m_weight));
241
+ }
242
+
243
+ int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
244
+ if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
245
+ return PatchDistanceMetric::kDistanceScale;
246
+
247
+ int source_scale = m_ijmap.size().height / source.size().height;
248
+ int target_scale = m_ijmap.size().height / target.size().height;
249
+
250
+ // fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
251
+
252
+ double score1 = PatchDistanceMetric::kDistanceScale;
253
+ if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
254
+ auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
255
+ auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
256
+
257
+ float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
258
+ float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
259
+ score1 = sqrt(di * di + dj *dj) / 0.707;
260
+ if (score1 < 0 || score1 > 1) score1 = 1;
261
+ score1 *= PatchDistanceMetric::kDistanceScale;
262
+ }
263
+
264
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
265
+ double score = score1 * m_weight + score2;
266
+ return int(score / (1 + m_weight));
267
+ }
268
+
PyPatchMatch/csrc/nnf.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+ #include "masked_image.h"
5
+
6
+ class PatchDistanceMetric {
7
+ public:
8
+ PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
9
+ virtual ~PatchDistanceMetric() = default;
10
+
11
+ inline int patch_size() const { return m_patch_size; }
12
+ virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
13
+ static const int kDistanceScale;
14
+
15
+ protected:
16
+ int m_patch_size;
17
+ };
18
+
19
+ class NearestNeighborField {
20
+ public:
21
+ NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
22
+ // pass
23
+ }
24
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
25
+ : m_source(source), m_target(target), m_distance_metric(metric) {
26
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
27
+ _randomize_field(max_retry);
28
+ }
29
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
30
+ : m_source(source), m_target(target), m_distance_metric(metric) {
31
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
32
+ _initialize_field_from(other, max_retry);
33
+ }
34
+
35
+ const MaskedImage &source() const {
36
+ return m_source;
37
+ }
38
+ const MaskedImage &target() const {
39
+ return m_target;
40
+ }
41
+ inline cv::Size source_size() const {
42
+ return m_source.size();
43
+ }
44
+ inline cv::Size target_size() const {
45
+ return m_target.size();
46
+ }
47
+ inline void set_source(const MaskedImage &source) {
48
+ m_source = source;
49
+ }
50
+ inline void set_target(const MaskedImage &target) {
51
+ m_target = target;
52
+ }
53
+
54
+ inline int *mutable_ptr(int y, int x) {
55
+ return m_field.ptr<int>(y, x);
56
+ }
57
+ inline const int *ptr(int y, int x) const {
58
+ return m_field.ptr<int>(y, x);
59
+ }
60
+
61
+ inline int at(int y, int x, int c) const {
62
+ return m_field.ptr<int>(y, x)[c];
63
+ }
64
+ inline int &at(int y, int x, int c) {
65
+ return m_field.ptr<int>(y, x)[c];
66
+ }
67
+ inline void set_identity(int y, int x) {
68
+ auto ptr = mutable_ptr(y, x);
69
+ ptr[0] = y, ptr[1] = x, ptr[2] = 0;
70
+ }
71
+
72
+ void minimize(int nr_pass);
73
+
74
+ private:
75
+ inline int _distance(int source_y, int source_x, int target_y, int target_x) {
76
+ return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
77
+ }
78
+
79
+ void _randomize_field(int max_retry = 20, bool reset = true);
80
+ void _initialize_field_from(const NearestNeighborField &other, int max_retry);
81
+ void _minimize_link(int y, int x, int direction);
82
+
83
+ MaskedImage m_source;
84
+ MaskedImage m_target;
85
+ cv::Mat m_field; // { y_target, x_target, distance_scaled }
86
+ const PatchDistanceMetric *m_distance_metric;
87
+ };
88
+
89
+
90
+ class PatchSSDDistanceMetric : public PatchDistanceMetric {
91
+ public:
92
+ using PatchDistanceMetric::PatchDistanceMetric;
93
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
94
+ static const int kSSDScale;
95
+ };
96
+
97
+ class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
98
+ public:
99
+ DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
100
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
101
+ protected:
102
+ int m_width, m_height;
103
+ };
104
+
105
+ class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
106
+ public:
107
+ RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
108
+ : PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
109
+
110
+ assert(m_dy1 == 0);
111
+ assert(m_dx2 == 0);
112
+ m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
113
+ }
114
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
115
+
116
+ protected:
117
+ double m_dx1, m_dy1, m_dx2, m_dy2;
118
+ double m_scale, m_weight;
119
+ };
120
+
121
+ class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
122
+ public:
123
+ RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
124
+ : PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
125
+
126
+ }
127
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
128
+
129
+ protected:
130
+ cv::Mat m_ijmap;
131
+ double m_width, m_height, m_weight;
132
+ };
133
+
PyPatchMatch/csrc/pyinterface.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "pyinterface.h"
2
+ #include "inpaint.h"
3
+
4
+ static unsigned int PM_seed = 1212;
5
+ static bool PM_verbose = false;
6
+
7
+ int _dtype_py_to_cv(int dtype_py);
8
+ int _dtype_cv_to_py(int dtype_cv);
9
+ cv::Mat _py_to_cv2(PM_mat_t pymat);
10
+ PM_mat_t _cv2_to_py(cv::Mat cvmat);
11
+
12
+ void PM_set_random_seed(unsigned int seed) {
13
+ PM_seed = seed;
14
+ }
15
+
16
+ void PM_set_verbose(int value) {
17
+ PM_verbose = static_cast<bool>(value);
18
+ }
19
+
20
+ void PM_free_pymat(PM_mat_t pymat) {
21
+ free(pymat.data_ptr);
22
+ }
23
+
24
+ PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
25
+ cv::Mat source = _py_to_cv2(source_py);
26
+ cv::Mat mask = _py_to_cv2(mask_py);
27
+ auto metric = PatchSSDDistanceMetric(patch_size);
28
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
29
+ return _cv2_to_py(result);
30
+ }
31
+
32
+ PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
33
+ cv::Mat source = _py_to_cv2(source_py);
34
+ cv::Mat mask = _py_to_cv2(mask_py);
35
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
36
+
37
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
38
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
39
+ return _cv2_to_py(result);
40
+ }
41
+
42
+ PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
43
+ cv::Mat source = _py_to_cv2(source_py);
44
+ cv::Mat mask = _py_to_cv2(mask_py);
45
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
46
+
47
+ auto metric = PatchSSDDistanceMetric(patch_size);
48
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
49
+ return _cv2_to_py(result);
50
+ }
51
+
52
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
53
+ cv::Mat source = _py_to_cv2(source_py);
54
+ cv::Mat mask = _py_to_cv2(mask_py);
55
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
56
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
57
+
58
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
59
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
60
+ return _cv2_to_py(result);
61
+ }
62
+
63
+ int _dtype_py_to_cv(int dtype_py) {
64
+ switch (dtype_py) {
65
+ case PM_UINT8: return CV_8U;
66
+ case PM_INT8: return CV_8S;
67
+ case PM_UINT16: return CV_16U;
68
+ case PM_INT16: return CV_16S;
69
+ case PM_INT32: return CV_32S;
70
+ case PM_FLOAT32: return CV_32F;
71
+ case PM_FLOAT64: return CV_64F;
72
+ }
73
+
74
+ return CV_8U;
75
+ }
76
+
77
+ int _dtype_cv_to_py(int dtype_cv) {
78
+ switch (dtype_cv) {
79
+ case CV_8U: return PM_UINT8;
80
+ case CV_8S: return PM_INT8;
81
+ case CV_16U: return PM_UINT16;
82
+ case CV_16S: return PM_INT16;
83
+ case CV_32S: return PM_INT32;
84
+ case CV_32F: return PM_FLOAT32;
85
+ case CV_64F: return PM_FLOAT64;
86
+ }
87
+
88
+ return PM_UINT8;
89
+ }
90
+
91
+ cv::Mat _py_to_cv2(PM_mat_t pymat) {
92
+ int dtype = _dtype_py_to_cv(pymat.dtype);
93
+ dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
94
+ return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
95
+ }
96
+
97
+ PM_mat_t _cv2_to_py(cv::Mat cvmat) {
98
+ PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
99
+ int dtype = _dtype_cv_to_py(cvmat.depth());
100
+ size_t dsize = cvmat.total() * cvmat.elemSize();
101
+
102
+ void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
103
+ memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
104
+
105
+ return PM_mat_t {data_ptr, shape, dtype};
106
+ }
107
+
PyPatchMatch/csrc/pyinterface.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <opencv2/core.hpp>
2
+ #include <cstdlib>
3
+ #include <cstdio>
4
+ #include <cstring>
5
+
6
+ extern "C" {
7
+
8
+ struct PM_shape_t {
9
+ int width, height, channels;
10
+ };
11
+
12
+ enum PM_dtype_e {
13
+ PM_UINT8,
14
+ PM_INT8,
15
+ PM_UINT16,
16
+ PM_INT16,
17
+ PM_INT32,
18
+ PM_FLOAT32,
19
+ PM_FLOAT64,
20
+ };
21
+
22
+ struct PM_mat_t {
23
+ void *data_ptr;
24
+ PM_shape_t shape;
25
+ int dtype;
26
+ };
27
+
28
+ void PM_set_random_seed(unsigned int seed);
29
+ void PM_set_verbose(int value);
30
+
31
+ void PM_free_pymat(PM_mat_t pymat);
32
+ PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
33
+ PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
34
+ PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
35
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
36
+
37
+ } /* extern "C" */
38
+
PyPatchMatch/examples/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /cpp_example.exe
2
+ /images/*recovered.bmp
PyPatchMatch/examples/cpp_example.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include <opencv2/imgcodecs.hpp>
3
+ #include <opencv2/highgui.hpp>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+ #include "inpaint.h"
8
+
9
+ int main() {
10
+ auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
11
+
12
+ auto mask = cv::Mat(source.size(), CV_8UC1);
13
+ mask = cv::Scalar::all(0);
14
+ for (int i = 0; i < source.size().height; ++i) {
15
+ for (int j = 0; j < source.size().width; ++j) {
16
+ auto source_ptr = source.ptr<unsigned char>(i, j);
17
+ if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
18
+ mask.at<unsigned char>(i, j) = 1;
19
+ }
20
+ }
21
+ }
22
+
23
+ auto metric = PatchSSDDistanceMetric(3);
24
+ auto result = Inpainting(source, mask, &metric).run(true, true);
25
+ // cv::imwrite("./images/forest_recovered.bmp", result);
26
+ // cv::imshow("Result", result);
27
+ // cv::waitKey();
28
+
29
+ return 0;
30
+ }
31
+
PyPatchMatch/examples/cpp_example_run.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ #
3
+ # cpp_example_run.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <[email protected]>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ set -x
10
+
11
+ CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
12
+ LDFLAGS="$(pkg-config --libs opencv)"
13
+ g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
14
+
15
+ export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
16
+ export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
17
+ time ./cpp_example.exe
18
+
PyPatchMatch/examples/images/forest.bmp ADDED
PyPatchMatch/examples/images/forest_pruned.bmp ADDED
PyPatchMatch/examples/py_example.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ from PIL import Image
11
+
12
+ import sys
13
+ sys.path.insert(0, '../')
14
+ import patch_match
15
+
16
+
17
+ if __name__ == '__main__':
18
+ source = Image.open('./images/forest_pruned.bmp')
19
+ result = patch_match.inpaint(source, patch_size=3)
20
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
21
+
PyPatchMatch/examples/py_example_global_mask.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ import sys
14
+ sys.path.insert(0, '../')
15
+ import patch_match
16
+
17
+
18
+ if __name__ == '__main__':
19
+ patch_match.set_verbose(True)
20
+ source = Image.open('./images/forest_pruned.bmp')
21
+ source = np.array(source)
22
+ source[:100, :100] = 255
23
+ global_mask = np.zeros_like(source[..., 0])
24
+ global_mask[:100, :100] = 1
25
+ result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
26
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
27
+
PyPatchMatch/patch_match.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : patch_match.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import ctypes
11
+ import os.path as osp
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+
18
+ import os
19
+ if os.name!="nt":
20
+ # Otherwise, fall back to the subprocess.
21
+ import subprocess
22
+ print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
23
+ # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
24
+ subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
25
+
26
+
27
+ __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
28
+
29
+
30
+ class CShapeT(ctypes.Structure):
31
+ _fields_ = [
32
+ ('width', ctypes.c_int),
33
+ ('height', ctypes.c_int),
34
+ ('channels', ctypes.c_int),
35
+ ]
36
+
37
+
38
+ class CMatT(ctypes.Structure):
39
+ _fields_ = [
40
+ ('data_ptr', ctypes.c_void_p),
41
+ ('shape', CShapeT),
42
+ ('dtype', ctypes.c_int)
43
+ ]
44
+
45
+ import tempfile
46
+ from urllib.request import urlopen, Request
47
+ import shutil
48
+ from pathlib import Path
49
+ from tqdm import tqdm
50
+
51
+ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
52
+ r"""Download object at the given URL to a local path.
53
+
54
+ Args:
55
+ url (string): URL of the object to download
56
+ dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
57
+ hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
58
+ Default: None
59
+ progress (bool, optional): whether or not to display a progress bar to stderr
60
+ Default: True
61
+ https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
62
+ """
63
+ file_size = None
64
+ req = Request(url)
65
+ u = urlopen(req)
66
+ meta = u.info()
67
+ if hasattr(meta, 'getheaders'):
68
+ content_length = meta.getheaders("Content-Length")
69
+ else:
70
+ content_length = meta.get_all("Content-Length")
71
+ if content_length is not None and len(content_length) > 0:
72
+ file_size = int(content_length[0])
73
+
74
+ # We deliberately save it in a temp file and move it after
75
+ # download is complete. This prevents a local working checkpoint
76
+ # being overridden by a broken download.
77
+ dst = os.path.expanduser(dst)
78
+ dst_dir = os.path.dirname(dst)
79
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
80
+
81
+ try:
82
+ with tqdm(total=file_size, disable=not progress,
83
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
84
+ while True:
85
+ buffer = u.read(8192)
86
+ if len(buffer) == 0:
87
+ break
88
+ f.write(buffer)
89
+ pbar.update(len(buffer))
90
+
91
+ f.close()
92
+ shutil.move(f.name, dst)
93
+ finally:
94
+ f.close()
95
+ if os.path.exists(f.name):
96
+ os.remove(f.name)
97
+
98
+ if os.name!="nt":
99
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
100
+ else:
101
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
102
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
103
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
104
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll'))
105
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
106
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder")
107
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
108
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder")
109
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
110
+
111
+ PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
112
+ PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
113
+ PMLIB.PM_free_pymat.argtypes = [CMatT]
114
+ PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
115
+ PMLIB.PM_inpaint.restype = CMatT
116
+ PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
117
+ PMLIB.PM_inpaint_regularity.restype = CMatT
118
+ PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
119
+ PMLIB.PM_inpaint2.restype = CMatT
120
+ PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
121
+ PMLIB.PM_inpaint2_regularity.restype = CMatT
122
+
123
+
124
+ def set_random_seed(seed: int):
125
+ PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
126
+
127
+
128
+ def set_verbose(verbose: bool):
129
+ PMLIB.PM_set_verbose(ctypes.c_int(verbose))
130
+
131
+
132
+ def inpaint(
133
+ image: Union[np.ndarray, Image.Image],
134
+ mask: Optional[Union[np.ndarray, Image.Image]] = None,
135
+ *,
136
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
137
+ patch_size: int = 15
138
+ ) -> np.ndarray:
139
+ """
140
+ PatchMatch based inpainting proposed in:
141
+
142
+ PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
143
+ C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
144
+ SIGGRAPH 2009
145
+
146
+ Args:
147
+ image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
148
+ mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
149
+ If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
150
+ global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
151
+ patch_size (int): the patch size for the inpainting algorithm.
152
+
153
+ Return:
154
+ result (np.ndarray): the repaired image, of the same size as the input image.
155
+ """
156
+
157
+ if isinstance(image, Image.Image):
158
+ image = np.array(image)
159
+ image = np.ascontiguousarray(image)
160
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
161
+
162
+ if mask is None:
163
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
164
+ mask = np.ascontiguousarray(mask)
165
+ else:
166
+ mask = _canonize_mask_array(mask)
167
+
168
+ if global_mask is None:
169
+ ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
170
+ else:
171
+ global_mask = _canonize_mask_array(global_mask)
172
+ ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
173
+
174
+ ret_npmat = pymat_to_np(ret_pymat)
175
+ PMLIB.PM_free_pymat(ret_pymat)
176
+
177
+ return ret_npmat
178
+
179
+
180
+ def inpaint_regularity(
181
+ image: Union[np.ndarray, Image.Image],
182
+ mask: Optional[Union[np.ndarray, Image.Image]],
183
+ ijmap: np.ndarray,
184
+ *,
185
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
186
+ patch_size: int = 15, guide_weight: float = 0.25
187
+ ) -> np.ndarray:
188
+ if isinstance(image, Image.Image):
189
+ image = np.array(image)
190
+ image = np.ascontiguousarray(image)
191
+
192
+ assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
193
+ ijmap = np.ascontiguousarray(ijmap)
194
+
195
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
196
+ if mask is None:
197
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
198
+ mask = np.ascontiguousarray(mask)
199
+ else:
200
+ mask = _canonize_mask_array(mask)
201
+
202
+
203
+ if global_mask is None:
204
+ ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
205
+ else:
206
+ global_mask = _canonize_mask_array(global_mask)
207
+ ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
208
+
209
+ ret_npmat = pymat_to_np(ret_pymat)
210
+ PMLIB.PM_free_pymat(ret_pymat)
211
+
212
+ return ret_npmat
213
+
214
+
215
+ def _canonize_mask_array(mask):
216
+ if isinstance(mask, Image.Image):
217
+ mask = np.array(mask)
218
+ if mask.ndim == 2 and mask.dtype == 'uint8':
219
+ mask = mask[..., np.newaxis]
220
+ assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
221
+ return np.ascontiguousarray(mask)
222
+
223
+
224
+ dtype_pymat_to_ctypes = [
225
+ ctypes.c_uint8,
226
+ ctypes.c_int8,
227
+ ctypes.c_uint16,
228
+ ctypes.c_int16,
229
+ ctypes.c_int32,
230
+ ctypes.c_float,
231
+ ctypes.c_double,
232
+ ]
233
+
234
+
235
+ dtype_np_to_pymat = {
236
+ 'uint8': 0,
237
+ 'int8': 1,
238
+ 'uint16': 2,
239
+ 'int16': 3,
240
+ 'int32': 4,
241
+ 'float32': 5,
242
+ 'float64': 6,
243
+ }
244
+
245
+
246
+ def np_to_pymat(npmat):
247
+ assert npmat.ndim == 3
248
+ return CMatT(
249
+ ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
250
+ CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
251
+ dtype_np_to_pymat[str(npmat.dtype)]
252
+ )
253
+
254
+
255
+ def pymat_to_np(pymat):
256
+ npmat = np.ctypeslib.as_array(
257
+ ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
258
+ (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
259
+ )
260
+ ret = np.empty(npmat.shape, npmat.dtype)
261
+ ret[:] = npmat
262
+ return ret
263
+
PyPatchMatch/travis.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ #
3
+ # travis.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <[email protected]>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ make clean && make
app.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import os
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import autocast
9
+ import diffusers
10
+ import requests
11
+
12
+
13
+ assert tuple(map(int,diffusers.__version__.split("."))) >= (0,9,0), "Please upgrade diffusers to 0.9.0"
14
+
15
+ from diffusers.configuration_utils import FrozenDict
16
+ from diffusers import (
17
+ StableDiffusionPipeline,
18
+ StableDiffusionInpaintPipeline,
19
+ StableDiffusionImg2ImgPipeline,
20
+ StableDiffusionInpaintPipelineLegacy,
21
+ DDIMScheduler,
22
+ LMSDiscreteScheduler,
23
+ DiffusionPipeline,
24
+ StableDiffusionUpscalePipeline,
25
+ DPMSolverMultistepScheduler,
26
+ PNDMScheduler,
27
+ )
28
+ from diffusers.models import AutoencoderKL
29
+ from PIL import Image
30
+ from PIL import ImageOps
31
+ import gradio as gr
32
+ import base64
33
+ import skimage
34
+ import skimage.measure
35
+ import yaml
36
+ import json
37
+ from enum import Enum
38
+ from utils import *
39
+
40
+ # load environment variables from the .env file
41
+ if os.path.exists(".env"):
42
+ with open(".env") as f:
43
+ for line in f:
44
+ if line.startswith("#") or not line.strip():
45
+ continue
46
+ name, value = line.strip().split("=", 1)
47
+ os.environ[name] = value
48
+
49
+
50
+ access_token = os.environ.get("HF_ACCESS_TOKEN")
51
+ print("access_token from HF 1:", access_token)
52
+
53
+
54
+
55
+ def query(payload, model_id, api_token):
56
+ headers = {"Authorization": f"Bearer {api_token}"}
57
+ API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
58
+ response = requests.post(API_URL, headers=headers, json=payload)
59
+ return response.json()
60
+
61
+ model_id = "stabilityai/stable-diffusion-2-inpainting"
62
+ api_token = "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" # get yours at hf.co/settings/tokens
63
+ data = query("The goal of life is [MASK].", model_id, api_token)
64
+
65
+
66
+
67
+
68
+ # def get_latest_image_url(database_url):
69
+ # response = requests.get(f"{database_url}/latestImage.json")
70
+ # latest_image_data = response.json()
71
+ # image_url = latest_image_data['downloadURL']
72
+ # image_name = latest_image_data['fileName']
73
+ # return image_url, image_name
74
+
75
+ # database_url = 'https://nyucapstone-7c22c-default-rtdb.firebaseio.com'
76
+ # latest_image_url, latest_image_name = get_latest_image_url(database_url)
77
+ # print(f"Latest image URL: {latest_image_url}")
78
+ # print(f"Latest image name: {latest_image_name}")
79
+
80
+ try:
81
+ abspath = os.path.abspath(__file__)
82
+ dirname = os.path.dirname(abspath)
83
+ os.chdir(dirname)
84
+ except:
85
+ pass
86
+
87
+ try:
88
+ from interrogate import Interrogator
89
+ except:
90
+ Interrogator = DummyInterrogator
91
+
92
+ USE_NEW_DIFFUSERS = True
93
+ RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
94
+
95
+
96
+ class ModelChoice(Enum):
97
+ INPAINTING = "stablediffusion-inpainting"
98
+ INPAINTING2 = "stablediffusion-2-inpainting"
99
+ INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-1.5"
100
+ MODEL_2_1 = "stablediffusion-2.1"
101
+ MODEL_2_0_V = "stablediffusion-2.0v"
102
+ MODEL_2_0 = "stablediffusion-2.0"
103
+ MODEL_1_5 = "stablediffusion-1.5"
104
+ MODEL_1_4 = "stablediffusion-1.4"
105
+
106
+
107
+ try:
108
+ from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline
109
+ except:
110
+ UnifiedPipeline = StableDiffusionInpaintPipeline
111
+
112
+ # sys.path.append("./glid_3_xl_stable")
113
+
114
+ USE_GLID = False
115
+ # try:
116
+ # from glid3xlmodel import GlidModel
117
+ # except:
118
+ # USE_GLID = False
119
+
120
+ try:
121
+ import onnxruntime
122
+ onnx_available = True
123
+ onnx_providers = ["CUDAExecutionProvider", "DmlExecutionProvider", "OpenVINOExecutionProvider", 'CPUExecutionProvider']
124
+ available_providers = onnxruntime.get_available_providers()
125
+ onnx_providers = [item for item in onnx_providers if item in available_providers]
126
+ except:
127
+ onnx_available = False
128
+ onnx_providers = []
129
+
130
+ try:
131
+ cuda_available = torch.cuda.is_available()
132
+ except:
133
+ cuda_available = False
134
+ finally:
135
+ if sys.platform == "darwin":
136
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
137
+ elif cuda_available:
138
+ device = "cuda"
139
+ else:
140
+ device = "cpu"
141
+
142
+ if device != "cuda":
143
+ import contextlib
144
+
145
+ autocast = contextlib.nullcontext
146
+
147
+ with open("config.yaml", "r") as yaml_in:
148
+ yaml_object = yaml.safe_load(yaml_in)
149
+ config_json = json.dumps(yaml_object)
150
+
151
+
152
+
153
+ def load_html():
154
+ body, canvaspy = "", ""
155
+ with open("index.html", encoding="utf8") as f:
156
+ body = f.read()
157
+ with open("canvas.py", encoding="utf8") as f:
158
+ canvaspy = f.read()
159
+ body = body.replace("- paths:\n", "")
160
+ body = body.replace(" - ./canvas.py\n", "")
161
+ body = body.replace("from canvas import InfCanvas", canvaspy)
162
+ return body
163
+
164
+
165
+ def test(x):
166
+ x = load_html()
167
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 780px" name="result" allow="midi; geolocation; microphone; camera;
168
+ display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
169
+ allow-scripts allow-same-origin allow-popups
170
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
171
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
172
+
173
+
174
+ DEBUG_MODE = False
175
+
176
+ try:
177
+ SAMPLING_MODE = Image.Resampling.LANCZOS
178
+ except Exception as e:
179
+ SAMPLING_MODE = Image.LANCZOS
180
+
181
+ try:
182
+ contain_func = ImageOps.contain
183
+ except Exception as e:
184
+
185
+ def contain_func(image, size, method=SAMPLING_MODE):
186
+ # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
187
+ im_ratio = image.width / image.height
188
+ dest_ratio = size[0] / size[1]
189
+ if im_ratio != dest_ratio:
190
+ if im_ratio > dest_ratio:
191
+ new_height = int(image.height / image.width * size[0])
192
+ if new_height != size[1]:
193
+ size = (size[0], new_height)
194
+ else:
195
+ new_width = int(image.width / image.height * size[1])
196
+ if new_width != size[0]:
197
+ size = (new_width, size[1])
198
+ return image.resize(size, resample=method)
199
+
200
+
201
+ import argparse
202
+
203
+ parser = argparse.ArgumentParser(description="stablediffusion-infinity")
204
+ parser.add_argument("--port", type=int, help="listen port", dest="server_port")
205
+ parser.add_argument("--host", type=str, help="host", dest="server_name")
206
+ parser.add_argument("--share", action="store_true", help="share this app?")
207
+ parser.add_argument("--debug", action="store_true", help="debug mode")
208
+ parser.add_argument("--fp32", action="store_true", help="using full precision")
209
+ parser.add_argument("--lowvram", action="store_true", help="using lowvram mode")
210
+ parser.add_argument("--encrypt", action="store_true", help="using https?")
211
+ parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
212
+ parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
213
+ parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
214
+ parser.add_argument(
215
+ "--auth", nargs=2, metavar=("username", "password"), help="use username password"
216
+ )
217
+ parser.add_argument(
218
+ "--remote_model",
219
+ type=str,
220
+ help="use a model (e.g. dreambooth fined) from huggingface hub",
221
+ default="",
222
+ )
223
+ parser.add_argument(
224
+ "--local_model", type=str, help="use a model stored on your PC", default=""
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ args = parser.parse_args()
229
+ else:
230
+ args = parser.parse_args(["--debug"])
231
+ # args = parser.parse_args(["--debug"])
232
+ if args.auth is not None:
233
+ args.auth = tuple(args.auth)
234
+
235
+ model = {}
236
+
237
+
238
+ def get_token():
239
+ token = "{access_token}"
240
+ if os.path.exists(".token"):
241
+ with open(".token", "r") as f:
242
+ token = f.read()
243
+ print("get_token called", token)
244
+ token = os.environ.get("hftoken", token)
245
+ return token
246
+
247
+
248
+ def save_token(token):
249
+ with open(".token", "w") as f:
250
+ f.write(token)
251
+
252
+
253
+ def prepare_scheduler(scheduler):
254
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
255
+ new_config = dict(scheduler.config)
256
+ new_config["steps_offset"] = 1
257
+ scheduler._internal_dict = FrozenDict(new_config)
258
+ return scheduler
259
+
260
+
261
+ def my_resize(width, height):
262
+ if width >= 512 and height >= 512:
263
+ return width, height
264
+ if width == height:
265
+ return 512, 512
266
+ smaller = min(width, height)
267
+ larger = max(width, height)
268
+ if larger >= 608:
269
+ return width, height
270
+ factor = 1
271
+ if smaller < 290:
272
+ factor = 2
273
+ elif smaller < 330:
274
+ factor = 1.75
275
+ elif smaller < 384:
276
+ factor = 1.375
277
+ elif smaller < 400:
278
+ factor = 1.25
279
+ elif smaller < 450:
280
+ factor = 1.125
281
+ return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8
282
+
283
+
284
+ def load_learned_embed_in_clip(
285
+ learned_embeds_path, text_encoder, tokenizer, token=None
286
+ ):
287
+ # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
288
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
289
+
290
+ # separate token and the embeds
291
+ trained_token = list(loaded_learned_embeds.keys())[0]
292
+ embeds = loaded_learned_embeds[trained_token]
293
+
294
+ # cast to dtype of text_encoder
295
+ dtype = text_encoder.get_input_embeddings().weight.dtype
296
+ embeds.to(dtype)
297
+
298
+ # add the token in tokenizer
299
+ token = token if token is not None else trained_token
300
+ num_added_tokens = tokenizer.add_tokens(token)
301
+ if num_added_tokens == 0:
302
+ raise ValueError(
303
+ f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
304
+ )
305
+
306
+ # resize the token embeddings
307
+ text_encoder.resize_token_embeddings(len(tokenizer))
308
+
309
+ # get the id for the token and assign the embeds
310
+ token_id = tokenizer.convert_tokens_to_ids(token)
311
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
312
+
313
+
314
+ scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None, "DPM": None, "PNDM": None}
315
+
316
+
317
+ class StableDiffusionInpaint:
318
+ def __init__(
319
+ self, token: str = "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", model_name: str = "", model_path: str = "", **kwargs,
320
+ ):
321
+ self.token = token
322
+ original_checkpoint = False
323
+ # if device == "cpu" and onnx_available:
324
+ # from diffusers import OnnxStableDiffusionInpaintPipeline
325
+ # inpaint = OnnxStableDiffusionInpaintPipeline.from_pretrained(
326
+ # model_name,
327
+ # revision="onnx",
328
+ # provider=onnx_providers[0] if onnx_providers else None
329
+ # )
330
+ # else:
331
+ if model_path and os.path.exists(model_path):
332
+ if model_path.endswith(".ckpt"):
333
+ original_checkpoint = True
334
+ elif model_path.endswith(".json"):
335
+ model_name = os.path.dirname(model_path)
336
+ else:
337
+ model_name = model_path
338
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
339
+ if device == "cuda" and not args.fp32:
340
+ vae.to(torch.float16)
341
+ if original_checkpoint:
342
+ print(f"Converting & Loading {model_path}")
343
+ from convert_checkpoint import convert_checkpoint
344
+
345
+ pipe = convert_checkpoint(model_path, inpainting=True)
346
+ if device == "cuda" and not args.fp32:
347
+ pipe.to(torch.float16)
348
+ inpaint = StableDiffusionInpaintPipeline(
349
+ vae=vae,
350
+ text_encoder=pipe.text_encoder,
351
+ tokenizer=pipe.tokenizer,
352
+ unet=pipe.unet,
353
+ scheduler=pipe.scheduler,
354
+ safety_checker=pipe.safety_checker,
355
+ feature_extractor=pipe.feature_extractor,
356
+ )
357
+ else:
358
+ print(f"Loading {model_name}")
359
+ if device == "cuda" and not args.fp32:
360
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
361
+ model_name,
362
+ revision="fp16",
363
+ torch_dtype=torch.float16,
364
+ use_auth_token=token,
365
+ vae=vae,
366
+ )
367
+ else:
368
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
369
+ model_name, use_auth_token=access_token, vae=vae
370
+ )
371
+ print(f"access_token from HF:", access_token)
372
+ if os.path.exists("./embeddings"):
373
+ print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
374
+ for item in os.listdir("./embeddings"):
375
+ if item.endswith(".bin"):
376
+ load_learned_embed_in_clip(
377
+ os.path.join("./embeddings", item),
378
+ inpaint.text_encoder,
379
+ inpaint.tokenizer,
380
+ )
381
+ inpaint.to(device)
382
+ # if device == "mps":
383
+ # _ = text2img("", num_inference_steps=1)
384
+ scheduler_dict["PLMS"] = inpaint.scheduler
385
+ scheduler_dict["DDIM"] = prepare_scheduler(
386
+ DDIMScheduler(
387
+ beta_start=0.00085,
388
+ beta_end=0.012,
389
+ beta_schedule="scaled_linear",
390
+ clip_sample=False,
391
+ set_alpha_to_one=False,
392
+ )
393
+ )
394
+ scheduler_dict["K-LMS"] = prepare_scheduler(
395
+ LMSDiscreteScheduler(
396
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
397
+ )
398
+ )
399
+ scheduler_dict["PNDM"] = prepare_scheduler(
400
+ PNDMScheduler(
401
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
402
+ skip_prk_steps=True
403
+ )
404
+ )
405
+ scheduler_dict["DPM"] = prepare_scheduler(
406
+ DPMSolverMultistepScheduler.from_config(inpaint.scheduler.config)
407
+ )
408
+ self.safety_checker = inpaint.safety_checker
409
+ save_token(token)
410
+ try:
411
+ total_memory = torch.cuda.get_device_properties(0).total_memory // (
412
+ 1024 ** 3
413
+ )
414
+ if total_memory <= 5 or args.lowvram:
415
+ inpaint.enable_attention_slicing()
416
+ inpaint.enable_sequential_cpu_offload()
417
+ except:
418
+ pass
419
+ self.inpaint = inpaint
420
+
421
+ def run(
422
+ self,
423
+ image_pil,
424
+ prompt="",
425
+ negative_prompt="",
426
+ guidance_scale=7.5,
427
+ resize_check=True,
428
+ enable_safety=True,
429
+ fill_mode="patchmatch",
430
+ strength=0.75,
431
+ step=50,
432
+ enable_img2img=False,
433
+ use_seed=False,
434
+ seed_val=-1,
435
+ generate_num=1,
436
+ scheduler="",
437
+ scheduler_eta=0.0,
438
+ **kwargs,
439
+ ):
440
+ inpaint = self.inpaint
441
+ selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
442
+ for item in [inpaint]:
443
+ item.scheduler = selected_scheduler
444
+ if enable_safety or self.safety_checker is None:
445
+ item.safety_checker = self.safety_checker
446
+ else:
447
+ item.safety_checker = lambda images, **kwargs: (images, False)
448
+ width, height = image_pil.size
449
+ sel_buffer = np.array(image_pil)
450
+ img = sel_buffer[:, :, 0:3]
451
+ mask = sel_buffer[:, :, -1]
452
+ nmask = 255 - mask
453
+ process_width = width
454
+ process_height = height
455
+ if resize_check:
456
+ process_width, process_height = my_resize(width, height)
457
+ process_width = process_width * 8 // 8
458
+ process_height = process_height * 8 // 8
459
+ extra_kwargs = {
460
+ "num_inference_steps": step,
461
+ "guidance_scale": guidance_scale,
462
+ "eta": scheduler_eta,
463
+ }
464
+ if USE_NEW_DIFFUSERS:
465
+ extra_kwargs["negative_prompt"] = negative_prompt
466
+ extra_kwargs["num_images_per_prompt"] = generate_num
467
+ if use_seed:
468
+ generator = torch.Generator(inpaint.device).manual_seed(seed_val)
469
+ extra_kwargs["generator"] = generator
470
+ if True:
471
+ if fill_mode == "g_diffuser":
472
+ mask = 255 - mask
473
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
474
+ img, mask = functbl[fill_mode](img, mask)
475
+ else:
476
+ img, mask = functbl[fill_mode](img, mask)
477
+ mask = 255 - mask
478
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
479
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
480
+ # extra_kwargs["strength"] = strength
481
+ inpaint_func = inpaint
482
+ init_image = Image.fromarray(img)
483
+ mask_image = Image.fromarray(mask)
484
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
485
+ if True:
486
+ images = inpaint_func(
487
+ prompt=prompt,
488
+ image=init_image.resize(
489
+ (process_width, process_height), resample=SAMPLING_MODE
490
+ ),
491
+ mask_image=mask_image.resize((process_width, process_height)),
492
+ width=process_width,
493
+ height=process_height,
494
+ **extra_kwargs,
495
+ )["images"]
496
+ return images
497
+
498
+
499
+ # class StableDiffusion:
500
+ # def __init__(
501
+ # self,
502
+ # token: str = "",
503
+ # model_name: str = "runwayml/stable-diffusion-v1-5",
504
+ # model_path: str = None,
505
+ # inpainting_model: bool = False,
506
+ # **kwargs,
507
+ # ):
508
+ # self.token = token
509
+ # original_checkpoint = False
510
+ # if device=="cpu" and onnx_available:
511
+ # from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
512
+ # text2img = OnnxStableDiffusionPipeline.from_pretrained(
513
+ # model_name,
514
+ # revision="onnx",
515
+ # provider=onnx_providers[0] if onnx_providers else None
516
+ # )
517
+ # inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
518
+ # vae_encoder=text2img.vae_encoder,
519
+ # vae_decoder=text2img.vae_decoder,
520
+ # text_encoder=text2img.text_encoder,
521
+ # tokenizer=text2img.tokenizer,
522
+ # unet=text2img.unet,
523
+ # scheduler=text2img.scheduler,
524
+ # safety_checker=text2img.safety_checker,
525
+ # feature_extractor=text2img.feature_extractor,
526
+ # )
527
+ # img2img = OnnxStableDiffusionImg2ImgPipeline(
528
+ # vae_encoder=text2img.vae_encoder,
529
+ # vae_decoder=text2img.vae_decoder,
530
+ # text_encoder=text2img.text_encoder,
531
+ # tokenizer=text2img.tokenizer,
532
+ # unet=text2img.unet,
533
+ # scheduler=text2img.scheduler,
534
+ # safety_checker=text2img.safety_checker,
535
+ # feature_extractor=text2img.feature_extractor,
536
+ # )
537
+ # else:
538
+ # if model_path and os.path.exists(model_path):
539
+ # if model_path.endswith(".ckpt"):
540
+ # original_checkpoint = True
541
+ # elif model_path.endswith(".json"):
542
+ # model_name = os.path.dirname(model_path)
543
+ # else:
544
+ # model_name = model_path
545
+ # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
546
+ # if device == "cuda" and not args.fp32:
547
+ # vae.to(torch.float16)
548
+ # if original_checkpoint:
549
+ # print(f"Converting & Loading {model_path}")
550
+ # from convert_checkpoint import convert_checkpoint
551
+
552
+ # pipe = convert_checkpoint(model_path)
553
+ # if device == "cuda" and not args.fp32:
554
+ # pipe.to(torch.float16)
555
+ # text2img = StableDiffusionPipeline(
556
+ # vae=vae,
557
+ # text_encoder=pipe.text_encoder,
558
+ # tokenizer=pipe.tokenizer,
559
+ # unet=pipe.unet,
560
+ # scheduler=pipe.scheduler,
561
+ # safety_checker=pipe.safety_checker,
562
+ # feature_extractor=pipe.feature_extractor,
563
+ # )
564
+ # else:
565
+ # print(f"Loading {model_name}")
566
+ # if device == "cuda" and not args.fp32:
567
+ # text2img = StableDiffusionPipeline.from_pretrained(
568
+ # model_name,
569
+ # revision="fp16",
570
+ # torch_dtype=torch.float16,
571
+ # use_auth_token=token,
572
+ # vae=vae,
573
+ # )
574
+ # else:
575
+ # text2img = StableDiffusionPipeline.from_pretrained(
576
+ # model_name, use_auth_token=token, vae=vae
577
+ # )
578
+ # if inpainting_model:
579
+ # # can reduce vRAM by reusing models except unet
580
+ # text2img_unet = text2img.unet
581
+ # del text2img.vae
582
+ # del text2img.text_encoder
583
+ # del text2img.tokenizer
584
+ # del text2img.scheduler
585
+ # del text2img.safety_checker
586
+ # del text2img.feature_extractor
587
+ # import gc
588
+
589
+ # gc.collect()
590
+ # if device == "cuda" and not args.fp32:
591
+ # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
592
+ # "runwayml/stable-diffusion-inpainting",
593
+ # revision="fp16",
594
+ # torch_dtype=torch.float16,
595
+ # use_auth_token=token,
596
+ # vae=vae,
597
+ # ).to(device)
598
+ # else:
599
+ # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
600
+ # "runwayml/stable-diffusion-inpainting",
601
+ # use_auth_token=token,
602
+ # vae=vae,
603
+ # ).to(device)
604
+ # text2img_unet.to(device)
605
+ # text2img = StableDiffusionPipeline(
606
+ # vae=inpaint.vae,
607
+ # text_encoder=inpaint.text_encoder,
608
+ # tokenizer=inpaint.tokenizer,
609
+ # unet=text2img_unet,
610
+ # scheduler=inpaint.scheduler,
611
+ # safety_checker=inpaint.safety_checker,
612
+ # feature_extractor=inpaint.feature_extractor,
613
+ # )
614
+ # else:
615
+ # inpaint = StableDiffusionInpaintPipelineLegacy(
616
+ # vae=text2img.vae,
617
+ # text_encoder=text2img.text_encoder,
618
+ # tokenizer=text2img.tokenizer,
619
+ # unet=text2img.unet,
620
+ # scheduler=text2img.scheduler,
621
+ # safety_checker=text2img.safety_checker,
622
+ # feature_extractor=text2img.feature_extractor,
623
+ # ).to(device)
624
+ # text_encoder = text2img.text_encoder
625
+ # tokenizer = text2img.tokenizer
626
+ # if os.path.exists("./embeddings"):
627
+ # for item in os.listdir("./embeddings"):
628
+ # if item.endswith(".bin"):
629
+ # load_learned_embed_in_clip(
630
+ # os.path.join("./embeddings", item),
631
+ # text2img.text_encoder,
632
+ # text2img.tokenizer,
633
+ # )
634
+ # text2img.to(device)
635
+ # if device == "mps":
636
+ # _ = text2img("", num_inference_steps=1)
637
+ # img2img = StableDiffusionImg2ImgPipeline(
638
+ # vae=text2img.vae,
639
+ # text_encoder=text2img.text_encoder,
640
+ # tokenizer=text2img.tokenizer,
641
+ # unet=text2img.unet,
642
+ # scheduler=text2img.scheduler,
643
+ # safety_checker=text2img.safety_checker,
644
+ # feature_extractor=text2img.feature_extractor,
645
+ # ).to(device)
646
+ # scheduler_dict["PLMS"] = text2img.scheduler
647
+ # scheduler_dict["DDIM"] = prepare_scheduler(
648
+ # DDIMScheduler(
649
+ # beta_start=0.00085,
650
+ # beta_end=0.012,
651
+ # beta_schedule="scaled_linear",
652
+ # clip_sample=False,
653
+ # set_alpha_to_one=False,
654
+ # )
655
+ # )
656
+ # scheduler_dict["K-LMS"] = prepare_scheduler(
657
+ # LMSDiscreteScheduler(
658
+ # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
659
+ # )
660
+ # )
661
+ # scheduler_dict["PNDM"] = prepare_scheduler(
662
+ # PNDMScheduler(
663
+ # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
664
+ # skip_prk_steps=True
665
+ # )
666
+ # )
667
+ # scheduler_dict["DPM"] = prepare_scheduler(
668
+ # DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
669
+ # )
670
+ # self.safety_checker = text2img.safety_checker
671
+ # save_token(token)
672
+ # try:
673
+ # total_memory = torch.cuda.get_device_properties(0).total_memory // (
674
+ # 1024 ** 3
675
+ # )
676
+ # if total_memory <= 5 or args.lowvram:
677
+ # inpaint.enable_attention_slicing()
678
+ # inpaint.enable_sequential_cpu_offload()
679
+ # if inpainting_model:
680
+ # text2img.enable_attention_slicing()
681
+ # text2img.enable_sequential_cpu_offload()
682
+ # except:
683
+ # pass
684
+ # self.text2img = text2img
685
+ # self.inpaint = inpaint
686
+ # self.img2img = img2img
687
+ # if True:
688
+ # self.unified = inpaint
689
+ # else:
690
+ # self.unified = UnifiedPipeline(
691
+ # vae=text2img.vae,
692
+ # text_encoder=text2img.text_encoder,
693
+ # tokenizer=text2img.tokenizer,
694
+ # unet=text2img.unet,
695
+ # scheduler=text2img.scheduler,
696
+ # safety_checker=text2img.safety_checker,
697
+ # feature_extractor=text2img.feature_extractor,
698
+ # ).to(device)
699
+ # self.inpainting_model = inpainting_model
700
+
701
+ # def run(
702
+ # self,
703
+ # image_pil,
704
+ # prompt="",
705
+ # negative_prompt="",
706
+ # guidance_scale=7.5,
707
+ # resize_check=True,
708
+ # enable_safety=True,
709
+ # fill_mode="patchmatch",
710
+ # strength=0.75,
711
+ # step=50,
712
+ # enable_img2img=False,
713
+ # use_seed=False,
714
+ # seed_val=-1,
715
+ # generate_num=1,
716
+ # scheduler="",
717
+ # scheduler_eta=0.0,
718
+ # **kwargs,
719
+ # ):
720
+ # text2img, inpaint, img2img, unified = (
721
+ # self.text2img,
722
+ # self.inpaint,
723
+ # self.img2img,
724
+ # self.unified,
725
+ # )
726
+ # selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
727
+ # for item in [text2img, inpaint, img2img, unified]:
728
+ # item.scheduler = selected_scheduler
729
+ # if enable_safety or self.safety_checker is None:
730
+ # item.safety_checker = self.safety_checker
731
+ # else:
732
+ # item.safety_checker = lambda images, **kwargs: (images, False)
733
+ # if RUN_IN_SPACE:
734
+ # step = max(150, step)
735
+ # image_pil = contain_func(image_pil, (1024, 1024))
736
+ # width, height = image_pil.size
737
+ # sel_buffer = np.array(image_pil)
738
+ # img = sel_buffer[:, :, 0:3]
739
+ # mask = sel_buffer[:, :, -1]
740
+ # nmask = 255 - mask
741
+ # process_width = width
742
+ # process_height = height
743
+ # if resize_check:
744
+ # process_width, process_height = my_resize(width, height)
745
+ # extra_kwargs = {
746
+ # "num_inference_steps": step,
747
+ # "guidance_scale": guidance_scale,
748
+ # "eta": scheduler_eta,
749
+ # }
750
+ # if RUN_IN_SPACE:
751
+ # generate_num = max(
752
+ # int(4 * 512 * 512 // process_width // process_height), generate_num
753
+ # )
754
+ # if USE_NEW_DIFFUSERS:
755
+ # extra_kwargs["negative_prompt"] = negative_prompt
756
+ # extra_kwargs["num_images_per_prompt"] = generate_num
757
+ # if use_seed:
758
+ # generator = torch.Generator(text2img.device).manual_seed(seed_val)
759
+ # extra_kwargs["generator"] = generator
760
+ # if nmask.sum() < 1 and enable_img2img:
761
+ # init_image = Image.fromarray(img)
762
+ # if True:
763
+ # images = img2img(
764
+ # prompt=prompt,
765
+ # image=init_image.resize(
766
+ # (process_width, process_height), resample=SAMPLING_MODE
767
+ # ),
768
+ # strength=strength,
769
+ # **extra_kwargs,
770
+ # )["images"]
771
+ # elif mask.sum() > 0:
772
+ # if fill_mode == "g_diffuser" and not self.inpainting_model:
773
+ # mask = 255 - mask
774
+ # mask = mask[:, :, np.newaxis].repeat(3, axis=2)
775
+ # img, mask = functbl[fill_mode](img, mask)
776
+ # extra_kwargs["strength"] = 1.0
777
+ # extra_kwargs["out_mask"] = Image.fromarray(mask)
778
+ # inpaint_func = unified
779
+ # else:
780
+ # img, mask = functbl[fill_mode](img, mask)
781
+ # mask = 255 - mask
782
+ # mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
783
+ # mask = mask.repeat(8, axis=0).repeat(8, axis=1)
784
+ # inpaint_func = inpaint
785
+ # init_image = Image.fromarray(img)
786
+ # mask_image = Image.fromarray(mask)
787
+ # # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
788
+ # input_image = init_image.resize(
789
+ # (process_width, process_height), resample=SAMPLING_MODE
790
+ # )
791
+ # if self.inpainting_model:
792
+ # images = inpaint_func(
793
+ # prompt=prompt,
794
+ # image=input_image,
795
+ # width=process_width,
796
+ # height=process_height,
797
+ # mask_image=mask_image.resize((process_width, process_height)),
798
+ # **extra_kwargs,
799
+ # )["images"]
800
+ # else:
801
+ # extra_kwargs["strength"] = strength
802
+ # if True:
803
+ # images = inpaint_func(
804
+ # prompt=prompt,
805
+ # image=input_image,
806
+ # mask_image=mask_image.resize((process_width, process_height)),
807
+ # **extra_kwargs,
808
+ # )["images"]
809
+ # else:
810
+ # if True:
811
+ # images = text2img(
812
+ # prompt=prompt,
813
+ # height=process_width,
814
+ # width=process_height,
815
+ # **extra_kwargs,
816
+ # )["images"]
817
+ # return images
818
+
819
+
820
+ def get_model(token="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", model_choice="", model_path=""):
821
+ if "model" not in model:
822
+ model_name = ""
823
+ if args.local_model:
824
+ print(f"Using local_model: {args.local_model}")
825
+ model_path = args.local_model
826
+ elif args.remote_model:
827
+ print(f"Using remote_model: {args.remote_model}")
828
+ model_name = args.remote_model
829
+ if model_choice == ModelChoice.INPAINTING.value:
830
+ if len(model_name) < 1:
831
+ model_name = "runwayml/stable-diffusion-inpainting"
832
+ print(f"Using [{model_name}] {model_path}")
833
+ tmp = StableDiffusionInpaint(
834
+ token=token, model_name=model_name, model_path=model_path
835
+ )
836
+ elif model_choice == ModelChoice.INPAINTING2.value:
837
+ if len(model_name) < 1:
838
+ model_name = "stabilityai/stable-diffusion-2-inpainting"
839
+ print(f"Using [{model_name}] {model_path}")
840
+ tmp = StableDiffusionInpaint(
841
+ token=token, model_name=model_name, model_path=model_path
842
+ )
843
+ elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
844
+ print(
845
+ f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
846
+ )
847
+ tmp = StableDiffusion(token=token, inpainting_model=True)
848
+ else:
849
+ if len(model_name) < 1:
850
+ model_name = (
851
+ "runwayml/stable-diffusion-v1-5"
852
+ if model_choice == ModelChoice.MODEL_1_5.value
853
+ else "CompVis/stable-diffusion-v1-4"
854
+ )
855
+ if model_choice == ModelChoice.MODEL_2_0.value:
856
+ model_name = "stabilityai/stable-diffusion-2-base"
857
+ elif model_choice == ModelChoice.MODEL_2_0_V.value:
858
+ model_name = "stabilityai/stable-diffusion-2"
859
+ elif model_choice == ModelChoice.MODEL_2_1.value:
860
+ model_name = "stabilityai/stable-diffusion-2-1-base"
861
+ tmp = StableDiffusion(
862
+ token=token, model_name=model_name, model_path=model_path
863
+ )
864
+ model["model"] = tmp
865
+ return model["model"]
866
+
867
+
868
+ def run_outpaint(
869
+ sel_buffer_str,
870
+ prompt_text,
871
+ negative_prompt_text,
872
+ strength,
873
+ guidance,
874
+ step,
875
+ resize_check,
876
+ fill_mode,
877
+ enable_safety,
878
+ use_correction,
879
+ enable_img2img,
880
+ use_seed,
881
+ seed_val,
882
+ generate_num,
883
+ scheduler,
884
+ scheduler_eta,
885
+ interrogate_mode,
886
+ state,
887
+ ):
888
+ data = base64.b64decode(str(sel_buffer_str))
889
+ pil = Image.open(io.BytesIO(data))
890
+ if interrogate_mode:
891
+ if "interrogator" not in model:
892
+ model["interrogator"] = Interrogator()
893
+ interrogator = model["interrogator"]
894
+ # possible point to integrate
895
+ img = np.array(pil)[:, :, 0:3]
896
+ mask = np.array(pil)[:, :, -1]
897
+ x, y = np.nonzero(mask)
898
+ if len(x) > 0:
899
+ x0, x1 = x.min(), x.max() + 1
900
+ y0, y1 = y.min(), y.max() + 1
901
+ img = img[x0:x1, y0:y1, :]
902
+ pil = Image.fromarray(img)
903
+ interrogate_ret = interrogator.interrogate(pil)
904
+ return (
905
+ gr.update(value=",".join([sel_buffer_str]),),
906
+ gr.update(label="Prompt", value=interrogate_ret),
907
+ state,
908
+ )
909
+ width, height = pil.size
910
+ sel_buffer = np.array(pil)
911
+ cur_model = get_model()
912
+ images = cur_model.run(
913
+ image_pil=pil,
914
+ prompt=prompt_text,
915
+ negative_prompt=negative_prompt_text,
916
+ guidance_scale=guidance,
917
+ strength=strength,
918
+ step=step,
919
+ resize_check=resize_check,
920
+ fill_mode=fill_mode,
921
+ enable_safety=enable_safety,
922
+ use_seed=use_seed,
923
+ seed_val=seed_val,
924
+ generate_num=generate_num,
925
+ scheduler=scheduler,
926
+ scheduler_eta=scheduler_eta,
927
+ enable_img2img=enable_img2img,
928
+ width=width,
929
+ height=height,
930
+ )
931
+ base64_str_lst = []
932
+ if enable_img2img:
933
+ use_correction = "border_mode"
934
+ for image in images:
935
+ image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
936
+ resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
937
+ out = sel_buffer.copy()
938
+ out[:, :, 0:3] = np.array(resized_img)
939
+ out[:, :, -1] = 255
940
+ out_pil = Image.fromarray(out)
941
+ out_buffer = io.BytesIO()
942
+ out_pil.save(out_buffer, format="PNG")
943
+ out_buffer.seek(0)
944
+ base64_bytes = base64.b64encode(out_buffer.read())
945
+ base64_str = base64_bytes.decode("ascii")
946
+ base64_str_lst.append(base64_str)
947
+ return (
948
+ gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
949
+ gr.update(label="Prompt"),
950
+ state + 1,
951
+ )
952
+
953
+
954
+ def load_js(name):
955
+ if name in ["export", "commit", "undo"]:
956
+ return f"""
957
+ function (x)
958
+ {{
959
+ let app=document.querySelector("gradio-app");
960
+ app=app.shadowRoot??app;
961
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
962
+ let button=frame.querySelector("#{name}");
963
+ button.click();
964
+ return x;
965
+ }}
966
+ """
967
+ ret = ""
968
+ with open(f"./js/{name}.js", "r") as f:
969
+ ret = f.read()
970
+ return ret
971
+
972
+
973
+ proceed_button_js = load_js("proceed")
974
+ setup_button_js = load_js("setup")
975
+
976
+ if RUN_IN_SPACE:
977
+ get_model(
978
+ token=os.environ.get("hftoken", ""),
979
+ model_choice=ModelChoice.INPAINTING_IMG2IMG.value,
980
+ )
981
+
982
+ blocks = gr.Blocks(
983
+ title="StableDiffusion-Infinity",
984
+ css="""
985
+ .tabs {
986
+ margin-top: 0rem;
987
+ margin-bottom: 0rem;
988
+ }
989
+ #markdown {
990
+ min-height: 0rem;
991
+ }
992
+ """,
993
+ theme=gr.themes.Soft()
994
+ )
995
+ model_path_input_val = ""
996
+ with blocks as demo:
997
+ # title
998
+ title = gr.Markdown(
999
+ """
1000
+ stanley capstone
1001
+ """,
1002
+ elem_id="markdown",
1003
+ )
1004
+ # github logo
1005
+ github_logo = gr.HTML(
1006
+ """
1007
+ <a href="https://github.com/stanleywalker1/capstone-studio-2">
1008
+ <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24"><path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z" fill="white"/></svg>
1009
+ </a>
1010
+ """
1011
+ )
1012
+ # frame
1013
+ frame = gr.HTML(test(2), visible=RUN_IN_SPACE)
1014
+ # setup
1015
+
1016
+ setup_button = gr.Button("Click to Start", variant="primary")
1017
+
1018
+
1019
+ if not RUN_IN_SPACE:
1020
+ model_choices_lst = [item.value for item in ModelChoice]
1021
+ if args.local_model:
1022
+ model_path_input_val = args.local_model
1023
+ # model_choices_lst.insert(0, "local_model")
1024
+ elif args.remote_model:
1025
+ model_path_input_val = args.remote_model
1026
+ model_choices_lst.insert(0, "remote_model")
1027
+
1028
+ sd_prompt = gr.Textbox(
1029
+ label="Prompt", placeholder="input your prompt here!", lines=2
1030
+ )
1031
+ with gr.Accordion("developer tools", open=True):
1032
+ with gr.Row(elem_id="setup_row"):
1033
+ with gr.Column(scale=4, min_width=350):
1034
+ token = gr.Textbox(
1035
+ label="Huggingface token",
1036
+ value=get_token(),
1037
+ placeholder="Input your token here/Ignore this if using local model",
1038
+ )
1039
+ with gr.Column(scale=3, min_width=320):
1040
+ model_selection = gr.Radio(
1041
+ label="Choose a model type here",
1042
+ choices=model_choices_lst,
1043
+ value=ModelChoice.INPAINTING.value if onnx_available else ModelChoice.INPAINTING2.value,
1044
+ )
1045
+ with gr.Column(scale=1, min_width=100):
1046
+ canvas_width = gr.Number(
1047
+ label="Canvas width",
1048
+ value=1024,
1049
+ precision=0,
1050
+ elem_id="canvas_width",
1051
+ )
1052
+ with gr.Column(scale=1, min_width=100):
1053
+ canvas_height = gr.Number(
1054
+ label="Canvas height",
1055
+ value=700,
1056
+ precision=0,
1057
+ elem_id="canvas_height",
1058
+ )
1059
+ with gr.Column(scale=1, min_width=100):
1060
+ selection_size = gr.Number(
1061
+ label="Selection box size",
1062
+ value=256,
1063
+ precision=0,
1064
+ elem_id="selection_size",
1065
+ )
1066
+ with gr.Column(scale=3, min_width=270):
1067
+ init_mode = gr.Dropdown(
1068
+ label="Init Mode",
1069
+ choices=[
1070
+ "patchmatch",
1071
+ "edge_pad",
1072
+ "cv2_ns",
1073
+ "cv2_telea",
1074
+ "perlin",
1075
+ "gaussian",
1076
+ "g_diffuser",
1077
+ ],
1078
+ value="patchmatch",
1079
+ type="value",
1080
+ )
1081
+ postprocess_check = gr.Radio(
1082
+ label="Photometric Correction Mode",
1083
+ choices=["disabled", "mask_mode", "border_mode",],
1084
+ value="disabled",
1085
+ type="value",
1086
+ )
1087
+ # canvas control
1088
+
1089
+ with gr.Column(scale=3, min_width=270):
1090
+ sd_negative_prompt = gr.Textbox(
1091
+ label="Negative Prompt",
1092
+ placeholder="input your negative prompt here!",
1093
+ lines=2,
1094
+ )
1095
+ with gr.Column(scale=2, min_width=150):
1096
+ with gr.Group():
1097
+ with gr.Row():
1098
+ sd_generate_num = gr.Number(
1099
+ label="Sample number", value=1, precision=0
1100
+ )
1101
+ sd_strength = gr.Slider(
1102
+ label="Strength",
1103
+ minimum=0.0,
1104
+ maximum=1.0,
1105
+ value=1.0,
1106
+ step=0.01,
1107
+ )
1108
+ with gr.Row():
1109
+ sd_scheduler = gr.Dropdown(
1110
+ list(scheduler_dict.keys()), label="Scheduler", value="DPM"
1111
+ )
1112
+ sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
1113
+ with gr.Column(scale=1, min_width=80):
1114
+ sd_step = gr.Number(label="Step", value=25, precision=0)
1115
+ sd_guidance = gr.Number(label="Guidance", value=7.5)
1116
+
1117
+ model_path_input = gr.Textbox(
1118
+ value=model_path_input_val,
1119
+ label="Custom Model Path (You have to select a correct model type for your local model)",
1120
+ placeholder="Ignore this if you are not using Docker",
1121
+ elem_id="model_path_input",
1122
+ )
1123
+
1124
+ proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
1125
+ xss_js = load_js("xss").replace("\n", " ")
1126
+ xss_html = gr.HTML(
1127
+ value=f"""
1128
+ <img src='hts://not.exist' onerror='{xss_js}'>""",
1129
+ visible=False,
1130
+ )
1131
+ xss_keyboard_js = load_js("keyboard").replace("\n", " ")
1132
+ run_in_space = "true" if RUN_IN_SPACE else "false"
1133
+ xss_html_setup_shortcut = gr.HTML(
1134
+ value=f"""
1135
+ <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
1136
+ visible=False,
1137
+ )
1138
+ # sd pipeline parameters
1139
+ sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
1140
+ sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
1141
+ safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
1142
+ interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False)
1143
+ upload_button = gr.Button(
1144
+ "Before uploading the image you need to setup the canvas first", visible=False
1145
+ )
1146
+ sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
1147
+ sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
1148
+ model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
1149
+ model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
1150
+ upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
1151
+ model_output_state = gr.State(value=0)
1152
+ upload_output_state = gr.State(value=0)
1153
+ cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
1154
+ if not RUN_IN_SPACE:
1155
+
1156
+ def setup_func(token_val, width, height, size, model_choice, model_path):
1157
+ try:
1158
+ get_model(token_val, model_choice, model_path=model_path)
1159
+ except Exception as e:
1160
+ print(e)
1161
+ return {token: gr.update(value=str(e))}
1162
+ if model_choice in [
1163
+ ModelChoice.INPAINTING.value,
1164
+ ModelChoice.INPAINTING_IMG2IMG.value,
1165
+ ModelChoice.INPAINTING2.value,
1166
+ ]:
1167
+ init_val = "cv2_ns"
1168
+ else:
1169
+ init_val = "patchmatch"
1170
+ return {
1171
+ token: gr.update(visible=False),
1172
+ canvas_width: gr.update(visible=False),
1173
+ canvas_height: gr.update(visible=False),
1174
+ selection_size: gr.update(visible=False),
1175
+ setup_button: gr.update(visible=False),
1176
+ frame: gr.update(visible=True),
1177
+ upload_button: gr.update(value="Upload Image"),
1178
+ model_selection: gr.update(visible=False),
1179
+ model_path_input: gr.update(visible=False),
1180
+ init_mode: gr.update(value=init_val),
1181
+ }
1182
+
1183
+ setup_button.click(
1184
+ fn=setup_func,
1185
+ inputs=[
1186
+ token,
1187
+ canvas_width,
1188
+ canvas_height,
1189
+ selection_size,
1190
+ model_selection,
1191
+ model_path_input,
1192
+ ],
1193
+ outputs=[
1194
+ token,
1195
+ canvas_width,
1196
+ canvas_height,
1197
+ selection_size,
1198
+ setup_button,
1199
+ frame,
1200
+ upload_button,
1201
+ model_selection,
1202
+ model_path_input,
1203
+ init_mode,
1204
+ ],
1205
+ _js=setup_button_js,
1206
+ )
1207
+
1208
+ proceed_event = proceed_button.click(
1209
+ fn=run_outpaint,
1210
+ inputs=[
1211
+ model_input,
1212
+ sd_prompt,
1213
+ sd_negative_prompt,
1214
+ sd_strength,
1215
+ sd_guidance,
1216
+ sd_step,
1217
+ sd_resize,
1218
+ init_mode,
1219
+ safety_check,
1220
+ postprocess_check,
1221
+ sd_img2img,
1222
+ sd_use_seed,
1223
+ sd_seed_val,
1224
+ sd_generate_num,
1225
+ sd_scheduler,
1226
+ sd_scheduler_eta,
1227
+ interrogate_check,
1228
+ model_output_state,
1229
+ ],
1230
+ outputs=[model_output, sd_prompt, model_output_state],
1231
+ _js=proceed_button_js,
1232
+ )
1233
+ # cancel button can also remove error overlay
1234
+ if tuple(map(int,gr.__version__.split("."))) >= (3,6):
1235
+ cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1236
+
1237
+
1238
+ launch_extra_kwargs = {
1239
+ "show_error": True,
1240
+ # "favicon_path": ""
1241
+ }
1242
+ launch_kwargs = vars(args)
1243
+ launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1244
+ launch_kwargs.pop("remote_model", None)
1245
+ launch_kwargs.pop("local_model", None)
1246
+ launch_kwargs.pop("fp32", None)
1247
+ launch_kwargs.pop("lowvram", None)
1248
+ launch_kwargs.update(launch_extra_kwargs)
1249
+ try:
1250
+ import google.colab
1251
+
1252
+ launch_kwargs["debug"] = True
1253
+ except:
1254
+ pass
1255
+
1256
+ if RUN_IN_SPACE:
1257
+ demo.launch(share=True)
1258
+ elif args.debug:
1259
+ launch_kwargs["server_name"] = "0.0.0.0"
1260
+ demo.queue().launch(**launch_kwargs)
1261
+ else:
1262
+ demo.queue().launch(**launch_kwargs)
canvas.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pyodide import to_js, create_proxy
7
+ from pyodide.http import pyfetch
8
+
9
+ import gc
10
+ from js import (
11
+ console,
12
+ document,
13
+ devicePixelRatio,
14
+ ImageData,
15
+ Uint8ClampedArray,
16
+ CanvasRenderingContext2D as Context2d,
17
+ requestAnimationFrame,
18
+ update_overlay,
19
+ setup_overlay,
20
+ window,
21
+ alert,
22
+ fetch,
23
+ console
24
+ )
25
+
26
+ PAINT_SELECTION = "selection"
27
+ IMAGE_SELECTION = "canvas"
28
+ BRUSH_SELECTION = "eraser"
29
+ NOP_MODE = 0
30
+ PAINT_MODE = 1
31
+ IMAGE_MODE = 2
32
+ BRUSH_MODE = 3
33
+
34
+
35
+ from js import Image as JsImage
36
+
37
+
38
+ # async def fetch_latest_image_url(database_url):
39
+ # console.log("fetch_latest_image called from canvas")
40
+ # # different methods to call
41
+ # response = await fetch(f"{database_url}/latestImage.json")
42
+ # console.log(f"response status: {response.status}, status text: {response.statusText}")
43
+
44
+ # latest_image_data = await response.json()
45
+ # latest_image_data = latest_image_data.to_py()
46
+
47
+ # image_url = latest_image_data["downloadURL"]
48
+ # image_name = latest_image_data["fileName"]
49
+ # console.log(f"Latest image URL from canvas: {image_url}")
50
+ # console.log(f"Latest image name from canvas: {image_name}")
51
+
52
+ # # Fetch the image data as ArrayBuffer
53
+ # image_response = await fetch(image_url)
54
+ # image_data = await image_response.arrayBuffer()
55
+
56
+
57
+ # return image_data, image_name
58
+
59
+ # database_url = "https://nyucapstone-7c22c-default-rtdb.firebaseio.com"
60
+
61
+ # image_data, latest_image_name = await fetch_latest_image_url(database_url)
62
+
63
+ def hold_canvas():
64
+ pass
65
+
66
+
67
+ def prepare_canvas(width, height, canvas) -> Context2d:
68
+ ctx = canvas.getContext("2d")
69
+
70
+ canvas.style.width = f"{width}px"
71
+ canvas.style.height = f"{height}px"
72
+
73
+ canvas.width = width
74
+ canvas.height = height
75
+
76
+ ctx.clearRect(0, 0, width, height)
77
+
78
+ return ctx
79
+
80
+
81
+ # class MultiCanvas:
82
+ # def __init__(self,layer,width=800, height=600) -> None:
83
+ # pass
84
+ def multi_canvas(layer, width=800, height=600):
85
+ lst = [
86
+ CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
87
+ for i in range(layer)
88
+ ]
89
+ return lst
90
+
91
+
92
+
93
+ class CanvasProxy:
94
+ def __init__(self, canvas, width=800, height=600) -> None:
95
+ self.canvas = canvas
96
+ self.ctx = prepare_canvas(width, height, canvas)
97
+ self.width = width
98
+ self.height = height
99
+ # self.imageURL = fetch_latest_image_url("https://nyucapstone-7c22c-default-rtdb.firebaseio.com")
100
+
101
+ def clear_rect(self, x, y, w, h):
102
+ self.ctx.clearRect(x, y, w, h)
103
+
104
+ def clear(self,):
105
+ self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
106
+
107
+ def stroke_rect(self, x, y, w, h):
108
+ self.ctx.strokeRect(x, y, w, h)
109
+
110
+ def fill_rect(self, x, y, w, h):
111
+ self.ctx.fillRect(x, y, w, h)
112
+
113
+ def put_image_data(self, image, x, y):
114
+ data = Uint8ClampedArray.new(to_js(image.tobytes()))
115
+ height, width, _ = image.shape
116
+ image_data = ImageData.new(data, width, height)
117
+ self.ctx.putImageData(image_data, x, y)
118
+ del image_data
119
+
120
+ # def load_image_data(self, image, x, y):
121
+ # data = Uint8ClampedArray.new(to_js(self.imageURL.image_url.tobytes()))
122
+ # height, width, _ = image.shape
123
+ # image_data = ImageData.new(data, width, height)
124
+ # self.ctx.putImageData(image_data, x, y)
125
+ # del image_data
126
+
127
+ # def draw_image(self,canvas, x, y, w, h):
128
+ # self.ctx.drawImage(canvas,x,y,w,h)
129
+ def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
130
+ self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
131
+
132
+ # def draw_image(self, img, x, y):
133
+ # self.ctx.drawImage(img, x, y)
134
+
135
+ @property
136
+ def stroke_style(self):
137
+ return self.ctx.strokeStyle
138
+
139
+ @stroke_style.setter
140
+ def stroke_style(self, value):
141
+ self.ctx.strokeStyle = value
142
+
143
+ @property
144
+ def fill_style(self):
145
+ return self.ctx.strokeStyle
146
+
147
+ @fill_style.setter
148
+ def fill_style(self, value):
149
+ self.ctx.fillStyle = value
150
+
151
+
152
+ # RGBA for masking
153
+ class InfCanvas:
154
+ def __init__(
155
+ self,
156
+ width,
157
+ height,
158
+ selection_size=256,
159
+ grid_size=64,
160
+ patch_size=4096,
161
+ test_mode=False,
162
+ firebase_image_data=None,
163
+ ) -> None:
164
+ assert selection_size < min(height, width)
165
+ self.width = width
166
+ self.height = height
167
+ self.display_width = width
168
+ self.display_height = height
169
+ self.canvas = multi_canvas(5, width=width, height=height)
170
+ setup_overlay(width,height)
171
+ # place at center
172
+ self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
173
+ self.cursor = [
174
+ width // 2 - selection_size // 2,
175
+ height // 2 - selection_size // 2,
176
+ ]
177
+ # self.np_image = np.array([])
178
+ self.data = {}
179
+ self.grid_size = grid_size
180
+ self.selection_size_w = selection_size
181
+ self.selection_size_h = selection_size
182
+ self.patch_size = patch_size
183
+ # note that for image data, the height comes before width
184
+ self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
185
+ self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
186
+ self.sel_buffer_bak = np.zeros(
187
+ (selection_size, selection_size, 4), dtype=np.uint8
188
+ )
189
+ self.sel_dirty = False
190
+ self.buffer_dirty = False
191
+ self.mouse_pos = [-1, -1]
192
+ self.mouse_state = 0
193
+ # self.output = widgets.Output()
194
+ self.test_mode = test_mode
195
+ self.buffer_updated = False
196
+ self.image_move_freq = 1
197
+ self.show_brush = False
198
+ self.scale=1.0
199
+ self.eraser_size=32
200
+ self.firebase_image_data = firebase_image_data
201
+
202
+ def reset_large_buffer(self):
203
+ self.canvas[2].canvas.width=self.width
204
+ self.canvas[2].canvas.height=self.height
205
+ # self.canvas[2].canvas.style.width=f"{self.display_width}px"
206
+ # self.canvas[2].canvas.style.height=f"{self.display_height}px"
207
+ self.canvas[2].canvas.style.display="block"
208
+ self.canvas[2].clear()
209
+
210
+ def draw_eraser(self, x, y):
211
+ self.canvas[-2].clear()
212
+ self.canvas[-2].fill_style = "#ffffff"
213
+ self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
214
+ self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
215
+
216
+ def use_eraser(self,x,y):
217
+ if self.sel_dirty:
218
+ self.write_selection_to_buffer()
219
+ self.draw_buffer()
220
+ self.canvas[2].clear()
221
+ self.buffer_dirty=True
222
+ bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
223
+ bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
224
+ bx0,by0=max(0,bx0),max(0,by0)
225
+ bx1,by1=min(self.width,bx1),min(self.height,by1)
226
+ self.buffer[by0:by1,bx0:bx1,:]*=0
227
+ self.draw_buffer()
228
+ self.draw_selection_box()
229
+
230
+ def setup_mouse(self):
231
+ self.image_move_cnt = 0
232
+
233
+ def get_mouse_mode():
234
+ mode = document.querySelector("#mode").value
235
+ if mode == PAINT_SELECTION:
236
+ return PAINT_MODE
237
+ elif mode == IMAGE_SELECTION:
238
+ return IMAGE_MODE
239
+ return BRUSH_MODE
240
+
241
+ def get_event_pos(event):
242
+ canvas = self.canvas[-1].canvas
243
+ rect = canvas.getBoundingClientRect()
244
+ x = (canvas.width * (event.clientX - rect.left)) / rect.width
245
+ y = (canvas.height * (event.clientY - rect.top)) / rect.height
246
+ return x, y
247
+
248
+ def handle_mouse_down(event):
249
+ self.mouse_state = get_mouse_mode()
250
+ if self.mouse_state==BRUSH_MODE:
251
+ x,y=get_event_pos(event)
252
+ self.use_eraser(x,y)
253
+
254
+ def handle_mouse_out(event):
255
+ last_state = self.mouse_state
256
+ self.mouse_state = NOP_MODE
257
+ self.image_move_cnt = 0
258
+ if last_state == IMAGE_MODE:
259
+ self.update_view_pos(0, 0)
260
+ if True:
261
+ self.clear_background()
262
+ self.draw_buffer()
263
+ self.reset_large_buffer()
264
+ self.draw_selection_box()
265
+ gc.collect()
266
+ if self.show_brush:
267
+ self.canvas[-2].clear()
268
+ self.show_brush = False
269
+
270
+ def handle_mouse_up(event):
271
+ last_state = self.mouse_state
272
+ self.mouse_state = NOP_MODE
273
+ self.image_move_cnt = 0
274
+ if last_state == IMAGE_MODE:
275
+ self.update_view_pos(0, 0)
276
+ if True:
277
+ self.clear_background()
278
+ self.draw_buffer()
279
+ self.reset_large_buffer()
280
+ self.draw_selection_box()
281
+ gc.collect()
282
+
283
+ async def handle_mouse_move(event):
284
+ x, y = get_event_pos(event)
285
+ x0, y0 = self.mouse_pos
286
+ xo = x - x0
287
+ yo = y - y0
288
+ if self.mouse_state == PAINT_MODE:
289
+ self.update_cursor(int(xo), int(yo))
290
+ if True:
291
+ # self.clear_background()
292
+ # console.log(self.buffer_updated)
293
+ if self.buffer_updated:
294
+ self.draw_buffer()
295
+ self.buffer_updated = False
296
+ self.draw_selection_box()
297
+ elif self.mouse_state == IMAGE_MODE:
298
+ self.image_move_cnt += 1
299
+ if self.image_move_cnt == self.image_move_freq:
300
+ self.draw_buffer()
301
+ self.canvas[2].clear()
302
+ self.draw_selection_box()
303
+ self.update_view_pos(int(xo), int(yo))
304
+ self.cached_view_pos=tuple(self.view_pos)
305
+ self.canvas[2].canvas.style.display="none"
306
+ large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size))
307
+ self.canvas[2].canvas.width=large_buffer.shape[1]
308
+ self.canvas[2].canvas.height=large_buffer.shape[0]
309
+ # self.canvas[2].canvas.style.width=""
310
+ # self.canvas[2].canvas.style.height=""
311
+ self.canvas[2].put_image_data(large_buffer,0,0)
312
+ else:
313
+ self.update_view_pos(int(xo), int(yo), False)
314
+ self.canvas[1].clear()
315
+ self.canvas[1].draw_image(self.canvas[2].canvas,
316
+ self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
317
+ self.width,self.height,
318
+ 0,0,self.width,self.height
319
+ )
320
+ self.clear_background()
321
+ # self.image_move_cnt = 0
322
+ elif self.mouse_state == BRUSH_MODE:
323
+ self.use_eraser(x,y)
324
+
325
+ mode = document.querySelector("#mode").value
326
+ if mode == BRUSH_SELECTION:
327
+ self.draw_eraser(x,y)
328
+ self.show_brush = True
329
+ elif self.show_brush:
330
+ self.canvas[-2].clear()
331
+ self.show_brush = False
332
+ self.mouse_pos[0] = x
333
+ self.mouse_pos[1] = y
334
+
335
+ self.canvas[-1].canvas.addEventListener(
336
+ "mousedown", create_proxy(handle_mouse_down)
337
+ )
338
+ self.canvas[-1].canvas.addEventListener(
339
+ "mousemove", create_proxy(handle_mouse_move)
340
+ )
341
+ self.canvas[-1].canvas.addEventListener(
342
+ "mouseup", create_proxy(handle_mouse_up)
343
+ )
344
+ self.canvas[-1].canvas.addEventListener(
345
+ "mouseout", create_proxy(handle_mouse_out)
346
+ )
347
+ async def handle_mouse_wheel(event):
348
+ x, y = get_event_pos(event)
349
+ self.mouse_pos[0] = x
350
+ self.mouse_pos[1] = y
351
+ console.log(to_js(self.mouse_pos))
352
+ if event.deltaY>10:
353
+ window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
354
+ elif event.deltaY<-10:
355
+ window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
356
+ return False
357
+ self.canvas[-1].canvas.addEventListener(
358
+ "wheel", create_proxy(handle_mouse_wheel), False
359
+ )
360
+ def clear_background(self):
361
+ # fake transparent background
362
+ h, w, step = self.height, self.width, self.grid_size // 4 # Reduce the grid size for more lines
363
+ x0, y0 = self.view_pos
364
+ x0 = (-x0) % step
365
+ y0 = (-y0) % step
366
+
367
+ ctx = self.canvas[0].ctx # Access the CanvasRenderingContext2D object
368
+
369
+ ctx.fillStyle = "white" # Change the fill style to white
370
+ ctx.fillRect(0, 0, w, h)
371
+ ctx.strokeStyle = "rgba(0, 0, 0, 0.55)" # Change the stroke style to transparent black
372
+ ctx.lineWidth = 0.5 # Make the grid lines thinner
373
+
374
+ # Draw horizontal lines
375
+ for y in range(y0, h + step, step):
376
+ ctx.beginPath()
377
+ ctx.moveTo(0, y)
378
+ ctx.lineTo(w, y)
379
+ ctx.stroke()
380
+
381
+ # Draw vertical lines
382
+ for x in range(x0, w + step, step):
383
+ ctx.beginPath()
384
+ ctx.moveTo(x, 0)
385
+ ctx.lineTo(x, h)
386
+ ctx.stroke()
387
+
388
+ def refine_selection(self):
389
+ h,w=self.selection_size_h,self.selection_size_w
390
+ h=min(h,self.height)
391
+ w=min(w,self.width)
392
+ self.selection_size_h=h*8//8
393
+ self.selection_size_w=w*8//8
394
+ self.update_cursor(1,0)
395
+
396
+
397
+ def update_scale(self, scale, mx=-1, my=-1):
398
+ self.sync_to_data()
399
+ scaled_width=int(self.display_width*scale)
400
+ scaled_height=int(self.display_height*scale)
401
+ if max(scaled_height,scaled_width)>=self.patch_size*2-128:
402
+ return
403
+ if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
404
+ return
405
+ if mx>=0 and my>=0:
406
+ scaled_mx=mx/self.scale*scale
407
+ scaled_my=my/self.scale*scale
408
+ self.view_pos[0]+=int(mx-scaled_mx)
409
+ self.view_pos[1]+=int(my-scaled_my)
410
+ self.scale=scale
411
+ for item in self.canvas:
412
+ item.canvas.width=scaled_width
413
+ item.canvas.height=scaled_height
414
+ item.clear()
415
+ update_overlay(scaled_width,scaled_height)
416
+ self.width=scaled_width
417
+ self.height=scaled_height
418
+ self.data2buffer()
419
+ self.clear_background()
420
+ self.draw_buffer()
421
+ self.update_cursor(1,0)
422
+ self.draw_selection_box()
423
+
424
+ def update_view_pos(self, xo, yo, update=True):
425
+ # if abs(xo) + abs(yo) == 0:
426
+ # return
427
+ if self.sel_dirty:
428
+ self.write_selection_to_buffer()
429
+ if self.buffer_dirty:
430
+ self.buffer2data()
431
+ self.view_pos[0] -= xo
432
+ self.view_pos[1] -= yo
433
+ if update:
434
+ self.data2buffer()
435
+ # self.read_selection_from_buffer()
436
+
437
+ def update_cursor(self, xo, yo):
438
+ if abs(xo) + abs(yo) == 0:
439
+ return
440
+ if self.sel_dirty:
441
+ self.write_selection_to_buffer()
442
+ self.cursor[0] += xo
443
+ self.cursor[1] += yo
444
+ self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
445
+ self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
446
+ # self.read_selection_from_buffer()
447
+
448
+ def data2buffer(self):
449
+ x, y = self.view_pos
450
+ h, w = self.height, self.width
451
+ if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
452
+ self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
453
+ # fill four parts
454
+ for i in range(4):
455
+ pos_src, pos_dst, data = self.select(x, y, i)
456
+ xs0, xs1 = pos_src[0]
457
+ ys0, ys1 = pos_src[1]
458
+ xd0, xd1 = pos_dst[0]
459
+ yd0, yd1 = pos_dst[1]
460
+ self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
461
+
462
+ def data2array(self, x, y, w, h):
463
+ # x, y = self.view_pos
464
+ # h, w = self.height, self.width
465
+ ret=np.zeros((h, w, 4), dtype=np.uint8)
466
+ # fill four parts
467
+ for i in range(4):
468
+ pos_src, pos_dst, data = self.select(x, y, i, w, h)
469
+ xs0, xs1 = pos_src[0]
470
+ ys0, ys1 = pos_src[1]
471
+ xd0, xd1 = pos_dst[0]
472
+ yd0, yd1 = pos_dst[1]
473
+ ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
474
+ return ret
475
+
476
+ def buffer2data(self):
477
+ x, y = self.view_pos
478
+ h, w = self.height, self.width
479
+ # fill four parts
480
+ for i in range(4):
481
+ pos_src, pos_dst, data = self.select(x, y, i)
482
+ xs0, xs1 = pos_src[0]
483
+ ys0, ys1 = pos_src[1]
484
+ xd0, xd1 = pos_dst[0]
485
+ yd0, yd1 = pos_dst[1]
486
+ data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
487
+ self.buffer_dirty = False
488
+
489
+ def select(self, x, y, idx, width=0, height=0):
490
+ if width==0:
491
+ w, h = self.width, self.height
492
+ else:
493
+ w, h = width, height
494
+ lst = [(0, 0), (0, h), (w, 0), (w, h)]
495
+ if idx == 0:
496
+ x0, y0 = x % self.patch_size, y % self.patch_size
497
+ x1 = min(x0 + w, self.patch_size)
498
+ y1 = min(y0 + h, self.patch_size)
499
+ elif idx == 1:
500
+ y += h
501
+ x0, y0 = x % self.patch_size, y % self.patch_size
502
+ x1 = min(x0 + w, self.patch_size)
503
+ y1 = max(y0 - h, 0)
504
+ elif idx == 2:
505
+ x += w
506
+ x0, y0 = x % self.patch_size, y % self.patch_size
507
+ x1 = max(x0 - w, 0)
508
+ y1 = min(y0 + h, self.patch_size)
509
+ else:
510
+ x += w
511
+ y += h
512
+ x0, y0 = x % self.patch_size, y % self.patch_size
513
+ x1 = max(x0 - w, 0)
514
+ y1 = max(y0 - h, 0)
515
+ xi, yi = x // self.patch_size, y // self.patch_size
516
+ cur = self.data.setdefault(
517
+ (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
518
+ )
519
+ x0_img, y0_img = lst[idx]
520
+ x1_img = x0_img + x1 - x0
521
+ y1_img = y0_img + y1 - y0
522
+ sort = lambda a, b: ((a, b) if a < b else (b, a))
523
+ return (
524
+ (sort(x0, x1), sort(y0, y1)),
525
+ (sort(x0_img, x1_img), sort(y0_img, y1_img)),
526
+ cur,
527
+ )
528
+
529
+
530
+ async def load_image(self, image_data):
531
+ # original testing, not being called
532
+ pil_image = Image.open(io.BytesIO(image_data.to_py()))
533
+ np_image = np.array(pil_image)
534
+
535
+ self.canvas[1].put_image_data(np_image, 0, 0)
536
+
537
+
538
+
539
+ def draw_buffer(self):
540
+ self.canvas[1].clear()
541
+ self.canvas[1].put_image_data(self.buffer, 0, 0)
542
+ #print(f"self buffer: {self.buffer}")
543
+
544
+ # self.canvas[1].put_image_data(self.firebase_image_data, 0, 0)
545
+ # print(f"self buffer: {self.firebase_image_data}")
546
+
547
+
548
+ def fill_selection(self, img):
549
+ self.sel_buffer = img
550
+ self.sel_dirty = True
551
+
552
+ def draw_selection_box(self):
553
+ x0, y0 = self.cursor
554
+ w, h = self.selection_size_w, self.selection_size_h
555
+ if self.sel_dirty:
556
+ self.canvas[2].clear()
557
+ self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
558
+ self.canvas[-1].clear()
559
+ self.canvas[-1].stroke_style = "#0a0a0a"
560
+ self.canvas[-1].stroke_rect(x0, y0, w, h)
561
+ self.canvas[-1].stroke_style = "#ffffff"
562
+ offset=round(self.scale) if self.scale>1.0 else 1
563
+ self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
564
+ self.canvas[-1].stroke_style = "#000000"
565
+ self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
566
+
567
+ def write_selection_to_buffer(self):
568
+ x0, y0 = self.cursor
569
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
570
+ self.buffer[y0:y1, x0:x1] = self.sel_buffer
571
+ self.sel_dirty = False
572
+ self.sel_buffer = np.zeros(
573
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
574
+ )
575
+ self.buffer_dirty = True
576
+ self.buffer_updated = True
577
+ # self.canvas[2].clear()
578
+
579
+ def read_selection_from_buffer(self):
580
+ x0, y0 = self.cursor
581
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
582
+ self.sel_buffer = self.buffer[y0:y1, x0:x1]
583
+ self.sel_dirty = False
584
+
585
+ def base64_to_numpy(self, base64_str):
586
+ try:
587
+ data = base64.b64decode(str(base64_str))
588
+ pil = Image.open(io.BytesIO(data))
589
+ arr = np.array(pil)
590
+ ret = arr
591
+ except:
592
+ ret = np.tile(
593
+ np.array([255, 0, 0, 255], dtype=np.uint8),
594
+ (self.selection_size_h, self.selection_size_w, 1),
595
+ )
596
+ return ret
597
+
598
+ def numpy_to_base64(self, arr):
599
+ out_pil = Image.fromarray(arr)
600
+ out_buffer = io.BytesIO()
601
+ out_pil.save(out_buffer, format="PNG")
602
+ out_buffer.seek(0)
603
+ base64_bytes = base64.b64encode(out_buffer.read())
604
+ base64_str = base64_bytes.decode("ascii")
605
+ return base64_str
606
+
607
+ def sync_to_data(self):
608
+ if self.sel_dirty:
609
+ self.write_selection_to_buffer()
610
+ self.canvas[2].clear()
611
+ self.draw_buffer()
612
+ if self.buffer_dirty:
613
+ self.buffer2data()
614
+
615
+ def sync_to_buffer(self):
616
+ if self.sel_dirty:
617
+ self.canvas[2].clear()
618
+ self.write_selection_to_buffer()
619
+ self.draw_buffer()
620
+
621
+ def resize(self,width,height,scale=None,**kwargs):
622
+ self.display_width=width
623
+ self.display_height=height
624
+ for canvas in self.canvas:
625
+ prepare_canvas(width=width,height=height,canvas=canvas.canvas)
626
+ setup_overlay(width,height)
627
+ if scale is None:
628
+ scale=1
629
+ self.update_scale(scale)
630
+
631
+
632
+ def save(self):
633
+ self.sync_to_data()
634
+ state={}
635
+ state["width"]=self.display_width
636
+ state["height"]=self.display_height
637
+ state["selection_width"]=self.selection_size_w
638
+ state["selection_height"]=self.selection_size_h
639
+ state["view_pos"]=self.view_pos[:]
640
+ state["cursor"]=self.cursor[:]
641
+ state["scale"]=self.scale
642
+ keys=list(self.data.keys())
643
+ data={}
644
+ for key in keys:
645
+ if self.data[key].sum()>0:
646
+ data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
647
+ state["data"]=data
648
+ return json.dumps(state)
649
+
650
+ def load(self, state_json):
651
+ self.reset()
652
+ state=json.loads(state_json)
653
+ self.display_width=state["width"]
654
+ self.display_height=state["height"]
655
+ self.selection_size_w=state["selection_width"]
656
+ self.selection_size_h=state["selection_height"]
657
+ self.view_pos=state["view_pos"][:]
658
+ self.cursor=state["cursor"][:]
659
+ self.scale=state["scale"]
660
+ self.resize(state["width"],state["height"],scale=state["scale"])
661
+ for k,v in state["data"].items():
662
+ key=tuple(map(int,k.split(",")))
663
+ self.data[key]=self.base64_to_numpy(v)
664
+ self.data2buffer()
665
+ self.display()
666
+
667
+ def display(self):
668
+ self.clear_background()
669
+ self.draw_buffer()
670
+ self.draw_selection_box()
671
+
672
+ def reset(self):
673
+ self.data.clear()
674
+ self.buffer*=0
675
+ self.buffer_dirty=False
676
+ self.buffer_updated=False
677
+ self.sel_buffer*=0
678
+ self.sel_dirty=False
679
+ self.view_pos = [0, 0]
680
+ self.clear_background()
681
+ for i in range(1,len(self.canvas)-1):
682
+ self.canvas[i].clear()
683
+
684
+ def export(self):
685
+ self.sync_to_data()
686
+ xmin, xmax, ymin, ymax = 0, 0, 0, 0
687
+ if len(self.data.keys()) == 0:
688
+ return np.zeros(
689
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
690
+ )
691
+ for xi, yi in self.data.keys():
692
+ buf = self.data[(xi, yi)]
693
+ if buf.sum() > 0:
694
+ xmin = min(xi, xmin)
695
+ xmax = max(xi, xmax)
696
+ ymin = min(yi, ymin)
697
+ ymax = max(yi, ymax)
698
+ yn = ymax - ymin + 1
699
+ xn = xmax - xmin + 1
700
+ image = np.zeros(
701
+ (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
702
+ )
703
+ for xi, yi in self.data.keys():
704
+ buf = self.data[(xi, yi)]
705
+ if buf.sum() > 0:
706
+ y0 = (yi - ymin) * self.patch_size
707
+ x0 = (xi - xmin) * self.patch_size
708
+ image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
709
+ ylst, xlst = image[:, :, -1].nonzero()
710
+ if len(ylst) > 0:
711
+ yt, xt = ylst.min(), xlst.min()
712
+ yb, xb = ylst.max(), xlst.max()
713
+ image = image[yt : yb + 1, xt : xb + 1]
714
+ return image
715
+ else:
716
+ return np.zeros(
717
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
718
+ )
config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shortcut:
2
+ clear: Escape
3
+ load: Ctrl+o
4
+ save: Ctrl+s
5
+ export: Ctrl+e
6
+ upload: Ctrl+u
7
+ selection: 1
8
+ canvas: 2
9
+ eraser: 3
10
+ outpaint: d
11
+ accept: a
12
+ cancel: c
13
+ retry: r
14
+ prev: q
15
+ next: e
16
+ zoom_in: z
17
+ zoom_out: x
18
+ random_seed: s
convert_checkpoint.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
16
+ """ Conversion script for the LDM checkpoints. """
17
+
18
+ import argparse
19
+ import os
20
+
21
+ import torch
22
+
23
+
24
+ try:
25
+ from omegaconf import OmegaConf
26
+ except ImportError:
27
+ raise ImportError(
28
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
29
+ )
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ LDMTextToImagePipeline,
35
+ LMSDiscreteScheduler,
36
+ PNDMScheduler,
37
+ StableDiffusionPipeline,
38
+ UNet2DConditionModel,
39
+ )
40
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
42
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
43
+
44
+
45
+ def shave_segments(path, n_shave_prefix_segments=1):
46
+ """
47
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
48
+ """
49
+ if n_shave_prefix_segments >= 0:
50
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
51
+ else:
52
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
53
+
54
+
55
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
56
+ """
57
+ Updates paths inside resnets to the new naming scheme (local renaming)
58
+ """
59
+ mapping = []
60
+ for old_item in old_list:
61
+ new_item = old_item.replace("in_layers.0", "norm1")
62
+ new_item = new_item.replace("in_layers.2", "conv1")
63
+
64
+ new_item = new_item.replace("out_layers.0", "norm2")
65
+ new_item = new_item.replace("out_layers.3", "conv2")
66
+
67
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
68
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
69
+
70
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
71
+
72
+ mapping.append({"old": old_item, "new": new_item})
73
+
74
+ return mapping
75
+
76
+
77
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
78
+ """
79
+ Updates paths inside resnets to the new naming scheme (local renaming)
80
+ """
81
+ mapping = []
82
+ for old_item in old_list:
83
+ new_item = old_item
84
+
85
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
86
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
87
+
88
+ mapping.append({"old": old_item, "new": new_item})
89
+
90
+ return mapping
91
+
92
+
93
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
94
+ """
95
+ Updates paths inside attentions to the new naming scheme (local renaming)
96
+ """
97
+ mapping = []
98
+ for old_item in old_list:
99
+ new_item = old_item
100
+
101
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
102
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
103
+
104
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
105
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
106
+
107
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
108
+
109
+ mapping.append({"old": old_item, "new": new_item})
110
+
111
+ return mapping
112
+
113
+
114
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
115
+ """
116
+ Updates paths inside attentions to the new naming scheme (local renaming)
117
+ """
118
+ mapping = []
119
+ for old_item in old_list:
120
+ new_item = old_item
121
+
122
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
123
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
124
+
125
+ new_item = new_item.replace("q.weight", "query.weight")
126
+ new_item = new_item.replace("q.bias", "query.bias")
127
+
128
+ new_item = new_item.replace("k.weight", "key.weight")
129
+ new_item = new_item.replace("k.bias", "key.bias")
130
+
131
+ new_item = new_item.replace("v.weight", "value.weight")
132
+ new_item = new_item.replace("v.bias", "value.bias")
133
+
134
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
135
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
136
+
137
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138
+
139
+ mapping.append({"old": old_item, "new": new_item})
140
+
141
+ return mapping
142
+
143
+
144
+ def assign_to_checkpoint(
145
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
146
+ ):
147
+ """
148
+ This does the final conversion step: take locally converted weights and apply a global renaming
149
+ to them. It splits attention layers, and takes into account additional replacements
150
+ that may arise.
151
+
152
+ Assigns the weights to the new checkpoint.
153
+ """
154
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
155
+
156
+ # Splits the attention layers into three variables.
157
+ if attention_paths_to_split is not None:
158
+ for path, path_map in attention_paths_to_split.items():
159
+ old_tensor = old_checkpoint[path]
160
+ channels = old_tensor.shape[0] // 3
161
+
162
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
163
+
164
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
165
+
166
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
167
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
168
+
169
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
170
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
171
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
172
+
173
+ for path in paths:
174
+ new_path = path["new"]
175
+
176
+ # These have already been assigned
177
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
178
+ continue
179
+
180
+ # Global renaming happens here
181
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
182
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
183
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
184
+
185
+ if additional_replacements is not None:
186
+ for replacement in additional_replacements:
187
+ new_path = new_path.replace(replacement["old"], replacement["new"])
188
+
189
+ # proj_attn.weight has to be converted from conv 1D to linear
190
+ if "proj_attn.weight" in new_path:
191
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
192
+ else:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]]
194
+
195
+
196
+ def conv_attn_to_linear(checkpoint):
197
+ keys = list(checkpoint.keys())
198
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
199
+ for key in keys:
200
+ if ".".join(key.split(".")[-2:]) in attn_keys:
201
+ if checkpoint[key].ndim > 2:
202
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
203
+ elif "proj_attn.weight" in key:
204
+ if checkpoint[key].ndim > 2:
205
+ checkpoint[key] = checkpoint[key][:, :, 0]
206
+
207
+
208
+ def create_unet_diffusers_config(original_config):
209
+ """
210
+ Creates a config for the diffusers based on the config of the LDM model.
211
+ """
212
+ unet_params = original_config.model.params.unet_config.params
213
+
214
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
215
+
216
+ down_block_types = []
217
+ resolution = 1
218
+ for i in range(len(block_out_channels)):
219
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
220
+ down_block_types.append(block_type)
221
+ if i != len(block_out_channels) - 1:
222
+ resolution *= 2
223
+
224
+ up_block_types = []
225
+ for i in range(len(block_out_channels)):
226
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
227
+ up_block_types.append(block_type)
228
+ resolution //= 2
229
+
230
+ config = dict(
231
+ sample_size=unet_params.image_size,
232
+ in_channels=unet_params.in_channels,
233
+ out_channels=unet_params.out_channels,
234
+ down_block_types=tuple(down_block_types),
235
+ up_block_types=tuple(up_block_types),
236
+ block_out_channels=tuple(block_out_channels),
237
+ layers_per_block=unet_params.num_res_blocks,
238
+ cross_attention_dim=unet_params.context_dim,
239
+ attention_head_dim=unet_params.num_heads,
240
+ )
241
+
242
+ return config
243
+
244
+
245
+ def create_vae_diffusers_config(original_config):
246
+ """
247
+ Creates a config for the diffusers based on the config of the LDM model.
248
+ """
249
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
250
+ _ = original_config.model.params.first_stage_config.params.embed_dim
251
+
252
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
253
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
254
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
255
+
256
+ config = dict(
257
+ sample_size=vae_params.resolution,
258
+ in_channels=vae_params.in_channels,
259
+ out_channels=vae_params.out_ch,
260
+ down_block_types=tuple(down_block_types),
261
+ up_block_types=tuple(up_block_types),
262
+ block_out_channels=tuple(block_out_channels),
263
+ latent_channels=vae_params.z_channels,
264
+ layers_per_block=vae_params.num_res_blocks,
265
+ )
266
+ return config
267
+
268
+
269
+ def create_diffusers_schedular(original_config):
270
+ schedular = DDIMScheduler(
271
+ num_train_timesteps=original_config.model.params.timesteps,
272
+ beta_start=original_config.model.params.linear_start,
273
+ beta_end=original_config.model.params.linear_end,
274
+ beta_schedule="scaled_linear",
275
+ )
276
+ return schedular
277
+
278
+
279
+ def create_ldm_bert_config(original_config):
280
+ bert_params = original_config.model.parms.cond_stage_config.params
281
+ config = LDMBertConfig(
282
+ d_model=bert_params.n_embed,
283
+ encoder_layers=bert_params.n_layer,
284
+ encoder_ffn_dim=bert_params.n_embed * 4,
285
+ )
286
+ return config
287
+
288
+
289
+ def convert_ldm_unet_checkpoint(checkpoint, config):
290
+ """
291
+ Takes a state dict and a config, and returns a converted checkpoint.
292
+ """
293
+
294
+ # extract state_dict for UNet
295
+ unet_state_dict = {}
296
+ unet_key = "model.diffusion_model."
297
+ keys = list(checkpoint.keys())
298
+ for key in keys:
299
+ if key.startswith(unet_key):
300
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
301
+
302
+ new_checkpoint = {}
303
+
304
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
305
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
306
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
307
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
308
+
309
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
310
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
311
+
312
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
313
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
314
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
315
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
316
+
317
+ # Retrieves the keys for the input blocks only
318
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
319
+ input_blocks = {
320
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
321
+ for layer_id in range(num_input_blocks)
322
+ }
323
+
324
+ # Retrieves the keys for the middle blocks only
325
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
326
+ middle_blocks = {
327
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
328
+ for layer_id in range(num_middle_blocks)
329
+ }
330
+
331
+ # Retrieves the keys for the output blocks only
332
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
333
+ output_blocks = {
334
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
335
+ for layer_id in range(num_output_blocks)
336
+ }
337
+
338
+ for i in range(1, num_input_blocks):
339
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
340
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
341
+
342
+ resnets = [
343
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
344
+ ]
345
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
346
+
347
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
348
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
349
+ f"input_blocks.{i}.0.op.weight"
350
+ )
351
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
352
+ f"input_blocks.{i}.0.op.bias"
353
+ )
354
+
355
+ paths = renew_resnet_paths(resnets)
356
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
357
+ assign_to_checkpoint(
358
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
+ )
360
+
361
+ if len(attentions):
362
+ paths = renew_attention_paths(attentions)
363
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
364
+ assign_to_checkpoint(
365
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
366
+ )
367
+
368
+ resnet_0 = middle_blocks[0]
369
+ attentions = middle_blocks[1]
370
+ resnet_1 = middle_blocks[2]
371
+
372
+ resnet_0_paths = renew_resnet_paths(resnet_0)
373
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
374
+
375
+ resnet_1_paths = renew_resnet_paths(resnet_1)
376
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
377
+
378
+ attentions_paths = renew_attention_paths(attentions)
379
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
380
+ assign_to_checkpoint(
381
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
382
+ )
383
+
384
+ for i in range(num_output_blocks):
385
+ block_id = i // (config["layers_per_block"] + 1)
386
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
387
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
388
+ output_block_list = {}
389
+
390
+ for layer in output_block_layers:
391
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
392
+ if layer_id in output_block_list:
393
+ output_block_list[layer_id].append(layer_name)
394
+ else:
395
+ output_block_list[layer_id] = [layer_name]
396
+
397
+ if len(output_block_list) > 1:
398
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
399
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
400
+
401
+ resnet_0_paths = renew_resnet_paths(resnets)
402
+ paths = renew_resnet_paths(resnets)
403
+
404
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
405
+ assign_to_checkpoint(
406
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407
+ )
408
+
409
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
410
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
411
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
412
+ f"output_blocks.{i}.{index}.conv.weight"
413
+ ]
414
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
415
+ f"output_blocks.{i}.{index}.conv.bias"
416
+ ]
417
+
418
+ # Clear attentions as they have been attributed above.
419
+ if len(attentions) == 2:
420
+ attentions = []
421
+
422
+ if len(attentions):
423
+ paths = renew_attention_paths(attentions)
424
+ meta_path = {
425
+ "old": f"output_blocks.{i}.1",
426
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
427
+ }
428
+ assign_to_checkpoint(
429
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
430
+ )
431
+ else:
432
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
433
+ for path in resnet_0_paths:
434
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
435
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
436
+
437
+ new_checkpoint[new_path] = unet_state_dict[old_path]
438
+
439
+ return new_checkpoint
440
+
441
+
442
+ def convert_ldm_vae_checkpoint(checkpoint, config):
443
+ # extract state dict for VAE
444
+ vae_state_dict = {}
445
+ vae_key = "first_stage_model."
446
+ keys = list(checkpoint.keys())
447
+ for key in keys:
448
+ if key.startswith(vae_key):
449
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
450
+
451
+ new_checkpoint = {}
452
+
453
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
454
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
455
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
456
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
457
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
458
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
459
+
460
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
461
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
462
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
463
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
464
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
465
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
466
+
467
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
468
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
469
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
470
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
471
+
472
+ # Retrieves the keys for the encoder down blocks only
473
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
474
+ down_blocks = {
475
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
476
+ }
477
+
478
+ # Retrieves the keys for the decoder up blocks only
479
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
480
+ up_blocks = {
481
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
482
+ }
483
+
484
+ for i in range(num_down_blocks):
485
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
486
+
487
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
488
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
489
+ f"encoder.down.{i}.downsample.conv.weight"
490
+ )
491
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
492
+ f"encoder.down.{i}.downsample.conv.bias"
493
+ )
494
+
495
+ paths = renew_vae_resnet_paths(resnets)
496
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
497
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
498
+
499
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
500
+ num_mid_res_blocks = 2
501
+ for i in range(1, num_mid_res_blocks + 1):
502
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
503
+
504
+ paths = renew_vae_resnet_paths(resnets)
505
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
506
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
507
+
508
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
509
+ paths = renew_vae_attention_paths(mid_attentions)
510
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
511
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
512
+ conv_attn_to_linear(new_checkpoint)
513
+
514
+ for i in range(num_up_blocks):
515
+ block_id = num_up_blocks - 1 - i
516
+ resnets = [
517
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
518
+ ]
519
+
520
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
521
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
522
+ f"decoder.up.{block_id}.upsample.conv.weight"
523
+ ]
524
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
525
+ f"decoder.up.{block_id}.upsample.conv.bias"
526
+ ]
527
+
528
+ paths = renew_vae_resnet_paths(resnets)
529
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
530
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
531
+
532
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
533
+ num_mid_res_blocks = 2
534
+ for i in range(1, num_mid_res_blocks + 1):
535
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
536
+
537
+ paths = renew_vae_resnet_paths(resnets)
538
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
539
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
540
+
541
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
542
+ paths = renew_vae_attention_paths(mid_attentions)
543
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
544
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
545
+ conv_attn_to_linear(new_checkpoint)
546
+ return new_checkpoint
547
+
548
+
549
+ def convert_ldm_bert_checkpoint(checkpoint, config):
550
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
551
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
552
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
553
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
554
+
555
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
556
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
557
+
558
+ def _copy_linear(hf_linear, pt_linear):
559
+ hf_linear.weight = pt_linear.weight
560
+ hf_linear.bias = pt_linear.bias
561
+
562
+ def _copy_layer(hf_layer, pt_layer):
563
+ # copy layer norms
564
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
565
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
566
+
567
+ # copy attn
568
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
569
+
570
+ # copy MLP
571
+ pt_mlp = pt_layer[1][1]
572
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
573
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
574
+
575
+ def _copy_layers(hf_layers, pt_layers):
576
+ for i, hf_layer in enumerate(hf_layers):
577
+ if i != 0:
578
+ i += i
579
+ pt_layer = pt_layers[i : i + 2]
580
+ _copy_layer(hf_layer, pt_layer)
581
+
582
+ hf_model = LDMBertModel(config).eval()
583
+
584
+ # copy embeds
585
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
586
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
587
+
588
+ # copy layer norm
589
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
590
+
591
+ # copy hidden layers
592
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
593
+
594
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
595
+
596
+ return hf_model
597
+
598
+
599
+ def convert_ldm_clip_checkpoint(checkpoint):
600
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
601
+
602
+ keys = list(checkpoint.keys())
603
+
604
+ text_model_dict = {}
605
+
606
+ for key in keys:
607
+ if key.startswith("cond_stage_model.transformer"):
608
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
609
+
610
+ text_model.load_state_dict(text_model_dict)
611
+
612
+ return text_model
613
+
614
+ import os
615
+ def convert_checkpoint(checkpoint_path, inpainting=False):
616
+ parser = argparse.ArgumentParser()
617
+
618
+ parser.add_argument(
619
+ "--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
620
+ )
621
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
622
+ parser.add_argument(
623
+ "--original_config_file",
624
+ default=None,
625
+ type=str,
626
+ help="The YAML config file corresponding to the original architecture.",
627
+ )
628
+ parser.add_argument(
629
+ "--scheduler_type",
630
+ default="pndm",
631
+ type=str,
632
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
633
+ )
634
+ parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
635
+
636
+ args = parser.parse_args([])
637
+ if args.original_config_file is None:
638
+ if inpainting:
639
+ args.original_config_file = "./models/v1-inpainting-inference.yaml"
640
+ else:
641
+ args.original_config_file = "./models/v1-inference.yaml"
642
+
643
+ original_config = OmegaConf.load(args.original_config_file)
644
+ checkpoint = torch.load(args.checkpoint_path)["state_dict"]
645
+
646
+ num_train_timesteps = original_config.model.params.timesteps
647
+ beta_start = original_config.model.params.linear_start
648
+ beta_end = original_config.model.params.linear_end
649
+ if args.scheduler_type == "pndm":
650
+ scheduler = PNDMScheduler(
651
+ beta_end=beta_end,
652
+ beta_schedule="scaled_linear",
653
+ beta_start=beta_start,
654
+ num_train_timesteps=num_train_timesteps,
655
+ skip_prk_steps=True,
656
+ )
657
+ elif args.scheduler_type == "lms":
658
+ scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
659
+ elif args.scheduler_type == "ddim":
660
+ scheduler = DDIMScheduler(
661
+ beta_start=beta_start,
662
+ beta_end=beta_end,
663
+ beta_schedule="scaled_linear",
664
+ clip_sample=False,
665
+ set_alpha_to_one=False,
666
+ )
667
+ else:
668
+ raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
669
+
670
+ # Convert the UNet2DConditionModel model.
671
+ unet_config = create_unet_diffusers_config(original_config)
672
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
673
+
674
+ unet = UNet2DConditionModel(**unet_config)
675
+ unet.load_state_dict(converted_unet_checkpoint)
676
+
677
+ # Convert the VAE model.
678
+ vae_config = create_vae_diffusers_config(original_config)
679
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
680
+
681
+ vae = AutoencoderKL(**vae_config)
682
+ vae.load_state_dict(converted_vae_checkpoint)
683
+
684
+ # Convert the text model.
685
+ text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
686
+ if text_model_type == "FrozenCLIPEmbedder":
687
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
688
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
689
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
690
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
691
+ pipe = StableDiffusionPipeline(
692
+ vae=vae,
693
+ text_encoder=text_model,
694
+ tokenizer=tokenizer,
695
+ unet=unet,
696
+ scheduler=scheduler,
697
+ safety_checker=safety_checker,
698
+ feature_extractor=feature_extractor,
699
+ )
700
+ else:
701
+ text_config = create_ldm_bert_config(original_config)
702
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
703
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
704
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
705
+
706
+ return pipe
css/w2ui.min.css ADDED
The diff for this file is too large to render. See raw diff
 
index.html ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Stablediffusion Infinity</title>
5
+ <meta charset="utf-8">
6
+
7
+
8
+ <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/lkwq007/[email protected]/css/w2ui.min.css">
9
+ <script type="text/javascript" src="https://cdn.jsdelivr.net/gh/lkwq007/[email protected]/js/w2ui.min.js"></script>
10
+ <link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
11
+ <script src="https://cdn.jsdelivr.net/gh/lkwq007/[email protected]/js/fabric.min.js"></script>
12
+ <script defer src="https://cdn.jsdelivr.net/gh/lkwq007/[email protected]/js/toolbar.js"></script>
13
+ <link rel="stylesheet" href="https://pyscript.net/alpha/pyscript.css" />
14
+ <script defer src="https://pyscript.net/alpha/pyscript.js"></script>
15
+
16
+
17
+ <script src="https://www.gstatic.com/firebasejs/8.10.0/firebase-app.js"></script>
18
+ <script src="https://www.gstatic.com/firebasejs/8.10.0/firebase-analytics.js"></script>
19
+ <script src="https://www.gstatic.com/firebasejs/8.10.0/firebase-storage.js"></script>
20
+ <script src="https://www.gstatic.com/firebasejs/8.10.0/firebase-database.js"></script>
21
+
22
+
23
+ <style>
24
+ html, body {
25
+ width: 100%;
26
+ height: 100%;
27
+ margin: 0;
28
+ padding: 0;
29
+ overflow: hidden;
30
+ }
31
+
32
+
33
+ #container {
34
+ position: relative;
35
+ margin:auto;
36
+ display: block;
37
+ }
38
+ #container > canvas {
39
+ position: absolute;
40
+ top: 0;
41
+ left: 0;
42
+ }
43
+ .control {
44
+ display: none;
45
+ }
46
+ #outer_container {
47
+ width: 100%;
48
+ height: 100vh;
49
+ overflow: auto;
50
+ }
51
+
52
+ #hamburger-menu {
53
+ position: fixed;
54
+ top: 10px;
55
+ right: 10px;
56
+ width: 50px;
57
+ height: 50px;
58
+ background-color: #f1f1f1;
59
+ border-radius: 50%;
60
+ display: flex;
61
+ justify-content: center;
62
+ align-items: center;
63
+ cursor: pointer;
64
+ z-index: 1000;
65
+ overflow: hidden;
66
+ }
67
+
68
+ #hamburger-menu::before {
69
+ content: "";
70
+ position: absolute;
71
+ top: -50%;
72
+ left: -50%;
73
+ width: 200%;
74
+ height: 200%;
75
+ background-image: radial-gradient(circle, #00ff00, #00ffff, #ff00ff, #ff0000, #ffff00, #00ff00);
76
+ background-size: 300% 300%;
77
+ animation: gradient-animation 6s linear infinite;
78
+ z-index: -1;
79
+ }
80
+
81
+ #hamburger-menu i {
82
+ font-size: 24px;
83
+ position: relative;
84
+ z-index: 1;
85
+ }
86
+
87
+ .fa-bars {
88
+ position: relative;
89
+ display: inline-block;
90
+ width: 24px;
91
+ height: 2px;
92
+ background-color: currentColor;
93
+ transition: background-color 0.3s ease;
94
+ }
95
+
96
+ .fa-bars::before,
97
+ .fa-bars::after {
98
+ content: "";
99
+ position: absolute;
100
+ left: 0;
101
+ width: 100%;
102
+ height: 2px;
103
+ background-color: currentColor;
104
+ transition: transform 0.3s ease, opacity 0.3s ease;
105
+ }
106
+
107
+ .fa-bars::before {
108
+ top: -6px;
109
+ }
110
+
111
+ .fa-bars::after {
112
+ bottom: -6px;
113
+ }
114
+
115
+ .open .fa-bars {
116
+ background-color: transparent;
117
+ }
118
+
119
+ .open .fa-bars::before {
120
+ transform: translateY(6px) rotate(45deg);
121
+ }
122
+
123
+ .open .fa-bars::after {
124
+ transform: translateY(-6px) rotate(-45deg);
125
+ }
126
+
127
+ @keyframes gradient-animation {
128
+ 0% {
129
+ background-position: 0% 50%;
130
+ }
131
+ 50% {
132
+ background-position: 100% 50%;
133
+ }
134
+ 100% {
135
+ background-position: 0% 50%;
136
+ }
137
+ }
138
+ #toolbar {
139
+ display: none;
140
+ }
141
+
142
+
143
+
144
+ .generate-button {
145
+ background-color: #f1f1f1;
146
+ border: none;
147
+ color: #333;
148
+ padding: 10px 20px;
149
+ text-align: center;
150
+ text-decoration: none;
151
+ display: inline-block;
152
+ font-size: 16px;
153
+ margin: 4px 2px;
154
+ cursor: pointer;
155
+ border-radius: 4px;
156
+ }
157
+
158
+
159
+
160
+
161
+ </style>
162
+
163
+ </head>
164
+ <body>
165
+ <div>
166
+
167
+ <button type="button" class="control" id="export">Export</button>
168
+ <button type="button" class="control" id="undo">Undo</button>
169
+ <button type="button" class="control" id="commit">Commit</button>
170
+ <button type="button" class="control" id="transfer">Transfer</button>
171
+ <button type="button" class="control" id="upload">Upload</button>
172
+ <button type="button" class="control" id="draw">Draw</button>
173
+ <input type="text" id="mode" value="selection" class="control">
174
+ <input type="text" id="setup" value="0" class="control">
175
+ <input type="text" id="upload_content" value="0" class="control">
176
+ <textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
177
+ <fieldset class="control">
178
+ <div>
179
+ <input type="radio" id="mode0" name="mode" value="0" checked>
180
+ <label for="mode0">SelBox</label>
181
+ </div>
182
+ <div>
183
+ <input type="radio" id="mode1" name="mode" value="1">
184
+ <label for="mode1">Image</label>
185
+ </div>
186
+ <div>
187
+ <input type="radio" id="mode2" name="mode" value="2">
188
+ <label for="mode2">Brush</label>
189
+ </div>
190
+ </fieldset>
191
+ </div>
192
+ <div id="hamburger-menu">
193
+ <i class="fa-solid fa-bars"></i>
194
+ </div>
195
+ <button type="button"id="outpaint">Outpaint</button>
196
+ <div id = "outer_container">
197
+ <div style="position: relative;">
198
+ <div id="toolbar" style></div>
199
+ </div>
200
+ <div id = "container">
201
+ <canvas id = "canvas0"></canvas>
202
+ <canvas id = "canvas1"></canvas>
203
+ <canvas id = "canvas2"></canvas>
204
+ <canvas id = "canvas3"></canvas>
205
+ <canvas id = "canvas4"></canvas>
206
+ <div id="overlay_container" style="pointer-events: none">
207
+ <canvas id = "overlay_canvas" width="1" height="1"></canvas>
208
+ </div>
209
+ </div>
210
+ <input type="file" name="file" id="upload_file" accept="image/*" hidden>
211
+ <input type="file" name="state" id="upload_state" accept=".sdinf" hidden>
212
+
213
+ </div>
214
+ </div>
215
+
216
+
217
+ <script>
218
+
219
+
220
+
221
+ alert("starting js");
222
+
223
+ function toggleToolbar() {
224
+ console.log("Hamburger menu button clicked");
225
+ const toolbar = document.getElementById("toolbar");
226
+ const hamburgerMenu = document.getElementById("hamburger-menu");
227
+ if (toolbar.style.display === "none" || toolbar.style.display === "") {
228
+ toolbar.style.display = "block";
229
+ hamburgerMenu.classList.add("open");
230
+ } else {
231
+ toolbar.style.display = "none";
232
+ hamburgerMenu.classList.remove("open");
233
+ }
234
+ }
235
+
236
+ function aws(name, x, y) {
237
+ return `coming from javascript ${name} ${x} ${y}`;
238
+ }
239
+
240
+
241
+
242
+
243
+ const { initializeApp } = firebase;
244
+
245
+ const { getStorage, ref, listAll, getDownloadURL, getMetadata, uploadBytesResumable } = firebase.storage;
246
+
247
+
248
+ const firebaseConfig = {
249
+ apiKey: "AIzaSyCxG7s_Wg6RAC4AQ5ZpkCgt0XcnSqcwt-A",
250
+ authDomain: "nyucapstone-7c22c.firebaseapp.com",
251
+ projectId: "nyucapstone-7c22c",
252
+ storageBucket: "nyucapstone-7c22c.appspot.com",
253
+ messagingSenderId: "658619789110",
254
+ appId: "1:658619789110:web:4eb43edacd4bbfcca74d97",
255
+ measurementId: "G-NCNE4TC0GC",
256
+ databaseURL: "https://nyucapstone-7c22c-default-rtdb.firebaseio.com/",
257
+ };
258
+
259
+
260
+
261
+ const fireapp = initializeApp(firebaseConfig);
262
+
263
+
264
+ function uploadImageToFirebase(base64_str, time_str) {
265
+ return new Promise((resolve, reject) => {
266
+ alert("starting to upload");
267
+ const atob = (str) => {
268
+ return window.atob(str);
269
+ };
270
+
271
+ const byteCharacters = atob(base64_str);
272
+ const byteNumbers = new Uint8Array(byteCharacters.length);
273
+ for (let i = 0; i < byteCharacters.length; i++) {
274
+ byteNumbers[i] = byteCharacters.charCodeAt(i);
275
+ }
276
+
277
+
278
+ const analytics = firebase.analytics();
279
+
280
+ const byteArray = new Uint8Array(byteNumbers);
281
+ const blob = new Blob([byteArray], {type: "image/png"});
282
+
283
+
284
+ const storage = firebase.storage(fireapp);
285
+
286
+
287
+ const storageRef = firebase.storage().ref(`images/${time_str}.png`);
288
+
289
+ const uploadTask = storageRef.put(blob);
290
+
291
+ alert("sucessful upload to firebae");
292
+ // Replace the successful upload handler with this:
293
+ uploadTask.on("state_changed", (snapshot) => {
294
+ // Handle the progress of the upload
295
+ }, (error) => {
296
+ // Handle the error during the upload
297
+ reject(error);
298
+ }, async () => {
299
+ // Handle the successful upload
300
+ const database = firebase.database();
301
+ const latestImageRef = database.ref("latestImage");
302
+ const downloadURL = await storageRef.getDownloadURL();
303
+ await latestImageRef.set({
304
+ fileName: `${time_str}.png`,
305
+ downloadURL: downloadURL
306
+ });
307
+ resolve();
308
+ });
309
+ });
310
+ }
311
+
312
+
313
+ document.getElementById("hamburger-menu").addEventListener("click", toggleToolbar);
314
+ alert("js loaded");
315
+ </script>
316
+
317
+ <py-env>
318
+ - numpy
319
+ - Pillow
320
+ - micropip:
321
+ - boto3
322
+ - paths:
323
+ - ./canvas.py
324
+ </py-env>
325
+
326
+
327
+ <py-script>
328
+ from pyodide import to_js, create_proxy
329
+ from PIL import Image
330
+ import io
331
+ import time
332
+ import base64
333
+ from collections import deque
334
+ import numpy as np
335
+ from js import (
336
+ console,
337
+ document,
338
+ parent,
339
+ devicePixelRatio,
340
+ ImageData,
341
+ Uint8ClampedArray,
342
+ CanvasRenderingContext2D as Context2d,
343
+ requestAnimationFrame,
344
+ window,
345
+ encodeURIComponent,
346
+ w2ui,
347
+ update_eraser,
348
+ update_scale,
349
+ adjust_selection,
350
+ update_count,
351
+ enable_result_lst,
352
+ setup_shortcut,
353
+ update_undo_redo,
354
+ alert,
355
+ uploadImageToFirebase,
356
+ firebase,
357
+ aws,
358
+ fetch
359
+ )
360
+ answer = aws("hello", 1, 2)
361
+ console.log(answer)
362
+
363
+ #addPhoto("demo")
364
+
365
+ # async def get_latest_image_from_firebase():
366
+ # alert("get_latest_image_from_firebase called")
367
+
368
+ # try:
369
+ # database = firebase.database()
370
+ # alert("try called")
371
+ # latestImageRef = database.ref("latestImage")
372
+ # latestImageSnapshot = await latestImageRef.once("value")
373
+ # latestImageInfo = latestImageSnapshot.val()
374
+
375
+
376
+ # download_url = latestImageInfo["downloadURL"]
377
+
378
+
379
+ # with pyodide.open_url(download_url) as f:
380
+
381
+ # img = Image.open(f)
382
+
383
+ # print("Downloaded image:", str(img))
384
+ # return img
385
+ # except Exception as e:
386
+ # print("Error while getting the latest image from Firebase:", str(e))
387
+ # return None
388
+
389
+ async def fetch_latest_image_url(database_url):
390
+ console.log("fetch_latest_image called")
391
+ # different methods to call
392
+ response = await fetch(f"{database_url}/latestImage.json")
393
+ console.log(f"response status: {response.status}, status text: {response.statusText}")
394
+
395
+ latest_image_data = await response.json()
396
+ latest_image_data = latest_image_data.to_py()
397
+
398
+ image_url = latest_image_data["downloadURL"]
399
+ image_name = latest_image_data["fileName"]
400
+ console.log(f"Latest image URL: {image_url}")
401
+ console.log(f"Latest image name: {image_name}")
402
+
403
+ # Fetch the image data as ArrayBuffer
404
+ image_response = await fetch(image_url)
405
+ image_data = await image_response.arrayBuffer()
406
+
407
+
408
+ return image_data, image_name
409
+
410
+
411
+ from canvas import InfCanvas
412
+
413
+
414
+ class History:
415
+ def __init__(self,maxlen=10):
416
+ self.idx=-1
417
+ self.undo_lst=deque([],maxlen=maxlen)
418
+ self.redo_lst=deque([],maxlen=maxlen)
419
+ self.state=None
420
+
421
+ def undo(self):
422
+ cur=None
423
+ if len(self.undo_lst):
424
+ cur=self.undo_lst.pop()
425
+ self.redo_lst.appendleft(cur)
426
+ return cur
427
+ def redo(self):
428
+ cur=None
429
+ if len(self.redo_lst):
430
+ cur=self.redo_lst.popleft()
431
+ self.undo_lst.append(cur)
432
+ return cur
433
+
434
+ def check(self):
435
+ return len(self.undo_lst)>0,len(self.redo_lst)>0
436
+
437
+ def append(self,state,update=True):
438
+ self.redo_lst.clear()
439
+ self.undo_lst.append(state)
440
+ if update:
441
+ update_undo_redo(*self.check())
442
+
443
+ history = History()
444
+
445
+ base_lst = [None]
446
+ async def draw_canvas() -> None:
447
+ alert("draw_canvas called")
448
+ width=1024
449
+ height=700
450
+ canvas=InfCanvas(1024,700)
451
+ update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
452
+ document.querySelector("#container").style.height= f"{height}px"
453
+ document.querySelector("#container").style.width = f"{width}px"
454
+ canvas.setup_mouse()
455
+ canvas.clear_background()
456
+ canvas.draw_buffer()
457
+ canvas.draw_selection_box()
458
+ base_lst[0]=canvas
459
+
460
+ # latest_image = await get_latest_image_from_firebase()
461
+
462
+ # if latest_image is not None:
463
+ # Log the URL of the latest image to the console
464
+ # console.log(f"Latest image URL: {latest_image.url}")
465
+ # Request the parent window to display the latest image on the canvas
466
+ # (commented out to fix the indentation error)
467
+ # window.parent.postMessage({ type: "displayLatestImageOnCanvas", image: latest_image }, "*")
468
+ # else:
469
+ # print("No latest image found in Firebase.")
470
+
471
+
472
+ async def draw_canvas_func(event):
473
+ alert("draw_canvas gradio called")
474
+ try:
475
+ app=parent.document.querySelector("gradio-app")
476
+ if app.shadowRoot:
477
+ app=app.shadowRoot
478
+ width=app.querySelector("#canvas_width input").value
479
+ height=app.querySelector("#canvas_height input").value
480
+ selection_size=app.querySelector("#selection_size input").value
481
+ except:
482
+ width=1024
483
+ height=768
484
+ selection_size=384
485
+ document.querySelector("#container").style.width = f"{width}px"
486
+ document.querySelector("#container").style.height= f"{height}px"
487
+
488
+ database_url = "https://nyucapstone-7c22c-default-rtdb.firebaseio.com"
489
+ image_data, latest_image_name = await fetch_latest_image_url(database_url)
490
+ pil_image = Image.open(io.BytesIO(image_data.to_py()))
491
+
492
+ np_image = np.array(pil_image)
493
+
494
+ canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size),firebase_image_data=np_image)
495
+
496
+
497
+ canvas.setup_mouse()
498
+ canvas.clear_background()
499
+ canvas.draw_buffer()
500
+ canvas.draw_selection_box()
501
+
502
+ # await canvas.load_image(image_data)
503
+
504
+
505
+ # Update the canvas buffer with the new image data and redraw the buffer
506
+ h, w, c = canvas.buffer.shape
507
+ canvas.sync_to_buffer()
508
+ canvas.buffer_dirty = True
509
+
510
+ h_min = min(h, np_image.shape[0])
511
+ w_min = min(w, np_image.shape[1])
512
+
513
+
514
+
515
+ # mask = np_image[:, :, 3:4].repeat(4, axis=2)
516
+ # canvas.buffer[mask > 0] = 0
517
+ # canvas.buffer[0:h, 0:w, :] += np_image
518
+
519
+ mask = np_image[:h_min, :w_min, 3:4].repeat(4, axis=2)
520
+ canvas.buffer[:h_min, :w_min][mask > 0] = 0
521
+ canvas.buffer[:h_min, :w_min] += np_image[:h_min, :w_min]
522
+
523
+
524
+
525
+ canvas.draw_buffer()
526
+
527
+ base_lst[0]=canvas
528
+
529
+ alert("made it to end of draw_canvas gradio")
530
+
531
+
532
+ import js
533
+
534
+ async def export_func(event):
535
+ base = base_lst[0]
536
+
537
+ arr = base.export()
538
+ base.draw_buffer()
539
+ base.canvas[2].clear()
540
+ base64_str = base.numpy_to_base64(arr)
541
+ time_str = time.strftime("%Y%m%d_%H%M%S")
542
+
543
+ # The rest of the original export_func code
544
+ link = document.createElement("a")
545
+ if len(event.data) > 2 and event.data[2]:
546
+ filename = event.data[2]
547
+ else:
548
+ filename = f"outpaint_{time_str}"
549
+ link.download = f"{filename}.png"
550
+ link.href = "data:image/png;base64," + base64_str
551
+ link.click()
552
+ console.log(f"Canvas saved to {filename}.png")
553
+
554
+ img_candidate_lst=[None,0]
555
+
556
+ async def outpaint_func(event):
557
+ base=base_lst[0]
558
+ if len(event.data)==2:
559
+ app=parent.document.querySelector("gradio-app")
560
+ if app.shadowRoot:
561
+ app=app.shadowRoot
562
+ base64_str_raw=app.querySelector("#output textarea").value
563
+ base64_str_lst=base64_str_raw.split(",")
564
+ img_candidate_lst[0]=base64_str_lst
565
+ img_candidate_lst[1]=0
566
+ elif event.data[2]=="next":
567
+ img_candidate_lst[1]+=1
568
+ elif event.data[2]=="prev":
569
+ img_candidate_lst[1]-=1
570
+ enable_result_lst()
571
+ if img_candidate_lst[0] is None:
572
+ return
573
+ lst=img_candidate_lst[0]
574
+ idx=img_candidate_lst[1]
575
+ update_count(idx%len(lst)+1,len(lst))
576
+ arr=base.base64_to_numpy(lst[idx%len(lst)])
577
+ base.fill_selection(arr)
578
+ base.draw_selection_box()
579
+
580
+ async def undo_func(event):
581
+ base=base_lst[0]
582
+ img_candidate_lst[0]=None
583
+ if base.sel_dirty:
584
+ base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
585
+ base.sel_dirty = False
586
+ base.canvas[2].clear()
587
+
588
+ async def commit_func(event):
589
+ base = base_lst[0]
590
+ img_candidate_lst[0] = None
591
+ if base.sel_dirty:
592
+ base.write_selection_to_buffer()
593
+ base.draw_buffer()
594
+ base.canvas[2].clear()
595
+ if len(event.data) > 2:
596
+ history.append(base.save())
597
+
598
+ # sending the image to firebase here
599
+ arr = base.export()
600
+ base64_str = base.numpy_to_base64(arr)
601
+ time_str = time.strftime("%Y%m%d_%H%M%S")
602
+
603
+ # Call the JavaScript function to upload the image to Firebase storage
604
+ await js.uploadImageToFirebase(base64_str, time_str)
605
+
606
+
607
+ async def history_undo_func(event):
608
+ base=base_lst[0]
609
+ if base.buffer_dirty or len(history.redo_lst)>0:
610
+ state=history.undo()
611
+ else:
612
+ history.undo()
613
+ state=history.undo()
614
+ if state is not None:
615
+ base.load(state)
616
+ update_undo_redo(*history.check())
617
+
618
+ async def history_setup_func(event):
619
+ base=base_lst[0]
620
+ history.undo_lst.clear()
621
+ history.redo_lst.clear()
622
+ history.append(base.save(),update=False)
623
+
624
+ async def history_redo_func(event):
625
+ base=base_lst[0]
626
+ if len(history.undo_lst)>0:
627
+ state=history.redo()
628
+ else:
629
+ history.redo()
630
+ state=history.redo()
631
+ if state is not None:
632
+ base.load(state)
633
+ update_undo_redo(*history.check())
634
+
635
+
636
+ async def transfer_func(event):
637
+ base=base_lst[0]
638
+ base.read_selection_from_buffer()
639
+ sel_buffer=base.sel_buffer
640
+ sel_buffer_str=base.numpy_to_base64(sel_buffer)
641
+ app=parent.document.querySelector("gradio-app")
642
+ if app.shadowRoot:
643
+ app=app.shadowRoot
644
+ app.querySelector("#input textarea").value=sel_buffer_str
645
+ app.querySelector("#proceed").click()
646
+
647
+ async def upload_func(event):
648
+ base=base_lst[0]
649
+ # base64_str=event.data[1]
650
+ # Retrieve the base64 encoded image string from the #upload_content HTML element
651
+ base64_str=document.querySelector("#upload_content").value
652
+ base64_str=base64_str.split(",")[-1]
653
+ # base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
654
+ arr=base.base64_to_numpy(base64_str)
655
+ h,w,c=base.buffer.shape
656
+ base.sync_to_buffer()
657
+ base.buffer_dirty=True
658
+ mask=arr[:,:,3:4].repeat(4,axis=2)
659
+ base.buffer[mask>0]=0
660
+ # in case mismatch
661
+ base.buffer[0:h,0:w,:]+=arr
662
+ #base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
663
+ #base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
664
+ base.draw_buffer()
665
+ if len(event.data)>2:
666
+ history.append(base.save())
667
+
668
+ async def setup_shortcut_func(event):
669
+ setup_shortcut(event.data[1])
670
+
671
+
672
+ document.querySelector("#export").addEventListener("click",create_proxy(export_func))
673
+ document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
674
+ document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
675
+ document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
676
+ document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
677
+
678
+ document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
679
+ document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
680
+
681
+ async def setup_func():
682
+ document.querySelector("#setup").value="1"
683
+
684
+ async def reset_func(event):
685
+ base=base_lst[0]
686
+ base.reset()
687
+
688
+ async def load_func(event):
689
+ base=base_lst[0]
690
+ base.load(event.data[1])
691
+
692
+ async def save_func(event):
693
+ base=base_lst[0]
694
+ json_str=base.save()
695
+ time_str = time.strftime("%Y%m%d_%H%M%S")
696
+ link = document.createElement("a")
697
+ if len(event.data)>2 and event.data[2]:
698
+ filename = str(event.data[2]).strip()
699
+ else:
700
+ filename = f"outpaint_{time_str}"
701
+ # link.download = f"sdinf_state_{time_str}.json"
702
+ link.download = f"{filename}.sdinf"
703
+ link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
704
+ link.click()
705
+
706
+ async def prev_result_func(event):
707
+ base=base_lst[0]
708
+ base.reset()
709
+
710
+ async def next_result_func(event):
711
+ base=base_lst[0]
712
+ base.reset()
713
+
714
+ async def zoom_in_func(event):
715
+ base=base_lst[0]
716
+ scale=base.scale
717
+ if scale>=0.2:
718
+ scale-=0.1
719
+ if len(event.data)>2:
720
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
721
+ else:
722
+ base.update_scale(scale)
723
+ scale=base.scale
724
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
725
+
726
+ async def zoom_out_func(event):
727
+ base=base_lst[0]
728
+ scale=base.scale
729
+ if scale<10:
730
+ scale+=0.1
731
+ console.log(len(event.data))
732
+ if len(event.data)>2:
733
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
734
+ else:
735
+ base.update_scale(scale)
736
+ scale=base.scale
737
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
738
+
739
+ async def sync_func(event):
740
+ base=base_lst[0]
741
+ base.sync_to_buffer()
742
+ base.canvas[2].clear()
743
+
744
+ async def eraser_size_func(event):
745
+ base=base_lst[0]
746
+ eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
747
+ eraser_size=max(8,eraser_size)
748
+ base.eraser_size=eraser_size
749
+
750
+ async def resize_selection_func(event):
751
+ base=base_lst[0]
752
+ cursor=base.cursor
753
+ if len(event.data)>3:
754
+ console.log(event.data)
755
+ base.cursor[0]=int(event.data[1])
756
+ base.cursor[1]=int(event.data[2])
757
+ base.selection_size_w=int(event.data[3])//8*8
758
+ base.selection_size_h=int(event.data[4])//8*8
759
+ base.refine_selection()
760
+ base.draw_selection_box()
761
+ elif len(event.data)>2:
762
+ base.draw_selection_box()
763
+ else:
764
+ base.canvas[-1].clear()
765
+ adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
766
+
767
+ async def eraser_func(event):
768
+ base=base_lst[0]
769
+ if event.data[1]!="eraser":
770
+ base.canvas[-2].clear()
771
+ else:
772
+ x,y=base.mouse_pos
773
+ base.draw_eraser(x,y)
774
+
775
+ async def resize_func(event):
776
+ base=base_lst[0]
777
+ width=int(event.data[1])
778
+ height=int(event.data[2])
779
+ if width>=256 and height>=256:
780
+ if max(base.selection_size_h,base.selection_size_w)>min(width,height):
781
+ base.selection_size_h=256
782
+ base.selection_size_w=256
783
+ base.resize(width,height)
784
+
785
+ async def message_func(event):
786
+ if event.data[0]=="click":
787
+ if event.data[1]=="clear":
788
+ await reset_func(event)
789
+ elif event.data[1]=="save":
790
+ await save_func(event)
791
+ elif event.data[1]=="export":
792
+ await export_func(event)
793
+ elif event.data[1]=="accept":
794
+ await commit_func(event)
795
+ elif event.data[1]=="cancel":
796
+ await undo_func(event)
797
+ elif event.data[1]=="zoom_in":
798
+ await zoom_in_func(event)
799
+ elif event.data[1]=="zoom_out":
800
+ await zoom_out_func(event)
801
+ elif event.data[1]=="redo":
802
+ await history_redo_func(event)
803
+ elif event.data[1]=="undo":
804
+ await history_undo_func(event)
805
+ elif event.data[1]=="history":
806
+ await history_setup_func(event)
807
+ elif event.data[0]=="sync":
808
+ await sync_func(event)
809
+ elif event.data[0]=="load":
810
+ await load_func(event)
811
+ elif event.data[0]=="upload":
812
+ await upload_func(event)
813
+ elif event.data[0]=="outpaint":
814
+ await outpaint_func(event)
815
+ elif event.data[0]=="mode":
816
+ if event.data[1]!="selection":
817
+ await sync_func(event)
818
+ await eraser_func(event)
819
+ document.querySelector("#mode").value=event.data[1]
820
+ elif event.data[0]=="transfer":
821
+ await transfer_func(event)
822
+ elif event.data[0]=="setup":
823
+ await draw_canvas_func(event)
824
+ elif event.data[0]=="eraser_size":
825
+ await eraser_size_func(event)
826
+ elif event.data[0]=="resize_selection":
827
+ await resize_selection_func(event)
828
+ elif event.data[0]=="shortcut":
829
+ await setup_shortcut_func(event)
830
+ elif event.data[0]=="resize":
831
+ await resize_func(event)
832
+
833
+ window.addEventListener("message",create_proxy(message_func))
834
+
835
+ import asyncio
836
+
837
+ _ = await asyncio.gather(
838
+ setup_func()
839
+ )
840
+ </py-script>
841
+
842
+ </body>
843
+ </html>
js/fabric.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/keyboard.js ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ window.my_setup_keyboard=setInterval(function(){
3
+ let app=document.querySelector("gradio-app");
4
+ app=app.shadowRoot??app;
5
+ let frame=app.querySelector("#sdinfframe").contentWindow;
6
+ console.log("Check iframe...");
7
+ if(frame.setup_shortcut)
8
+ {
9
+ frame.setup_shortcut(json);
10
+ clearInterval(window.my_setup_keyboard);
11
+ }
12
+ }, 1000);
13
+ var config=JSON.parse(json);
14
+ var key_map={};
15
+ Object.keys(config.shortcut).forEach(k=>{
16
+ key_map[config.shortcut[k]]=k;
17
+ });
18
+ document.addEventListener("keydown", e => {
19
+ if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
20
+ {
21
+ let key=e.key;
22
+ if(e.ctrlKey)
23
+ {
24
+ key="Ctrl+"+e.key;
25
+ if(key in key_map)
26
+ {
27
+ e.preventDefault();
28
+ }
29
+ }
30
+ let app=document.querySelector("gradio-app");
31
+ app=app.shadowRoot??app;
32
+ let frame=app.querySelector("#sdinfframe").contentDocument;
33
+ frame.dispatchEvent(
34
+ new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
35
+ );
36
+ }
37
+ })
js/mode.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ function(mode){
2
+ let app=document.querySelector("gradio-app").shadowRoot;
3
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
4
+ frame.querySelector("#mode").value=mode;
5
+ return mode;
6
+ }
js/outpaint.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(a){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ return a;
23
+ }
js/proceed.js ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(sel_buffer_str,
2
+ prompt_text,
3
+ negative_prompt_text,
4
+ strength,
5
+ guidance,
6
+ step,
7
+ resize_check,
8
+ fill_mode,
9
+ enable_safety,
10
+ use_correction,
11
+ enable_img2img,
12
+ use_seed,
13
+ seed_val,
14
+ generate_num,
15
+ scheduler,
16
+ scheduler_eta,
17
+ state){
18
+ let app=document.querySelector("gradio-app");
19
+ app=app.shadowRoot??app;
20
+ sel_buffer=app.querySelector("#input textarea").value;
21
+ let use_correction_bak=false;
22
+ ({resize_check,enable_safety,use_correction_bak,enable_img2img,use_seed,seed_val}=window.config_obj);
23
+ return [
24
+ sel_buffer,
25
+ prompt_text,
26
+ negative_prompt_text,
27
+ strength,
28
+ guidance,
29
+ step,
30
+ resize_check,
31
+ fill_mode,
32
+ enable_safety,
33
+ use_correction,
34
+ enable_img2img,
35
+ use_seed,
36
+ seed_val,
37
+ generate_num,
38
+ scheduler,
39
+ scheduler_eta,
40
+ state,
41
+ ]
42
+ }
js/setup.js ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(token_val, width, height, size, model_choice, model_path){
2
+ let app=document.querySelector("gradio-app");
3
+ app=app.shadowRoot??app;
4
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
5
+ // app.querySelector("#setup_row").style.display="none";
6
+ app.querySelector("#model_path_input").style.display="none";
7
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
8
+
9
+ if(frame.querySelector("#setup").value=="0")
10
+ {
11
+ window.my_setup=setInterval(function(){
12
+ let app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
15
+ console.log("Check PyScript...")
16
+ if(frame.querySelector("#setup").value=="1")
17
+ {
18
+ frame.querySelector("#draw").click();
19
+ clearInterval(window.my_setup);
20
+ }
21
+ }, 100)
22
+ }
23
+ else
24
+ {
25
+ frame.querySelector("#draw").click();
26
+ }
27
+ return [token_val, width, height, size, model_choice, model_path];
28
+ }
js/toolbar.js ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
2
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
3
+
4
+ // https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
5
+ function getBase64(file) {
6
+ var reader = new FileReader();
7
+ reader.readAsDataURL(file);
8
+ reader.onload = function () {
9
+ add_image(reader.result);
10
+ // console.log(reader.result);
11
+ };
12
+ reader.onerror = function (error) {
13
+ console.log("Error: ", error);
14
+ };
15
+ }
16
+
17
+ function getText(file) {
18
+ var reader = new FileReader();
19
+ reader.readAsText(file);
20
+ reader.onload = function () {
21
+ window.postMessage(["load",reader.result],"*")
22
+ // console.log(reader.result);
23
+ };
24
+ reader.onerror = function (error) {
25
+ console.log("Error: ", error);
26
+ };
27
+ }
28
+
29
+ document.querySelector("#upload_file").addEventListener("change", (event)=>{
30
+ console.log(event);
31
+ let file = document.querySelector("#upload_file").files[0];
32
+ getBase64(file);
33
+ })
34
+
35
+ document.querySelector("#upload_state").addEventListener("change", (event)=>{
36
+ console.log(event);
37
+ let file = document.querySelector("#upload_state").files[0];
38
+ getText(file);
39
+ })
40
+
41
+ open_setting = function() {
42
+ if (!w2ui.foo) {
43
+ new w2form({
44
+ name: "foo",
45
+ style: "border: 0px; background-color: transparent;",
46
+ fields: [{
47
+ field: "canvas_width",
48
+ type: "int",
49
+ required: true,
50
+ html: {
51
+ label: "Canvas Width"
52
+ }
53
+ },
54
+ {
55
+ field: "canvas_height",
56
+ type: "int",
57
+ required: true,
58
+ html: {
59
+ label: "Canvas Height"
60
+ }
61
+ },
62
+ ],
63
+ record: {
64
+ canvas_width: 1200,
65
+ canvas_height: 600,
66
+ },
67
+ actions: {
68
+ Save() {
69
+ this.validate();
70
+ let record = this.getCleanRecord();
71
+ window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
72
+ w2popup.close();
73
+ },
74
+ custom: {
75
+ text: "Cancel",
76
+ style: "text-transform: uppercase",
77
+ onClick(event) {
78
+ w2popup.close();
79
+ }
80
+ }
81
+ }
82
+ });
83
+ }
84
+ w2popup.open({
85
+ title: "Form in a Popup",
86
+ body: "<div id='form' style='width: 100%; height: 100%;''></div>",
87
+ style: "padding: 15px 0px 0px 0px",
88
+ width: 500,
89
+ height: 280,
90
+ showMax: true,
91
+ async onToggle(event) {
92
+ await event.complete
93
+ w2ui.foo.resize();
94
+ }
95
+ })
96
+ .then((event) => {
97
+ w2ui.foo.render("#form")
98
+ });
99
+ }
100
+
101
+ var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
102
+ var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting"];
103
+ var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting"];
104
+ var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting"];
105
+ var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
106
+ var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
107
+
108
+ function check_button(id,text="",checked=true,tooltip="")
109
+ {
110
+ return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
111
+ }
112
+
113
+ var toolbar=new w2toolbar({
114
+ box: "#toolbar",
115
+ name: "toolbar",
116
+ tooltip: "top",
117
+ items: [
118
+ { type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
119
+ { type: "break" },
120
+ { type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
121
+ { type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
122
+ { type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
123
+ { type: "break" },
124
+ { type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
125
+ { type: "break" },
126
+ { type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
127
+ { type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
128
+ { type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
129
+ { type: "break" },
130
+ { type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
131
+ { type: "break" },
132
+ { type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disable:true,},
133
+ { type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
134
+ { type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disable:true,},
135
+ { type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disable:true,},
136
+ { type: "html", id: "current", hidden: true, disable:true,
137
+ async onRefresh(event) {
138
+ await event.complete
139
+ let fragment = query.html(`
140
+ <div class="w2ui-tb-text">
141
+ <div class="w2ui-tb-count">
142
+ <span>${this.sel_value ?? "1/1"}</span>
143
+ </div> </div>`)
144
+ query(this.box).find("#tb_toolbar_item_current").append(fragment)
145
+ }
146
+ },
147
+ { type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disable:true,},
148
+ { type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disable:true,},
149
+ { type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disable:true,},
150
+ { type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disable:true,},
151
+ { type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disable:true,},
152
+ { type: "break" },
153
+ { type: "spacer" },
154
+ { type: "break" },
155
+ { type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
156
+ { type: "html", id: "eraser_size", hidden: true,
157
+ async onRefresh(event) {
158
+ await event.complete
159
+ // let fragment = query.html(`
160
+ // <input type="number" size="${this.eraser_size ? this.eraser_size.length:"2"}" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
161
+ // <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">`)
162
+ let fragment = query.html(`
163
+ <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
164
+ `)
165
+ fragment.filter("input").on("change", event => {
166
+ this.eraser_size = event.target.value;
167
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
168
+ this.setCount("eraser_size_btn", event.target.value);
169
+ window.postMessage(["eraser_size", event.target.value],"*")
170
+ this.refresh();
171
+ })
172
+ query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
173
+ }
174
+ },
175
+ // { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
176
+ { type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
177
+ { type: "break" },
178
+ { type: "html", id: "scale",
179
+ async onRefresh(event) {
180
+ await event.complete
181
+ let fragment = query.html(`
182
+ <div class="">
183
+ <div style="padding: 4px; border: 1px solid silver">
184
+ <span>${this.scale_value ?? "100%"}</span>
185
+ </div></div>`)
186
+ query(this.box).find("#tb_toolbar_item_scale").append(fragment)
187
+ }
188
+ },
189
+ { type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
190
+ { type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
191
+ { type: "break" },
192
+ { type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
193
+ { type: "new-line"},
194
+ { type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
195
+ { type: "break" },
196
+ check_button("enable_img2img","Enable Img2Img",false),
197
+ // check_button("use_correction","Photometric Correction",false),
198
+ check_button("resize_check","Resize Small Input",true),
199
+ check_button("enable_safety","Enable Safety Checker",true),
200
+ check_button("square_selection","Square Selection Only",false),
201
+ {type: "break"},
202
+ check_button("use_seed","Use Seed:",false),
203
+ { type: "html", id: "seed_val",
204
+ async onRefresh(event) {
205
+ await event.complete
206
+ let fragment = query.html(`
207
+ <input type="number" style="margin: 0px 3px; padding: 4px; width:100px;" value="${this.config_obj.seed_val ?? "0"}">`)
208
+ fragment.filter("input").on("change", event => {
209
+ this.config_obj.seed_val = event.target.value;
210
+ parent.config_obj=this.config_obj;
211
+ this.refresh();
212
+ })
213
+ query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
214
+ }
215
+ },
216
+ { type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
217
+ ],
218
+ onClick(event) {
219
+ switch(event.target){
220
+ case "setting":
221
+ open_setting();
222
+ break;
223
+ case "upload":
224
+ this.upload_mode=true
225
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
226
+ this.click("canvas");
227
+ this.click("selection");
228
+ this.show("confirm","cancel_overlay","add_image","delete_image");
229
+ this.enable("confirm","cancel_overlay","add_image","delete_image");
230
+ this.disable(...upload_button_lst);
231
+ query("#upload_file").click();
232
+ if(this.upload_tip)
233
+ {
234
+ this.upload_tip=false;
235
+ w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
236
+ }
237
+ break;
238
+ case "resize_selection":
239
+ this.resize_mode=true;
240
+ this.disable(...resize_button_lst);
241
+ this.enable("confirm","cancel_overlay");
242
+ this.show("confirm","cancel_overlay");
243
+ window.postMessage(["resize_selection",""],"*");
244
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
245
+ break;
246
+ case "confirm":
247
+ if(this.upload_mode)
248
+ {
249
+ export_image();
250
+ }
251
+ else
252
+ {
253
+ let sel_box=this.selection_box;
254
+ window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
255
+ }
256
+ case "cancel_overlay":
257
+ end_overlay();
258
+ this.hide("confirm","cancel_overlay","add_image","delete_image");
259
+ if(this.upload_mode){
260
+ this.enable(...upload_button_lst);
261
+ }
262
+ else
263
+ {
264
+ this.enable(...resize_button_lst);
265
+ window.postMessage(["resize_selection","",""],"*");
266
+ if(event.target=="cancel_overlay")
267
+ {
268
+ this.selection_box=this.selection_box_bak;
269
+ }
270
+ }
271
+ if(this.selection_box)
272
+ {
273
+ this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
274
+ }
275
+ this.disable("confirm","cancel_overlay","add_image","delete_image");
276
+ this.upload_mode=false;
277
+ this.resize_mode=false;
278
+ this.click("selection");
279
+ break;
280
+ case "add_image":
281
+ query("#upload_file").click();
282
+ break;
283
+ case "delete_image":
284
+ let active_obj = window.overlay.getActiveObject();
285
+ if(active_obj)
286
+ {
287
+ window.overlay.remove(active_obj);
288
+ window.overlay.renderAll();
289
+ }
290
+ else
291
+ {
292
+ w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
293
+ }
294
+ break;
295
+ case "load":
296
+ query("#upload_state").click();
297
+ this.selection_box=null;
298
+ this.setCount("resize_selection","");
299
+ break;
300
+ case "next":
301
+ case "prev":
302
+ window.postMessage(["outpaint", "", event.target], "*");
303
+ break;
304
+ case "outpaint":
305
+ this.click("selection");
306
+ this.disable(...outpaint_button_lst);
307
+ this.show(...outpaint_result_lst);
308
+ if(this.outpaint_tip)
309
+ {
310
+ this.outpaint_tip=false;
311
+ w2utils.notify("The canvas stays locked until you accept/cancel current outpainting",{timeout:10000,where:query("#container")})
312
+ }
313
+ document.querySelector("#container").style.pointerEvents="none";
314
+ case "retry":
315
+ this.disable(...outpaint_result_func_lst);
316
+ window.postMessage(["transfer",""],"*")
317
+ break;
318
+ case "accept":
319
+ case "cancel":
320
+ this.hide(...outpaint_result_lst);
321
+ this.disable(...outpaint_result_func_lst);
322
+ this.enable(...outpaint_button_lst);
323
+ document.querySelector("#container").style.pointerEvents="auto";
324
+ window.postMessage(["click", event.target],"*");
325
+ let app=parent.document.querySelector("gradio-app");
326
+ app=app.shadowRoot??app;
327
+ app.querySelector("#cancel").click();
328
+ break;
329
+ case "eraser":
330
+ case "selection":
331
+ case "canvas":
332
+ if(event.target=="eraser")
333
+ {
334
+ this.show("eraser_size","eraser_size_btn");
335
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
336
+ window.overlay.isDrawingMode = true;
337
+ }
338
+ else
339
+ {
340
+ this.hide("eraser_size","eraser_size_btn");
341
+ window.overlay.isDrawingMode = false;
342
+ }
343
+ if(this.upload_mode)
344
+ {
345
+ if(event.target=="canvas")
346
+ {
347
+ window.postMessage(["mode", event.target],"*")
348
+ document.querySelector("#overlay_container").style.pointerEvents="none";
349
+ document.querySelector("#overlay_container").style.opacity = 0.5;
350
+ }
351
+ else
352
+ {
353
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
354
+ document.querySelector("#overlay_container").style.opacity = 1.0;
355
+ }
356
+ }
357
+ else
358
+ {
359
+ window.postMessage(["mode", event.target],"*")
360
+ }
361
+ break;
362
+ case "help":
363
+ w2popup.open({
364
+ title: "Document",
365
+ body: "Usage: <a href='https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md' target='_blank'>https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md</a>"
366
+ })
367
+ break;
368
+ case "clear":
369
+ w2confirm("Reset canvas?").yes(() => {
370
+ window.postMessage(["click", event.target],"*");
371
+ }).no(() => {})
372
+ break;
373
+ case "random_seed":
374
+ this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
375
+ parent.config_obj=this.config_obj;
376
+ this.refresh();
377
+ break;
378
+ case "enable_img2img":
379
+ case "use_correction":
380
+ case "resize_check":
381
+ case "enable_safety":
382
+ case "use_seed":
383
+ case "square_selection":
384
+ let target=this.get(event.target);
385
+ target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
386
+ this.config_obj[event.target]=!target.checked;
387
+ parent.config_obj=this.config_obj;
388
+ this.refresh();
389
+ break;
390
+ case "save":
391
+ case "export":
392
+ ask_filename(event.target);
393
+ break;
394
+ default:
395
+ // clear, save, export, outpaint, retry
396
+ // break, save, export, accept, retry, outpaint
397
+ window.postMessage(["click", event.target],"*")
398
+ }
399
+ console.log("Target: "+ event.target, event)
400
+ }
401
+ })
402
+ window.w2ui=w2ui;
403
+ w2ui.toolbar.config_obj={
404
+ resize_check: true,
405
+ enable_safety: true,
406
+ use_correction: false,
407
+ enable_img2img: false,
408
+ use_seed: false,
409
+ seed_val: 0,
410
+ square_selection: false,
411
+ };
412
+ w2ui.toolbar.outpaint_tip=true;
413
+ w2ui.toolbar.upload_tip=true;
414
+ window.update_count=function(cur,total){
415
+ w2ui.toolbar.sel_value=`${cur}/${total}`;
416
+ w2ui.toolbar.refresh();
417
+ }
418
+ window.update_eraser=function(val,max_val){
419
+ w2ui.toolbar.eraser_size=`${val}`;
420
+ w2ui.toolbar.eraser_max=`${max_val}`;
421
+ w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
422
+ w2ui.toolbar.refresh();
423
+ }
424
+ window.update_scale=function(val){
425
+ w2ui.toolbar.scale_value=`${val}`;
426
+ w2ui.toolbar.refresh();
427
+ }
428
+ window.enable_result_lst=function(){
429
+ w2ui.toolbar.enable(...outpaint_result_lst);
430
+ }
431
+ function onObjectScaled(e)
432
+ {
433
+ let object = e.target;
434
+ if(object.isType("rect"))
435
+ {
436
+ let width=object.getScaledWidth();
437
+ let height=object.getScaledHeight();
438
+ object.scale(1);
439
+ width=Math.max(Math.min(width,window.overlay.width-object.left),256);
440
+ height=Math.max(Math.min(height,window.overlay.height-object.top),256);
441
+ let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
442
+ let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
443
+ if(window.w2ui.toolbar.config_obj.square_selection)
444
+ {
445
+ let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
446
+ width=max_val;
447
+ height=max_val;
448
+ }
449
+ object.set({ width: width, height: height, left:l,top:t})
450
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
451
+ window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
452
+ window.w2ui.toolbar.refresh();
453
+ }
454
+ }
455
+ function onObjectMoved(e)
456
+ {
457
+ let object = e.target;
458
+ if(object.isType("rect"))
459
+ {
460
+ let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
461
+ let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
462
+ object.set({left:l,top:t});
463
+ window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
464
+ }
465
+ }
466
+ window.setup_overlay=function(width,height)
467
+ {
468
+ if(window.overlay)
469
+ {
470
+ window.overlay.setDimensions({width:width,height:height});
471
+ let app=parent.document.querySelector("gradio-app");
472
+ app=app.shadowRoot??app;
473
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
474
+ document.querySelector("#container").style.height= height+"px";
475
+ document.querySelector("#container").style.width = width+"px";
476
+ }
477
+ else
478
+ {
479
+ canvas=new fabric.Canvas("overlay_canvas");
480
+ canvas.setDimensions({width:width,height:height});
481
+ let app=parent.document.querySelector("gradio-app");
482
+ app=app.shadowRoot??app;
483
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
484
+ canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
485
+ canvas.on("object:scaling", onObjectScaled);
486
+ canvas.on("object:moving", onObjectMoved);
487
+ window.overlay=canvas;
488
+ }
489
+ document.querySelector("#overlay_container").style.pointerEvents="none";
490
+ }
491
+ window.update_overlay=function(width,height)
492
+ {
493
+ window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
494
+ // document.querySelector("#overlay_container").style.pointerEvents="none";
495
+ }
496
+ window.adjust_selection=function(x,y,width,height)
497
+ {
498
+ var rect = new fabric.Rect({
499
+ left: x,
500
+ top: y,
501
+ fill: "rgba(0,0,0,0)",
502
+ strokeWidth: 3,
503
+ stroke: "rgba(0,0,0,0.7)",
504
+ cornerColor: "red",
505
+ cornerStrokeColor: "red",
506
+ borderColor: "rgba(255, 0, 0, 1.0)",
507
+ width: width,
508
+ height: height,
509
+ lockRotation: true,
510
+ });
511
+ rect.setControlsVisibility({ mtr: false });
512
+ window.overlay.add(rect);
513
+ window.overlay.setActiveObject(window.overlay.item(0));
514
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
515
+ window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
516
+ }
517
+ function add_image(url)
518
+ {
519
+ fabric.Image.fromURL(url,function(img){
520
+ window.overlay.add(img);
521
+ window.overlay.setActiveObject(img);
522
+ },{left:100,top:100});
523
+ }
524
+ function export_image()
525
+ {
526
+ data=window.overlay.toDataURL();
527
+ document.querySelector("#upload_content").value=data;
528
+ window.postMessage(["upload",""],"*");
529
+ end_overlay();
530
+ }
531
+ function end_overlay()
532
+ {
533
+ window.overlay.clear();
534
+ document.querySelector("#overlay_container").style.opacity = 1.0;
535
+ document.querySelector("#overlay_container").style.pointerEvents="none";
536
+ }
537
+ function ask_filename(target)
538
+ {
539
+ w2prompt({
540
+ label: "Enter filename",
541
+ value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
542
+ })
543
+ .change((event) => {
544
+ console.log("change", event.detail.originalEvent.target.value);
545
+ })
546
+ .ok((event) => {
547
+ console.log("value=", event.detail.value);
548
+ window.postMessage(["click",target,event.detail.value],"*");
549
+ })
550
+ .cancel((event) => {
551
+ console.log("cancel");
552
+ });
553
+ }
554
+
555
+ document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
556
+ window.setup_shortcut=function(json)
557
+ {
558
+ var config=JSON.parse(json);
559
+ var key_map={};
560
+ Object.keys(config.shortcut).forEach(k=>{
561
+ key_map[config.shortcut[k]]=k;
562
+ })
563
+ document.addEventListener("keydown",(e)=>{
564
+ if(e.target.tagName!="INPUT")
565
+ {
566
+ let key=e.key;
567
+ if(e.ctrlKey)
568
+ {
569
+ key="Ctrl+"+e.key;
570
+ if(key in key_map)
571
+ {
572
+ e.preventDefault();
573
+ }
574
+ }
575
+ if(key in key_map)
576
+ {
577
+ w2ui.toolbar.click(key_map[key]);
578
+ }
579
+ }
580
+ })
581
+ }
js/upload.js ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(a,b){
2
+ if(!window.my_observe_upload)
3
+ {
4
+ console.log("setup upload here");
5
+ window.my_observe_upload = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
8
+ frame.querySelector("#upload").click();
9
+ });
10
+ window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
11
+ window.my_observe_upload.observe(window.my_observe_upload_target, {
12
+ attributes: false,
13
+ subtree: true,
14
+ childList: true,
15
+ characterData: true
16
+ });
17
+ }
18
+ return [a,b];
19
+ }
js/w2ui.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/xss.js ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var setup_outpaint=function(){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ };
23
+ window.config_obj={
24
+ resize_check: true,
25
+ enable_safety: true,
26
+ use_correction: false,
27
+ enable_img2img: false,
28
+ use_seed: false,
29
+ seed_val: 0,
30
+ };
31
+ setup_outpaint();
models/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
models/v1-inpainting-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 7.5e-05
3
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: hybrid # important
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ finetune_keys: null
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ build-essential
2
+ python3-opencv
3
+ libopencv-dev
4
+ cmake
perlin2d.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ ##########
4
+ # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
5
+ def perlin(x, y, seed=0):
6
+ # permutation table
7
+ np.random.seed(seed)
8
+ p = np.arange(256, dtype=int)
9
+ np.random.shuffle(p)
10
+ p = np.stack([p, p]).flatten()
11
+ # coordinates of the top-left
12
+ xi, yi = x.astype(int), y.astype(int)
13
+ # internal coordinates
14
+ xf, yf = x - xi, y - yi
15
+ # fade factors
16
+ u, v = fade(xf), fade(yf)
17
+ # noise components
18
+ n00 = gradient(p[p[xi] + yi], xf, yf)
19
+ n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
20
+ n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
21
+ n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
22
+ # combine noises
23
+ x1 = lerp(n00, n10, u)
24
+ x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
25
+ return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
26
+
27
+
28
+ def lerp(a, b, x):
29
+ "linear interpolation"
30
+ return a + x * (b - a)
31
+
32
+
33
+ def fade(t):
34
+ "6t^5 - 15t^4 + 10t^3"
35
+ return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
36
+
37
+
38
+ def gradient(h, x, y):
39
+ "grad converts h to the right gradient vector and return the dot product with (x,y)"
40
+ vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
41
+ g = vectors[h % 4]
42
+ return g[:, :, 0] * x + g[:, :, 1] * y
43
+
44
+
45
+ ##########
postprocess.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import time
27
+ import argparse
28
+ import os
29
+ import fpie
30
+ from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
31
+ from fpie.io import read_images, write_image
32
+ from process import BaseProcessor, EquProcessor, GridProcessor
33
+
34
+ from PIL import Image
35
+ import numpy as np
36
+ import skimage
37
+ import skimage.measure
38
+ import scipy
39
+ import scipy.signal
40
+
41
+
42
+ class PhotometricCorrection:
43
+ def __init__(self,quite=False):
44
+ self.get_parser("cli")
45
+ args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
46
+ args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
47
+ self.backend=args.backend
48
+ self.args=args
49
+ self.quite=quite
50
+ proc: BaseProcessor
51
+ proc = GridProcessor(
52
+ args.gradient,
53
+ args.backend,
54
+ args.cpu,
55
+ args.mpi_sync_interval,
56
+ args.block_size,
57
+ args.grid_x,
58
+ args.grid_y,
59
+ )
60
+ print(
61
+ f"[PIE]Successfully initialize PIE {args.method} solver "
62
+ f"with {args.backend} backend"
63
+ )
64
+ self.proc=proc
65
+
66
+ def run(self, original_image, inpainted_image, mode="mask_mode"):
67
+ print(f"[PIE] start")
68
+ if mode=="disabled":
69
+ return inpainted_image
70
+ input_arr=np.array(original_image)
71
+ if input_arr[:,:,-1].sum()<1:
72
+ return inpainted_image
73
+ output_arr=np.array(inpainted_image)
74
+ mask=input_arr[:,:,-1]
75
+ mask=255-mask
76
+ if mask.sum()<1 and mode=="mask_mode":
77
+ mode=""
78
+ if mode=="mask_mode":
79
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
80
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
81
+ else:
82
+ mask[8:-9,8:-9]=255
83
+ mask = mask[:,:,np.newaxis].repeat(3,axis=2)
84
+ nmask=mask.copy()
85
+ output_arr2=output_arr[:,:,0:3].copy()
86
+ input_arr2=input_arr[:,:,0:3].copy()
87
+ output_arr2[nmask<128]=0
88
+ input_arr2[nmask>=128]=0
89
+ output_arr2+=input_arr2
90
+ src = output_arr2[:,:,0:3]
91
+ tgt = src.copy()
92
+ proc=self.proc
93
+ args=self.args
94
+ if proc.root:
95
+ n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
96
+ proc.sync()
97
+ if proc.root:
98
+ result = tgt
99
+ t = time.time()
100
+ if args.p == 0:
101
+ args.p = args.n
102
+
103
+ for i in range(0, args.n, args.p):
104
+ if proc.root:
105
+ result, err = proc.step(args.p) # type: ignore
106
+ print(f"[PIE] Iter {i + args.p}, abs_err {err}")
107
+ else:
108
+ proc.step(args.p)
109
+
110
+ if proc.root:
111
+ dt = time.time() - t
112
+ print(f"[PIE] Time elapsed: {dt:.4f}s")
113
+ # make sure consistent with dummy process
114
+ return Image.fromarray(result)
115
+
116
+
117
+ def get_parser(self,gen_type: str) -> argparse.Namespace:
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument(
120
+ "-v", "--version", action="store_true", help="show the version and exit"
121
+ )
122
+ parser.add_argument(
123
+ "--check-backend", action="store_true", help="print all available backends"
124
+ )
125
+ if gen_type == "gui" and "mpi" in ALL_BACKEND:
126
+ # gui doesn't support MPI backend
127
+ ALL_BACKEND.remove("mpi")
128
+ parser.add_argument(
129
+ "-b",
130
+ "--backend",
131
+ type=str,
132
+ choices=ALL_BACKEND,
133
+ default=DEFAULT_BACKEND,
134
+ help="backend choice",
135
+ )
136
+ parser.add_argument(
137
+ "-c",
138
+ "--cpu",
139
+ type=int,
140
+ default=CPU_COUNT,
141
+ help="number of CPU used",
142
+ )
143
+ parser.add_argument(
144
+ "-z",
145
+ "--block-size",
146
+ type=int,
147
+ default=1024,
148
+ help="cuda block size (only for equ solver)",
149
+ )
150
+ parser.add_argument(
151
+ "--method",
152
+ type=str,
153
+ choices=["equ", "grid"],
154
+ default="equ",
155
+ help="how to parallelize computation",
156
+ )
157
+ parser.add_argument("-s", "--source", type=str, help="source image filename")
158
+ if gen_type == "cli":
159
+ parser.add_argument(
160
+ "-m",
161
+ "--mask",
162
+ type=str,
163
+ help="mask image filename (default is to use the whole source image)",
164
+ default="",
165
+ )
166
+ parser.add_argument("-t", "--target", type=str, help="target image filename")
167
+ parser.add_argument("-o", "--output", type=str, help="output image filename")
168
+ if gen_type == "cli":
169
+ parser.add_argument(
170
+ "-h0", type=int, help="mask position (height) on source image", default=0
171
+ )
172
+ parser.add_argument(
173
+ "-w0", type=int, help="mask position (width) on source image", default=0
174
+ )
175
+ parser.add_argument(
176
+ "-h1", type=int, help="mask position (height) on target image", default=0
177
+ )
178
+ parser.add_argument(
179
+ "-w1", type=int, help="mask position (width) on target image", default=0
180
+ )
181
+ parser.add_argument(
182
+ "-g",
183
+ "--gradient",
184
+ type=str,
185
+ choices=["max", "src", "avg"],
186
+ default="max",
187
+ help="how to calculate gradient for PIE",
188
+ )
189
+ parser.add_argument(
190
+ "-n",
191
+ type=int,
192
+ help="how many iteration would you perfer, the more the better",
193
+ default=5000,
194
+ )
195
+ if gen_type == "cli":
196
+ parser.add_argument(
197
+ "-p", type=int, help="output result every P iteration", default=0
198
+ )
199
+ if "mpi" in ALL_BACKEND:
200
+ parser.add_argument(
201
+ "--mpi-sync-interval",
202
+ type=int,
203
+ help="MPI sync iteration interval",
204
+ default=100,
205
+ )
206
+ parser.add_argument(
207
+ "--grid-x", type=int, help="x axis stride for grid solver", default=8
208
+ )
209
+ parser.add_argument(
210
+ "--grid-y", type=int, help="y axis stride for grid solver", default=8
211
+ )
212
+ self.parser=parser
213
+
214
+ if __name__ =="__main__":
215
+ import sys
216
+ import io
217
+ import base64
218
+ from PIL import Image
219
+ def base64_to_pil(base64_str):
220
+ data = base64.b64decode(str(base64_str))
221
+ pil = Image.open(io.BytesIO(data))
222
+ return pil
223
+
224
+ def pil_to_base64(out_pil):
225
+ out_buffer = io.BytesIO()
226
+ out_pil.save(out_buffer, format="PNG")
227
+ out_buffer.seek(0)
228
+ base64_bytes = base64.b64encode(out_buffer.read())
229
+ base64_str = base64_bytes.decode("ascii")
230
+ return base64_str
231
+ correction_func=PhotometricCorrection(quite=True)
232
+ while True:
233
+ buffer = sys.stdin.readline()
234
+ print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
235
+ if len(buffer)==0:
236
+ break
237
+ if isinstance(buffer,str):
238
+ lst=buffer.strip().split(",")
239
+ else:
240
+ lst=buffer.decode("ascii").strip().split(",")
241
+ img0=base64_to_pil(lst[0])
242
+ img1=base64_to_pil(lst[1])
243
+ ret=correction_func.run(img0,img1,mode=lst[2])
244
+ ret_base64=pil_to_base64(ret)
245
+ if isinstance(buffer,str):
246
+ sys.stdout.write(f"{ret_base64}\n")
247
+ else:
248
+ sys.stdout.write(f"{ret_base64}\n".encode())
249
+ sys.stdout.flush()
process.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+ import os
26
+ from abc import ABC, abstractmethod
27
+ from typing import Any, Optional, Tuple
28
+
29
+ import numpy as np
30
+
31
+ from fpie import np_solver
32
+
33
+ import scipy
34
+ import scipy.signal
35
+
36
+ CPU_COUNT = os.cpu_count() or 1
37
+ DEFAULT_BACKEND = "numpy"
38
+ ALL_BACKEND = ["numpy"]
39
+
40
+ try:
41
+ from fpie import numba_solver
42
+ ALL_BACKEND += ["numba"]
43
+ DEFAULT_BACKEND = "numba"
44
+ except ImportError:
45
+ numba_solver = None # type: ignore
46
+
47
+ try:
48
+ from fpie import taichi_solver
49
+ ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
50
+ DEFAULT_BACKEND = "taichi-cpu"
51
+ except ImportError:
52
+ taichi_solver = None # type: ignore
53
+
54
+ # try:
55
+ # from fpie import core_gcc # type: ignore
56
+ # DEFAULT_BACKEND = "gcc"
57
+ # ALL_BACKEND.append("gcc")
58
+ # except ImportError:
59
+ # core_gcc = None
60
+
61
+ # try:
62
+ # from fpie import core_openmp # type: ignore
63
+ # DEFAULT_BACKEND = "openmp"
64
+ # ALL_BACKEND.append("openmp")
65
+ # except ImportError:
66
+ # core_openmp = None
67
+
68
+ # try:
69
+ # from mpi4py import MPI
70
+
71
+ # from fpie import core_mpi # type: ignore
72
+ # ALL_BACKEND.append("mpi")
73
+ # except ImportError:
74
+ # MPI = None # type: ignore
75
+ # core_mpi = None
76
+
77
+ try:
78
+ from fpie import core_cuda # type: ignore
79
+ DEFAULT_BACKEND = "cuda"
80
+ ALL_BACKEND.append("cuda")
81
+ except ImportError:
82
+ core_cuda = None
83
+
84
+
85
+ class BaseProcessor(ABC):
86
+ """API definition for processor class."""
87
+
88
+ def __init__(
89
+ self, gradient: str, rank: int, backend: str, core: Optional[Any]
90
+ ):
91
+ if core is None:
92
+ error_msg = {
93
+ "numpy":
94
+ "Please run `pip install numpy`.",
95
+ "numba":
96
+ "Please run `pip install numba`.",
97
+ "gcc":
98
+ "Please install cmake and gcc in your operating system.",
99
+ "openmp":
100
+ "Please make sure your gcc is compatible with `-fopenmp` option.",
101
+ "mpi":
102
+ "Please install MPI and run `pip install mpi4py`.",
103
+ "cuda":
104
+ "Please make sure nvcc and cuda-related libraries are available.",
105
+ "taichi":
106
+ "Please run `pip install taichi`.",
107
+ }
108
+ print(error_msg[backend.split("-")[0]])
109
+
110
+ raise AssertionError(f"Invalid backend {backend}.")
111
+
112
+ self.gradient = gradient
113
+ self.rank = rank
114
+ self.backend = backend
115
+ self.core = core
116
+ self.root = rank == 0
117
+
118
+ def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
119
+ if self.gradient == "src":
120
+ return a
121
+ if self.gradient == "avg":
122
+ return (a + b) / 2
123
+ # mix gradient, see Equ. 12 in PIE paper
124
+ mask = np.abs(a) < np.abs(b)
125
+ a[mask] = b[mask]
126
+ return a
127
+
128
+ @abstractmethod
129
+ def reset(
130
+ self,
131
+ src: np.ndarray,
132
+ mask: np.ndarray,
133
+ tgt: np.ndarray,
134
+ mask_on_src: Tuple[int, int],
135
+ mask_on_tgt: Tuple[int, int],
136
+ ) -> int:
137
+ pass
138
+
139
+ def sync(self) -> None:
140
+ self.core.sync()
141
+
142
+ @abstractmethod
143
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
144
+ pass
145
+
146
+
147
+ class EquProcessor(BaseProcessor):
148
+ """PIE Jacobi equation processor."""
149
+
150
+ def __init__(
151
+ self,
152
+ gradient: str = "max",
153
+ backend: str = DEFAULT_BACKEND,
154
+ n_cpu: int = CPU_COUNT,
155
+ min_interval: int = 100,
156
+ block_size: int = 1024,
157
+ ):
158
+ core: Optional[Any] = None
159
+ rank = 0
160
+
161
+ if backend == "numpy":
162
+ core = np_solver.EquSolver()
163
+ elif backend == "numba" and numba_solver is not None:
164
+ core = numba_solver.EquSolver()
165
+ elif backend == "gcc":
166
+ core = core_gcc.EquSolver()
167
+ elif backend == "openmp" and core_openmp is not None:
168
+ core = core_openmp.EquSolver(n_cpu)
169
+ elif backend == "mpi" and core_mpi is not None:
170
+ core = core_mpi.EquSolver(min_interval)
171
+ rank = MPI.COMM_WORLD.Get_rank()
172
+ elif backend == "cuda" and core_cuda is not None:
173
+ core = core_cuda.EquSolver(block_size)
174
+ elif backend.startswith("taichi") and taichi_solver is not None:
175
+ core = taichi_solver.EquSolver(backend, n_cpu, block_size)
176
+
177
+ super().__init__(gradient, rank, backend, core)
178
+
179
+ def mask2index(
180
+ self, mask: np.ndarray
181
+ ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
182
+ x, y = np.nonzero(mask)
183
+ max_id = x.shape[0] + 1
184
+ index = np.zeros((max_id, 3))
185
+ ids = self.core.partition(mask)
186
+ ids[mask == 0] = 0 # reserve id=0 for constant
187
+ index = ids[x, y].argsort()
188
+ return ids, max_id, x[index], y[index]
189
+
190
+ def reset(
191
+ self,
192
+ src: np.ndarray,
193
+ mask: np.ndarray,
194
+ tgt: np.ndarray,
195
+ mask_on_src: Tuple[int, int],
196
+ mask_on_tgt: Tuple[int, int],
197
+ ) -> int:
198
+ assert self.root
199
+ # check validity
200
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
201
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
202
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
203
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
204
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
205
+
206
+ if len(mask.shape) == 3:
207
+ mask = mask.mean(-1)
208
+ mask = (mask >= 128).astype(np.int32)
209
+
210
+ # zero-out edge
211
+ mask[0] = 0
212
+ mask[-1] = 0
213
+ mask[:, 0] = 0
214
+ mask[:, -1] = 0
215
+
216
+ x, y = np.nonzero(mask)
217
+ x0, x1 = x.min() - 1, x.max() + 2
218
+ y0, y1 = y.min() - 1, y.max() + 2
219
+ mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
220
+ mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
221
+ mask = mask[x0:x1, y0:y1]
222
+ ids, max_id, index_x, index_y = self.mask2index(mask)
223
+
224
+ src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
225
+ tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
226
+
227
+ src_C = src[src_x, src_y].astype(np.float32)
228
+ src_U = src[src_x - 1, src_y].astype(np.float32)
229
+ src_D = src[src_x + 1, src_y].astype(np.float32)
230
+ src_L = src[src_x, src_y - 1].astype(np.float32)
231
+ src_R = src[src_x, src_y + 1].astype(np.float32)
232
+ tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
233
+ tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
234
+ tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
235
+ tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
236
+ tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
237
+
238
+ grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
239
+ + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
240
+ + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
241
+ + self.mixgrad(src_C - src_D, tgt_C - tgt_D)
242
+
243
+ A = np.zeros((max_id, 4), np.int32)
244
+ X = np.zeros((max_id, 3), np.float32)
245
+ B = np.zeros((max_id, 3), np.float32)
246
+
247
+ X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
248
+ # four-way
249
+ A[1:, 0] = ids[index_x - 1, index_y]
250
+ A[1:, 1] = ids[index_x + 1, index_y]
251
+ A[1:, 2] = ids[index_x, index_y - 1]
252
+ A[1:, 3] = ids[index_x, index_y + 1]
253
+ B[1:] = grad
254
+ m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
255
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
256
+ m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
257
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
258
+ m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
259
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
260
+ m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
261
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
262
+
263
+ self.tgt = tgt.copy()
264
+ self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
265
+ self.core.reset(max_id, A, X, B)
266
+ return max_id
267
+
268
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
269
+ result = self.core.step(iteration)
270
+ if self.root:
271
+ x, err = result
272
+ self.tgt[self.tgt_index] = x[1:]
273
+ return self.tgt, err
274
+ return None
275
+
276
+
277
+ class GridProcessor(BaseProcessor):
278
+ """PIE grid processor."""
279
+
280
+ def __init__(
281
+ self,
282
+ gradient: str = "max",
283
+ backend: str = DEFAULT_BACKEND,
284
+ n_cpu: int = CPU_COUNT,
285
+ min_interval: int = 100,
286
+ block_size: int = 1024,
287
+ grid_x: int = 8,
288
+ grid_y: int = 8,
289
+ ):
290
+ core: Optional[Any] = None
291
+ rank = 0
292
+
293
+ if backend == "numpy":
294
+ core = np_solver.GridSolver()
295
+ elif backend == "numba" and numba_solver is not None:
296
+ core = numba_solver.GridSolver()
297
+ elif backend == "gcc":
298
+ core = core_gcc.GridSolver(grid_x, grid_y)
299
+ elif backend == "openmp" and core_openmp is not None:
300
+ core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
301
+ elif backend == "mpi" and core_mpi is not None:
302
+ core = core_mpi.GridSolver(min_interval)
303
+ rank = MPI.COMM_WORLD.Get_rank()
304
+ elif backend == "cuda" and core_cuda is not None:
305
+ core = core_cuda.GridSolver(grid_x, grid_y)
306
+ elif backend.startswith("taichi") and taichi_solver is not None:
307
+ core = taichi_solver.GridSolver(
308
+ grid_x, grid_y, backend, n_cpu, block_size
309
+ )
310
+
311
+ super().__init__(gradient, rank, backend, core)
312
+
313
+ def reset(
314
+ self,
315
+ src: np.ndarray,
316
+ mask: np.ndarray,
317
+ tgt: np.ndarray,
318
+ mask_on_src: Tuple[int, int],
319
+ mask_on_tgt: Tuple[int, int],
320
+ ) -> int:
321
+ assert self.root
322
+ # check validity
323
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
324
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
325
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
326
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
327
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
328
+
329
+ if len(mask.shape) == 3:
330
+ mask = mask.mean(-1)
331
+ mask = (mask >= 128).astype(np.int32)
332
+
333
+ # zero-out edge
334
+ mask[0] = 0
335
+ mask[-1] = 0
336
+ mask[:, 0] = 0
337
+ mask[:, -1] = 0
338
+
339
+ x, y = np.nonzero(mask)
340
+ x0, x1 = x.min() - 1, x.max() + 2
341
+ y0, y1 = y.min() - 1, y.max() + 2
342
+ mask = mask[x0:x1, y0:y1]
343
+ max_id = np.prod(mask.shape)
344
+
345
+ src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
346
+ mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
347
+ tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
348
+ mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
349
+ grad = np.zeros([*mask.shape, 3], np.float32)
350
+ grad[1:] += self.mixgrad(
351
+ src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
352
+ )
353
+ grad[:-1] += self.mixgrad(
354
+ src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
355
+ )
356
+ grad[:, 1:] += self.mixgrad(
357
+ src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
358
+ )
359
+ grad[:, :-1] += self.mixgrad(
360
+ src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
361
+ )
362
+
363
+ grad[mask == 0] = 0
364
+ if True:
365
+ kernel = [[1] * 3 for _ in range(3)]
366
+ nmask = mask.copy()
367
+ nmask[nmask > 0] = 1
368
+ res = scipy.signal.convolve2d(
369
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
370
+ )
371
+ res[nmask < 1] = 0
372
+ res[res == 9] = 0
373
+ res[res > 0] = 1
374
+ grad[res>0]=0
375
+ # ylst, xlst = res.nonzero()
376
+ # for y, x in zip(ylst, xlst):
377
+ # grad[y,x]=0
378
+ # for yi in range(-1,2):
379
+ # for xi in range(-1,2):
380
+ # grad[y+yi,x+xi]=0
381
+ self.x0 = mask_on_tgt[0] + x0
382
+ self.x1 = mask_on_tgt[0] + x1
383
+ self.y0 = mask_on_tgt[1] + y0
384
+ self.y1 = mask_on_tgt[1] + y1
385
+ self.tgt = tgt.copy()
386
+ self.core.reset(max_id, mask, tgt_crop, grad)
387
+ return max_id
388
+
389
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
390
+ result = self.core.step(iteration)
391
+ if self.root:
392
+ tgt, err = result
393
+ self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
394
+ return self.tgt, err
395
+ return None
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/nightly/cu117
2
+ imageio==2.19.5
3
+ imageio-ffmpeg==0.4.7
4
+ numpy==1.22.4
5
+ opencv-python-headless==4.6.0.66
6
+ torch[dynamo]
7
+ torchvision
8
+ Pillow
9
+ scipy
10
+ scikit-image
11
+ diffusers==0.9.0
12
+ transformers
13
+ ftfy
14
+ fpie
15
+ accelerate
16
+ ninja
17
+ setuptools==59.8.0
utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageFilter
3
+ import cv2
4
+ import numpy as np
5
+ import scipy
6
+ import scipy.signal
7
+ from scipy.spatial import cKDTree
8
+
9
+ import os
10
+ from perlin2d import *
11
+
12
+ patch_match_compiled = True
13
+
14
+ try:
15
+ from PyPatchMatch import patch_match
16
+ except Exception as e:
17
+ try:
18
+ import patch_match
19
+ except Exception as e:
20
+ patch_match_compiled = False
21
+
22
+ try:
23
+ patch_match
24
+ except NameError:
25
+ print("patch_match compiling failed, will fall back to edge_pad")
26
+ patch_match_compiled = False
27
+
28
+
29
+
30
+
31
+ def edge_pad(img, mask, mode=1):
32
+ if mode == 0:
33
+ nmask = mask.copy()
34
+ nmask[nmask > 0] = 1
35
+ res0 = 1 - nmask
36
+ res1 = nmask
37
+ p0 = np.stack(res0.nonzero(), axis=0).transpose()
38
+ p1 = np.stack(res1.nonzero(), axis=0).transpose()
39
+ min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
40
+ loc = p1[min_dist_idx]
41
+ for (a, b), (c, d) in zip(p0, loc):
42
+ img[a, b] = img[c, d]
43
+ elif mode == 1:
44
+ record = {}
45
+ kernel = [[1] * 3 for _ in range(3)]
46
+ nmask = mask.copy()
47
+ nmask[nmask > 0] = 1
48
+ res = scipy.signal.convolve2d(
49
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
50
+ )
51
+ res[nmask < 1] = 0
52
+ res[res == 9] = 0
53
+ res[res > 0] = 1
54
+ ylst, xlst = res.nonzero()
55
+ queue = [(y, x) for y, x in zip(ylst, xlst)]
56
+ # bfs here
57
+ cnt = res.astype(np.float32)
58
+ acc = img.astype(np.float32)
59
+ step = 1
60
+ h = acc.shape[0]
61
+ w = acc.shape[1]
62
+ offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
63
+ while queue:
64
+ target = []
65
+ for y, x in queue:
66
+ val = acc[y][x]
67
+ for yo, xo in offset:
68
+ yn = y + yo
69
+ xn = x + xo
70
+ if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
71
+ if record.get((yn, xn), step) == step:
72
+ acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
73
+ cnt[yn][xn] += 1
74
+ acc[yn][xn] /= cnt[yn][xn]
75
+ if (yn, xn) not in record:
76
+ record[(yn, xn)] = step
77
+ target.append((yn, xn))
78
+ step += 1
79
+ queue = target
80
+ img = acc.astype(np.uint8)
81
+ else:
82
+ nmask = mask.copy()
83
+ ylst, xlst = nmask.nonzero()
84
+ yt, xt = ylst.min(), xlst.min()
85
+ yb, xb = ylst.max(), xlst.max()
86
+ content = img[yt : yb + 1, xt : xb + 1]
87
+ img = np.pad(
88
+ content,
89
+ ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
90
+ mode="edge",
91
+ )
92
+ return img, mask
93
+
94
+
95
+ def perlin_noise(img, mask):
96
+ lin = np.linspace(0, 5, mask.shape[0], endpoint=False)
97
+ x, y = np.meshgrid(lin, lin)
98
+ avg = img.mean(axis=0).mean(axis=0)
99
+ # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
100
+ noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
101
+ noise = np.stack(noise, axis=-1)
102
+ # mask=skimage.measure.block_reduce(mask,(8,8),np.min)
103
+ # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
104
+ # mask_image=Image.fromarray(mask)
105
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
106
+ # mask=np.array(mask_image)
107
+ nmask = mask.copy()
108
+ # nmask=nmask/255.0
109
+ nmask[mask > 0] = 1
110
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
111
+ # img=img.astype(np.uint8)
112
+ return img, mask
113
+
114
+
115
+ def gaussian_noise(img, mask):
116
+ noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
117
+ noise = (noise + 1) / 2 * 255
118
+ noise = noise.astype(np.uint8)
119
+ nmask = mask.copy()
120
+ nmask[mask > 0] = 1
121
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
122
+ return img, mask
123
+
124
+
125
+ def cv2_telea(img, mask):
126
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
127
+ return ret, mask
128
+
129
+
130
+ def cv2_ns(img, mask):
131
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
132
+ return ret, mask
133
+
134
+
135
+ def patch_match_func(img, mask):
136
+ ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
137
+ return ret, mask
138
+
139
+
140
+ def mean_fill(img, mask):
141
+ avg = img.mean(axis=0).mean(axis=0)
142
+ img[mask < 1] = avg
143
+ return img, mask
144
+
145
+ def g_diffuser(img,mask):
146
+ return img, mask
147
+
148
+ def dummy_fill(img,mask):
149
+ return img,mask
150
+ functbl = {
151
+ "gaussian": gaussian_noise,
152
+ "perlin": perlin_noise,
153
+ "edge_pad": edge_pad,
154
+ "patchmatch": patch_match_func if patch_match_compiled else edge_pad,
155
+ "cv2_ns": cv2_ns,
156
+ "cv2_telea": cv2_telea,
157
+ "g_diffuser": g_diffuser,
158
+ "g_diffuser_lib": dummy_fill,
159
+ }
160
+
161
+ try:
162
+ from postprocess import PhotometricCorrection
163
+ correction_func = PhotometricCorrection()
164
+ except Exception as e:
165
+ print(e, "so PhotometricCorrection is disabled")
166
+ class DummyCorrection:
167
+ def __init__(self):
168
+ self.backend=""
169
+ pass
170
+ def run(self,a,b,**kwargs):
171
+ return b
172
+ correction_func=DummyCorrection()
173
+
174
+ if "taichi" in correction_func.backend:
175
+ import sys
176
+ import io
177
+ import base64
178
+ from PIL import Image
179
+ def base64_to_pil(base64_str):
180
+ data = base64.b64decode(str(base64_str))
181
+ pil = Image.open(io.BytesIO(data))
182
+ return pil
183
+
184
+ def pil_to_base64(out_pil):
185
+ out_buffer = io.BytesIO()
186
+ out_pil.save(out_buffer, format="PNG")
187
+ out_buffer.seek(0)
188
+ base64_bytes = base64.b64encode(out_buffer.read())
189
+ base64_str = base64_bytes.decode("ascii")
190
+ return base64_str
191
+ from subprocess import Popen, PIPE, STDOUT
192
+ class SubprocessCorrection:
193
+ def __init__(self):
194
+ self.backend=correction_func.backend
195
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
196
+ def run(self,img_input,img_inpainted,mode):
197
+ if mode=="disabled":
198
+ return img_inpainted
199
+ base64_str_input = pil_to_base64(img_input)
200
+ base64_str_inpainted = pil_to_base64(img_inpainted)
201
+ try:
202
+ if self.child.poll():
203
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
204
+ self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
205
+ self.child.stdin.flush()
206
+ out = self.child.stdout.readline()
207
+ base64_str=out.decode().strip()
208
+ while base64_str and base64_str[0]=="[":
209
+ print(base64_str)
210
+ out = self.child.stdout.readline()
211
+ base64_str=out.decode().strip()
212
+ ret=base64_to_pil(base64_str)
213
+ except:
214
+ print("[PIE] not working, photometric correction is disabled")
215
+ ret=img_inpainted
216
+ return ret
217
+ correction_func = SubprocessCorrection()