Finished the initial cleanup for OverlappingWaveFormCollapse, and added it to the build.

This commit is contained in:
Relintai 2022-04-22 18:50:27 +02:00
parent 21be696f6a
commit 6e5407c55f
3 changed files with 112 additions and 95 deletions

View File

@ -9,4 +9,4 @@ env_wfc.add_source_files(env.modules_sources, "register_types.cpp")
env_wfc.add_source_files(env.modules_sources, "wave_form_collapse.cpp") env_wfc.add_source_files(env.modules_sources, "wave_form_collapse.cpp")
env_wfc.add_source_files(env.modules_sources, "tiling_wave_form_collapse.cpp") env_wfc.add_source_files(env.modules_sources, "tiling_wave_form_collapse.cpp")
#env_wfc.add_source_files(env.modules_sources, "overlapping_wave_form_collapse.cpp") env_wfc.add_source_files(env.modules_sources, "overlapping_wave_form_collapse.cpp")

View File

@ -1,18 +1,24 @@
#include "overlapping_wave_form_collapse.h" #include "overlapping_wave_form_collapse.h"
#include "core/set.h"
void OverlappingWaveFormCollapse::set_input(const Array2D<uint32_t> &data) {
input = data;
}
uint32_t OverlappingWaveFormCollapse::get_wave_height() const { uint32_t OverlappingWaveFormCollapse::get_wave_height() const {
return periodic_output ? out_height : out_height - pattern_size + 1; return periodic_output ? out_height : out_height - pattern_size + 1;
} }
//Get the wave width given these options. //Get the wave width given these
uint32_t OverlappingWaveFormCollapse::get_wave_width() const { uint32_t OverlappingWaveFormCollapse::get_wave_width() const {
return periodic_output ? out_width : out_width - pattern_size + 1; return periodic_output ? out_width : out_width - pattern_size + 1;
} }
// Run the WFC algorithm, and return the result if the algorithm succeeded. // Run the WFC algorithm, and return the result if the algorithm succeeded.
Array2D<uint32_t> OverlappingWaveFormCollapse::run() { Array2D<uint32_t> OverlappingWaveFormCollapse::orun() {
Array2D<uint32_t> result = wfc.run(); Array2D<uint32_t> result = run();
if (result.width == 0 && result.height == 0) { if (result.width == 0 && result.height == 0) {
return Array2D<uint32_t>(0, 0); return Array2D<uint32_t>(0, 0);
@ -21,29 +27,29 @@ Array2D<uint32_t> OverlappingWaveFormCollapse::run() {
return to_image(result); return to_image(result);
} }
void OverlappingWaveFormCollapse::init_ground(WFC &wfc, const Array2D<T> &input, const Vector<Array2D<T>> &patterns, const OverlappingWFCOptions &options) { void OverlappingWaveFormCollapse::init_ground() {
uint32_t ground_pattern_id = get_ground_pattern_id(input, patterns, options); uint32_t ground_pattern_id = get_ground_pattern_id();
for (uint32_t j = 0; j < options.get_wave_width(); j++) { for (uint32_t j = 0; j < get_wave_width(); j++) {
set_pattern(ground_pattern_id, options.get_wave_height() - 1, j); set_pattern(ground_pattern_id, get_wave_height() - 1, j);
} }
for (uint32_t i = 0; i < options.get_wave_height() - 1; i++) { for (uint32_t i = 0; i < get_wave_height() - 1; i++) {
for (uint32_t j = 0; j < options.get_wave_width(); j++) { for (uint32_t j = 0; j < get_wave_width(); j++) {
wfc.remove_wave_pattern(i, j, ground_pattern_id); remove_wave_pattern(i, j, ground_pattern_id);
} }
} }
wfc.propagate(); propagate();
} }
// Set the pattern at a specific position. // Set the pattern at a specific position.
// Returns false if the given pattern does not exist, or if the // Returns false if the given pattern does not exist, or if the
// coordinates are not in the wave // coordinates are not in the wave
bool OverlappingWaveFormCollapse::set_pattern(const Array2D<uint32_t> &pattern, uint32_t i, uint32_t j) { bool OverlappingWaveFormCollapse::set_pattern(const Array2D<uint32_t> &pattern, uint32_t i, uint32_t j) {
auto pattern_id = get_pattern_id(pattern); uint32_t pattern_id = get_pattern_id(pattern);
if (pattern_id == std::nullopt || i >= options.get_wave_height() || j >= options.get_wave_width()) { if (pattern_id == static_cast<uint32_t>(-1) || i >= get_wave_height() || j >= get_wave_width()) {
return false; return false;
} }
@ -51,9 +57,9 @@ bool OverlappingWaveFormCollapse::set_pattern(const Array2D<uint32_t> &pattern,
return true; return true;
} }
static uint32_t OverlappingWaveFormCollapse::get_ground_pattern_id(const Array2D<uint32_t> &input, const Vector<Array2D<uint32_t>> &patterns, const OverlappingWFCOptions &options) { uint32_t OverlappingWaveFormCollapse::get_ground_pattern_id() {
// Get the pattern. // Get the pattern.
Array2D<uint32_t> ground_pattern = input.get_sub_array(input.height - 1, input.width / 2, options.pattern_size, options.pattern_size); Array2D<uint32_t> ground_pattern = input.get_sub_array(input.height - 1, input.width / 2, pattern_size, pattern_size);
// Retrieve the id of the pattern. // Retrieve the id of the pattern.
for (int i = 0; i < patterns.size(); i++) { for (int i = 0; i < patterns.size(); i++) {
@ -66,10 +72,10 @@ static uint32_t OverlappingWaveFormCollapse::get_ground_pattern_id(const Array2D
} }
uint32_t OverlappingWaveFormCollapse::get_pattern_id(const Array2D<uint32_t> &pattern) { uint32_t OverlappingWaveFormCollapse::get_pattern_id(const Array2D<uint32_t> &pattern) {
uint32_t *pattern_id = std::find(patterns.begin(), patterns.end(), pattern); for (int i = 0; i < patterns.size(); ++i) {
if (patterns[i] == pattern) {
if (pattern_id != patterns.end()) { return i;
return *pattern_id; }
} }
return -1; return -1;
@ -79,45 +85,59 @@ uint32_t OverlappingWaveFormCollapse::get_pattern_id(const Array2D<uint32_t> &pa
// pattern_id needs to be a valid pattern id, and i and j needs to be in the wave range // pattern_id needs to be a valid pattern id, and i and j needs to be in the wave range
void OverlappingWaveFormCollapse::set_pattern(uint32_t pattern_id, uint32_t i, uint32_t j) { void OverlappingWaveFormCollapse::set_pattern(uint32_t pattern_id, uint32_t i, uint32_t j) {
for (int p = 0; p < patterns.size(); p++) { for (int p = 0; p < patterns.size(); p++) {
if (pattern_id != p) { if (pattern_id != static_cast<uint32_t>(p)) {
wfc.remove_wave_pattern(i, j, p); remove_wave_pattern(i, j, p);
} }
} }
} }
//Return the list of patterns, as well as their probabilities of apparition. //Return the list of patterns, as well as their probabilities of apparition.
static std::pair<Vector<Array2D<uint32_t>>, Vector<double>> OverlappingWaveFormCollapse::get_patterns(const Array2D<uint32_t> &input, const OverlappingWFCOptions &options) { void OverlappingWaveFormCollapse::get_patterns() {
std::unordered_map<Array2D<uint32_t>, uint32_t> patterns_id; //OAHashMap<Array2D<uint32_t>, uint32_t> patterns_id;
Vector<Array2D<uint32_t>> patterns;
LocalVector<Array2D<uint32_t>> patterns_id;
patterns.clear();
// The number of time a pattern is seen in the input image. // The number of time a pattern is seen in the input image.
Vector<double> patterns_weight; Vector<double> patterns_weight;
Vector<Array2D<uint32_t>> symmetries(8, Array2D<uint32_t>(options.pattern_size, options.pattern_size)); Vector<Array2D<uint32_t>> symmetries;
uint32_t max_i = options.periodic_input ? input.height : input.height - options.pattern_size + 1; symmetries.resize(8);
uint32_t max_j = options.periodic_input ? input.width : input.width - options.pattern_size + 1;
for (int i = 0; i < 8; ++i) {
symmetries.write[i].resize(pattern_size, pattern_size);
}
uint32_t max_i = periodic_input ? input.height : input.height - pattern_size + 1;
uint32_t max_j = periodic_input ? input.width : input.width - pattern_size + 1;
for (uint32_t i = 0; i < max_i; i++) { for (uint32_t i = 0; i < max_i; i++) {
for (uint32_t j = 0; j < max_j; j++) { for (uint32_t j = 0; j < max_j; j++) {
// Compute the symmetries of every pattern in the image. // Compute the symmetries of every pattern in the image.
symmetries[0].data = input.get_sub_array(i, j, options.pattern_size, options.pattern_size).data; symmetries.write[0].data = input.get_sub_array(i, j, pattern_size, pattern_size).data;
symmetries[1].data = symmetries[0].reflected().data; symmetries.write[1].data = symmetries[0].reflected().data;
symmetries[2].data = symmetries[0].rotated().data; symmetries.write[2].data = symmetries[0].rotated().data;
symmetries[3].data = symmetries[2].reflected().data; symmetries.write[3].data = symmetries[2].reflected().data;
symmetries[4].data = symmetries[2].rotated().data; symmetries.write[4].data = symmetries[2].rotated().data;
symmetries[5].data = symmetries[4].reflected().data; symmetries.write[5].data = symmetries[4].reflected().data;
symmetries[6].data = symmetries[4].rotated().data; symmetries.write[6].data = symmetries[4].rotated().data;
symmetries[7].data = symmetries[6].reflected().data; symmetries.write[7].data = symmetries[6].reflected().data;
// The number of symmetries in the option class define which symetries // The number of symmetries in the option class define which symetries will be used.
// will be used. for (uint32_t k = 0; k < symmetry; k++) {
for (uint32_t k = 0; k < options.symmetry; k++) { int indx = patterns.size();
auto res = patterns_id.insert(std::make_pair(symmetries[k], patterns.size()));
// If the pattern already exist, we just have to increase its number for (uint32_t h = 0; h < patterns_id.size(); ++h) {
// of appearance. if (patterns_id[h] == symmetries[k]) {
if (!res.second) { indx = h;
patterns_weight[res.first->second] += 1; break;
}
}
if (indx != patterns.size()) {
// If the pattern already exist, we just have to increase its number of appearance.
patterns_weight.write[indx] += 1;
} else { } else {
patterns.push_back(symmetries[k]); patterns.push_back(symmetries[k]);
patterns_weight.push_back(1); patterns_weight.push_back(1);
@ -126,11 +146,11 @@ static std::pair<Vector<Array2D<uint32_t>>, Vector<double>> OverlappingWaveFormC
} }
} }
return { patterns, patterns_weight }; set_pattern_frequencies(patterns_weight);
} }
//Return true if the pattern1 is compatible with pattern2 when pattern2 is at a distance (dy,dx) from pattern1. //Return true if the pattern1 is compatible with pattern2 when pattern2 is at a distance (dy,dx) from pattern1.
static bool OverlappingWaveFormCollapse::agrees(const Array2D<uint32_t> &pattern1, const Array2D<uint32_t> &pattern2, int dy, int dx) { bool OverlappingWaveFormCollapse::agrees(const Array2D<uint32_t> &pattern1, const Array2D<uint32_t> &pattern2, int dy, int dx) {
uint32_t xmin = dx < 0 ? 0 : dx; uint32_t xmin = dx < 0 ? 0 : dx;
uint32_t xmax = dx < 0 ? dx + pattern2.width : pattern1.width; uint32_t xmax = dx < 0 ? dx + pattern2.width : pattern1.width;
uint32_t ymin = dy < 0 ? 0 : dy; uint32_t ymin = dy < 0 ? 0 : dy;
@ -153,8 +173,8 @@ static bool OverlappingWaveFormCollapse::agrees(const Array2D<uint32_t> &pattern
// If agrees(pattern1, pattern2, dy, dx), then compatible[pattern1][direction] // If agrees(pattern1, pattern2, dy, dx), then compatible[pattern1][direction]
// contains pattern2, where direction is the direction defined by (dy, dx) // contains pattern2, where direction is the direction defined by (dy, dx)
// (see direction.hpp). // (see direction.hpp).
static Vector<PropagatorStateEntry> OverlappingWaveFormCollapse::generate_compatible(const Vector<Array2D<uint32_t>> &patterns) { Vector<WaveFormCollapse::PropagatorStateEntry> OverlappingWaveFormCollapse::generate_compatible() {
Vector<PropagatorStateEntry> compatible; Vector<WaveFormCollapse::PropagatorStateEntry> compatible;
compatible.resize(patterns.size()); compatible.resize(patterns.size());
// Iterate on every dy, dx, pattern1 and pattern2 // Iterate on every dy, dx, pattern1 and pattern2
@ -162,7 +182,7 @@ static Vector<PropagatorStateEntry> OverlappingWaveFormCollapse::generate_compat
for (uint32_t direction = 0; direction < 4; direction++) { for (uint32_t direction = 0; direction < 4; direction++) {
for (int pattern2 = 0; pattern2 < patterns.size(); pattern2++) { for (int pattern2 = 0; pattern2 < patterns.size(); pattern2++) {
if (agrees(patterns[pattern1], patterns[pattern2], directions_y[direction], directions_x[direction])) { if (agrees(patterns[pattern1], patterns[pattern2], directions_y[direction], directions_x[direction])) {
compatible[pattern1][direction].push_back(pattern2); compatible.write[pattern1].directions[direction].push_back(pattern2);
} }
} }
} }
@ -173,41 +193,41 @@ static Vector<PropagatorStateEntry> OverlappingWaveFormCollapse::generate_compat
// Transform a 2D array containing the patterns id to a 2D array containing the pixels. // Transform a 2D array containing the patterns id to a 2D array containing the pixels.
Array2D<uint32_t> OverlappingWaveFormCollapse::to_image(const Array2D<uint32_t> &output_patterns) const { Array2D<uint32_t> OverlappingWaveFormCollapse::to_image(const Array2D<uint32_t> &output_patterns) const {
Array2D<uint32_t> output = Array2D<uint32_t>(options.out_height, options.out_width); Array2D<uint32_t> output = Array2D<uint32_t>(out_height, out_width);
if (options.periodic_output) { if (periodic_output) {
for (uint32_t y = 0; y < options.get_wave_height(); y++) { for (uint32_t y = 0; y < get_wave_height(); y++) {
for (uint32_t x = 0; x < options.get_wave_width(); x++) { for (uint32_t x = 0; x < get_wave_width(); x++) {
output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0); output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0);
} }
} }
} else { } else {
for (uint32_t y = 0; y < options.get_wave_height(); y++) { for (uint32_t y = 0; y < get_wave_height(); y++) {
for (uint32_t x = 0; x < options.get_wave_width(); x++) { for (uint32_t x = 0; x < get_wave_width(); x++) {
output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0); output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0);
} }
} }
for (uint32_t y = 0; y < options.get_wave_height(); y++) { for (uint32_t y = 0; y < get_wave_height(); y++) {
const Array2D<uint32_t> &pattern = patterns[output_patterns.get(y, options.get_wave_width() - 1)]; const Array2D<uint32_t> &pattern = patterns[output_patterns.get(y, get_wave_width() - 1)];
for (uint32_t dx = 1; dx < options.pattern_size; dx++) { for (uint32_t dx = 1; dx < pattern_size; dx++) {
output.get(y, options.get_wave_width() - 1 + dx) = pattern.get(0, dx); output.get(y, get_wave_width() - 1 + dx) = pattern.get(0, dx);
} }
} }
for (uint32_t x = 0; x < options.get_wave_width(); x++) { for (uint32_t x = 0; x < get_wave_width(); x++) {
const Array2D<uint32_t> &pattern = patterns[output_patterns.get(options.get_wave_height() - 1, x)]; const Array2D<uint32_t> &pattern = patterns[output_patterns.get(get_wave_height() - 1, x)];
for (uint32_t dy = 1; dy < options.pattern_size; dy++) { for (uint32_t dy = 1; dy < pattern_size; dy++) {
output.get(options.get_wave_height() - 1 + dy, x) = output.get(get_wave_height() - 1 + dy, x) =
pattern.get(dy, 0); pattern.get(dy, 0);
} }
} }
const Array2D<uint32_t> &pattern = patterns[output_patterns.get(options.get_wave_height() - 1, options.get_wave_width() - 1)]; const Array2D<uint32_t> &pattern = patterns[output_patterns.get(get_wave_height() - 1, get_wave_width() - 1)];
for (uint32_t dy = 1; dy < options.pattern_size; dy++) { for (uint32_t dy = 1; dy < pattern_size; dy++) {
for (uint32_t dx = 1; dx < options.pattern_size; dx++) { for (uint32_t dx = 1; dx < pattern_size; dx++) {
output.get(options.get_wave_height() - 1 + dy, options.get_wave_width() - 1 + dx) = pattern.get(dy, dx); output.get(get_wave_height() - 1 + dy, get_wave_width() - 1 + dx) = pattern.get(dy, dx);
} }
} }
} }
@ -215,32 +235,25 @@ Array2D<uint32_t> OverlappingWaveFormCollapse::to_image(const Array2D<uint32_t>
return output; return output;
} }
/* void OverlappingWaveFormCollapse::initialize() {
// If necessary, the ground is set.
OverlappingWFC(const Array2D<uint32_t> &input, const OverlappingWFCOptions &options, int seed) : if (ground) {
OverlappingWFC(input, options, seed, get_patterns(input, options)) {} init_ground();
OverlappingWFC(
const Array2D<uint32_t> &input, const OverlappingWFCOptions &options,
const int &seed,
const std::pair<Vector<Array2D<uint32_t>>, Vector<double>> &patterns,
const Vector<PropagatorStateEntry> &propagator) :
input(input), options(options), patterns(patterns.first), wfc(options.periodic_output, seed, patterns.second, propagator, options.get_wave_height(), options.get_wave_width()) {
// If necessary, the ground is set.
if (options.ground) {
init_ground(wfc, input, patterns.first, options);
}
} }
OverlappingWFC(const Array2D<uint32_t> &input, const OverlappingWFCOptions &options, set_propagator_state(generate_compatible());
const int &seed,
const std::pair<Vector<Array2D<uint32_t>>, Vector<double>> WaveFormCollapse::initialize();
&patterns) : }
OverlappingWFC(input, options, seed, patterns,
generate_compatible(patterns.first)) {}
*/
OverlappingWaveFormCollapse::OverlappingWaveFormCollapse() { OverlappingWaveFormCollapse::OverlappingWaveFormCollapse() {
periodic_input = false;
periodic_output = false;
out_height = 0;
out_width = 0;
symmetry = 0;
ground = false;
pattern_size = 0;
} }
OverlappingWaveFormCollapse::~OverlappingWaveFormCollapse() { OverlappingWaveFormCollapse::~OverlappingWaveFormCollapse() {
} }

View File

@ -18,25 +18,29 @@ public:
bool ground; bool ground;
uint32_t pattern_size; uint32_t pattern_size;
void set_input(const Array2D<uint32_t> &data);
uint32_t get_wave_height() const; uint32_t get_wave_height() const;
uint32_t get_wave_width() const; uint32_t get_wave_width() const;
Array2D<uint32_t> run(); Array2D<uint32_t> orun();
void init_ground(const Array2D<uint32_t> &input, const Vector<Array2D<uint32_t>> &patterns, const OverlappingWFCOptions &options); void init_ground();
bool set_pattern(const Array2D<uint32_t> &pattern, uint32_t i, uint32_t j); bool set_pattern(const Array2D<uint32_t> &pattern, uint32_t i, uint32_t j);
static uint32_t get_ground_pattern_id(const Array2D<uint32_t> &input, const Vector<Array2D<uint32_t>> &patterns, const OverlappingWFCOptions &options); uint32_t get_ground_pattern_id();
uint32_t get_pattern_id(const Array2D<uint32_t> &pattern); uint32_t get_pattern_id(const Array2D<uint32_t> &pattern);
void set_pattern(uint32_t pattern_id, uint32_t i, uint32_t j); void set_pattern(uint32_t pattern_id, uint32_t i, uint32_t j);
static std::pair<Vector<Array2D<uint32_t>>, Vector<double>> get_patterns(const Array2D<uint32_t> &input, const OverlappingWFCOptions &options); void get_patterns();
static bool agrees(const Array2D<uint32_t> &pattern1, const Array2D<uint32_t> &pattern2, int dy, int dx); bool agrees(const Array2D<uint32_t> &pattern1, const Array2D<uint32_t> &pattern2, int dy, int dx);
static Vector<PropagatorStateEntry> generate_compatible(const Vector<Array2D<uint32_t>> &patterns); Vector<WaveFormCollapse::PropagatorStateEntry> generate_compatible();
Array2D<uint32_t> to_image(const Array2D<uint32_t> &output_patterns) const; Array2D<uint32_t> to_image(const Array2D<uint32_t> &output_patterns) const;
void initialize();
OverlappingWaveFormCollapse(); OverlappingWaveFormCollapse();
~OverlappingWaveFormCollapse(); ~OverlappingWaveFormCollapse();