Spaces:
Runtime error
Runtime error
File size: 4,754 Bytes
09c675d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#pragma once
#include <opencv2/core.hpp>
#include "masked_image.h"
class PatchDistanceMetric {
public:
PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
virtual ~PatchDistanceMetric() = default;
inline int patch_size() const { return m_patch_size; }
virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
static const int kDistanceScale;
protected:
int m_patch_size;
};
class NearestNeighborField {
public:
NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
// pass
}
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
: m_source(source), m_target(target), m_distance_metric(metric) {
m_field = cv::Mat(m_source.size(), CV_32SC3);
_randomize_field(max_retry);
}
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
: m_source(source), m_target(target), m_distance_metric(metric) {
m_field = cv::Mat(m_source.size(), CV_32SC3);
_initialize_field_from(other, max_retry);
}
const MaskedImage &source() const {
return m_source;
}
const MaskedImage &target() const {
return m_target;
}
inline cv::Size source_size() const {
return m_source.size();
}
inline cv::Size target_size() const {
return m_target.size();
}
inline void set_source(const MaskedImage &source) {
m_source = source;
}
inline void set_target(const MaskedImage &target) {
m_target = target;
}
inline int *mutable_ptr(int y, int x) {
return m_field.ptr<int>(y, x);
}
inline const int *ptr(int y, int x) const {
return m_field.ptr<int>(y, x);
}
inline int at(int y, int x, int c) const {
return m_field.ptr<int>(y, x)[c];
}
inline int &at(int y, int x, int c) {
return m_field.ptr<int>(y, x)[c];
}
inline void set_identity(int y, int x) {
auto ptr = mutable_ptr(y, x);
ptr[0] = y, ptr[1] = x, ptr[2] = 0;
}
void minimize(int nr_pass);
private:
inline int _distance(int source_y, int source_x, int target_y, int target_x) {
return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
}
void _randomize_field(int max_retry = 20, bool reset = true);
void _initialize_field_from(const NearestNeighborField &other, int max_retry);
void _minimize_link(int y, int x, int direction);
MaskedImage m_source;
MaskedImage m_target;
cv::Mat m_field; // { y_target, x_target, distance_scaled }
const PatchDistanceMetric *m_distance_metric;
};
class PatchSSDDistanceMetric : public PatchDistanceMetric {
public:
using PatchDistanceMetric::PatchDistanceMetric;
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
static const int kSSDScale;
};
class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
public:
DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
protected:
int m_width, m_height;
};
class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
public:
RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
: PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
assert(m_dy1 == 0);
assert(m_dx2 == 0);
m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
}
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
protected:
double m_dx1, m_dy1, m_dx2, m_dy2;
double m_scale, m_weight;
};
class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
public:
RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
: PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
}
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
protected:
cv::Mat m_ijmap;
double m_width, m_height, m_weight;
};
|