diff --git a/modules/wfc/SCsub b/modules/wfc/SCsub index ea2afa7d9..081bc5314 100644 --- a/modules/wfc/SCsub +++ b/modules/wfc/SCsub @@ -3,6 +3,6 @@ Import("env") Import("env_modules") -env_network_synchronizer = env_modules.Clone() +env_wfc = env_modules.Clone() -env_network_synchronizer.add_source_files(env.modules_sources, "register_types.cpp") +env_wfc.add_source_files(env.modules_sources, "register_types.cpp") diff --git a/modules/wfc/array2D.hpp b/modules/wfc/array2D.hpp deleted file mode 100644 index 9fad0d80c..000000000 --- a/modules/wfc/array2D.hpp +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef FAST_WFC_UTILS_ARRAY2D_HPP_ -#define FAST_WFC_UTILS_ARRAY2D_HPP_ - -#include "assert.h" -#include - -/** - * Represent a 2D array. - * The 2D array is stored in a single array, to improve cache usage. - */ -template class Array2D { - -public: - /** - * Height and width of the 2D array. - */ - std::size_t height; - std::size_t width; - - /** - * The array containing the data of the 2D array. - */ - std::vector data; - - /** - * Build a 2D array given its height and width. - * All the array elements are initialized to default value. - */ - Array2D(std::size_t height, std::size_t width) noexcept - : height(height), width(width), data(width * height) {} - - /** - * Build a 2D array given its height and width. - * All the array elements are initialized to value. - */ - Array2D(std::size_t height, std::size_t width, T value) noexcept - : height(height), width(width), data(width * height, value) {} - - /** - * Return a const reference to the element in the i-th line and j-th column. - * i must be lower than height and j lower than width. - */ - const T &get(std::size_t i, std::size_t j) const noexcept { - assert(i < height && j < width); - return data[j + i * width]; - } - - /** - * Return a reference to the element in the i-th line and j-th column. - * i must be lower than height and j lower than width. - */ - T &get(std::size_t i, std::size_t j) noexcept { - assert(i < height && j < width); - return data[j + i * width]; - } - - /** - * Return the current 2D array reflected along the x axis. - */ - Array2D reflected() const noexcept { - Array2D result = Array2D(width, height); - for (std::size_t y = 0; y < height; y++) { - for (std::size_t x = 0; x < width; x++) { - result.get(y, x) = get(y, width - 1 - x); - } - } - return result; - } - - /** - * Return the current 2D array rotated 90° anticlockwise - */ - Array2D rotated() const noexcept { - Array2D result = Array2D(width, height); - for (std::size_t y = 0; y < width; y++) { - for (std::size_t x = 0; x < height; x++) { - result.get(y, x) = get(x, width - 1 - y); - } - } - return result; - } - - /** - * Return the sub 2D array starting from (y,x) and with size (sub_width, - * sub_height). The current 2D array is considered toric for this operation. - */ - Array2D get_sub_array(std::size_t y, std::size_t x, std::size_t sub_width, - std::size_t sub_height) const noexcept { - Array2D sub_array_2d = Array2D(sub_width, sub_height); - for (std::size_t ki = 0; ki < sub_height; ki++) { - for (std::size_t kj = 0; kj < sub_width; kj++) { - sub_array_2d.get(ki, kj) = get((y + ki) % height, (x + kj) % width); - } - } - return sub_array_2d; - } - - /** - * Check if two 2D arrays are equals. - */ - bool operator==(const Array2D &a) const noexcept { - if (height != a.height || width != a.width) { - return false; - } - - for (std::size_t i = 0; i < data.size(); i++) { - if (a.data[i] != data[i]) { - return false; - } - } - return true; - } -}; - -/** - * Hash function. - */ -namespace std { -template class hash> { -public: - std::size_t operator()(const Array2D &a) const noexcept { - std::size_t seed = a.data.size(); - for (const T &i : a.data) { - seed ^= hash()(i) + (std::size_t)0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; -} // namespace std - -#endif // FAST_WFC_UTILS_ARRAY2D_HPP_ diff --git a/modules/wfc/array3D.hpp b/modules/wfc/array3D.hpp deleted file mode 100644 index 68b7c42af..000000000 --- a/modules/wfc/array3D.hpp +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef FAST_WFC_UTILS_ARRAY3D_HPP_ -#define FAST_WFC_UTILS_ARRAY3D_HPP_ - -#include "assert.h" -#include - -/** - * Represent a 3D array. - * The 3D array is stored in a single array, to improve cache usage. - */ -template class Array3D { - -public: - /** - * The dimensions of the 3D array. - */ - std::size_t height; - std::size_t width; - std::size_t depth; - - /** - * The array containing the data of the 3D array. - */ - std::vector data; - - /** - * Build a 2D array given its height, width and depth. - * All the arrays elements are initialized to default value. - */ - Array3D(std::size_t height, std::size_t width, std::size_t depth) noexcept - : height(height), width(width), depth(depth), - data(width * height * depth) {} - - /** - * Build a 2D array given its height, width and depth. - * All the arrays elements are initialized to value - */ - Array3D(std::size_t height, std::size_t width, std::size_t depth, - T value) noexcept - : height(height), width(width), depth(depth), - data(width * height * depth, value) {} - - /** - * Return a const reference to the element in the i-th line, j-th column, and - * k-th depth. i must be lower than height, j lower than width, and k lower - * than depth. - */ - const T &get(std::size_t i, std::size_t j, std::size_t k) const noexcept { - assert(i < height && j < width && k < depth); - return data[i * width * depth + j * depth + k]; - } - - /** - * Return a reference to the element in the i-th line, j-th column, and k-th - * depth. i must be lower than height, j lower than width, and k lower than - * depth. - */ - T &get(std::size_t i, std::size_t j, std::size_t k) noexcept { - return data[i * width * depth + j * depth + k]; - } - - /** - * Check if two 3D arrays are equals. - */ - bool operator==(const Array3D &a) const noexcept { - if (height != a.height || width != a.width || depth != a.depth) { - return false; - } - - for (std::size_t i = 0; i < data.size(); i++) { - if (a.data[i] != data[i]) { - return false; - } - } - return true; - } -}; - -#endif // FAST_WFC_UTILS_ARRAY3D_HPP_ diff --git a/modules/wfc/array_2d.h b/modules/wfc/array_2d.h new file mode 100644 index 000000000..da25be876 --- /dev/null +++ b/modules/wfc/array_2d.h @@ -0,0 +1,134 @@ +#ifndef FAST_WFC_UTILS_ARRAY2D_HPP_ +#define FAST_WFC_UTILS_ARRAY2D_HPP_ + +#include "assert.h" +#include + +/** + * Represent a 2D array. + * The 2D array is stored in a single array, to improve cache usage. + */ +template +class Array2D { +public: + /** + * Height and width of the 2D array. + */ + std::size_t height; + std::size_t width; + + /** + * The array containing the data of the 2D array. + */ + std::vector data; + + /** + * Build a 2D array given its height and width. + * All the array elements are initialized to default value. + */ + Array2D(std::size_t height, std::size_t width) noexcept + : + height(height), width(width), data(width * height) {} + + /** + * Build a 2D array given its height and width. + * All the array elements are initialized to value. + */ + Array2D(std::size_t height, std::size_t width, T value) noexcept + : + height(height), width(width), data(width * height, value) {} + + /** + * Return a const reference to the element in the i-th line and j-th column. + * i must be lower than height and j lower than width. + */ + const T &get(std::size_t i, std::size_t j) const noexcept { + assert(i < height && j < width); + return data[j + i * width]; + } + + /** + * Return a reference to the element in the i-th line and j-th column. + * i must be lower than height and j lower than width. + */ + T &get(std::size_t i, std::size_t j) noexcept { + assert(i < height && j < width); + return data[j + i * width]; + } + + /** + * Return the current 2D array reflected along the x axis. + */ + Array2D reflected() const noexcept { + Array2D result = Array2D(width, height); + for (std::size_t y = 0; y < height; y++) { + for (std::size_t x = 0; x < width; x++) { + result.get(y, x) = get(y, width - 1 - x); + } + } + return result; + } + + /** + * Return the current 2D array rotated 90° anticlockwise + */ + Array2D rotated() const noexcept { + Array2D result = Array2D(width, height); + for (std::size_t y = 0; y < width; y++) { + for (std::size_t x = 0; x < height; x++) { + result.get(y, x) = get(x, width - 1 - y); + } + } + return result; + } + + /** + * Return the sub 2D array starting from (y,x) and with size (sub_width, + * sub_height). The current 2D array is considered toric for this operation. + */ + Array2D get_sub_array(std::size_t y, std::size_t x, std::size_t sub_width, + std::size_t sub_height) const noexcept { + Array2D sub_array_2d = Array2D(sub_width, sub_height); + for (std::size_t ki = 0; ki < sub_height; ki++) { + for (std::size_t kj = 0; kj < sub_width; kj++) { + sub_array_2d.get(ki, kj) = get((y + ki) % height, (x + kj) % width); + } + } + return sub_array_2d; + } + + /** + * Check if two 2D arrays are equals. + */ + bool operator==(const Array2D &a) const noexcept { + if (height != a.height || width != a.width) { + return false; + } + + for (std::size_t i = 0; i < data.size(); i++) { + if (a.data[i] != data[i]) { + return false; + } + } + return true; + } +}; + +/** + * Hash function. + */ +namespace std { +template +class hash> { +public: + std::size_t operator()(const Array2D &a) const noexcept { + std::size_t seed = a.data.size(); + for (const T &i : a.data) { + seed ^= hash()(i) + (std::size_t)0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; +} // namespace std + +#endif // FAST_WFC_UTILS_ARRAY2D_HPP_ diff --git a/modules/wfc/array_3d.h b/modules/wfc/array_3d.h new file mode 100644 index 000000000..8256ad6b9 --- /dev/null +++ b/modules/wfc/array_3d.h @@ -0,0 +1,79 @@ +#ifndef FAST_WFC_UTILS_ARRAY3D_HPP_ +#define FAST_WFC_UTILS_ARRAY3D_HPP_ + +#include "assert.h" +#include + +/** + * Represent a 3D array. + * The 3D array is stored in a single array, to improve cache usage. + */ +template +class Array3D { +public: + /** + * The dimensions of the 3D array. + */ + std::size_t height; + std::size_t width; + std::size_t depth; + + /** + * The array containing the data of the 3D array. + */ + std::vector data; + + /** + * Build a 2D array given its height, width and depth. + * All the arrays elements are initialized to default value. + */ + Array3D(std::size_t height, std::size_t width, std::size_t depth) noexcept + : + height(height), width(width), depth(depth), data(width * height * depth) {} + + /** + * Build a 2D array given its height, width and depth. + * All the arrays elements are initialized to value + */ + Array3D(std::size_t height, std::size_t width, std::size_t depth, + T value) noexcept + : + height(height), width(width), depth(depth), data(width * height * depth, value) {} + + /** + * Return a const reference to the element in the i-th line, j-th column, and + * k-th depth. i must be lower than height, j lower than width, and k lower + * than depth. + */ + const T &get(std::size_t i, std::size_t j, std::size_t k) const noexcept { + assert(i < height && j < width && k < depth); + return data[i * width * depth + j * depth + k]; + } + + /** + * Return a reference to the element in the i-th line, j-th column, and k-th + * depth. i must be lower than height, j lower than width, and k lower than + * depth. + */ + T &get(std::size_t i, std::size_t j, std::size_t k) noexcept { + return data[i * width * depth + j * depth + k]; + } + + /** + * Check if two 3D arrays are equals. + */ + bool operator==(const Array3D &a) const noexcept { + if (height != a.height || width != a.width || depth != a.depth) { + return false; + } + + for (std::size_t i = 0; i < data.size(); i++) { + if (a.data[i] != data[i]) { + return false; + } + } + return true; + } +}; + +#endif // FAST_WFC_UTILS_ARRAY3D_HPP_ diff --git a/modules/wfc/direction.h b/modules/wfc/direction.h new file mode 100644 index 000000000..f5bd2708f --- /dev/null +++ b/modules/wfc/direction.h @@ -0,0 +1,11 @@ +#ifndef FAST_WFC_DIRECTION_HPP_ +#define FAST_WFC_DIRECTION_HPP_ + +constexpr int directions_x[4] = { 0, -1, 1, 0 }; +constexpr int directions_y[4] = { -1, 0, 0, 1 }; + +constexpr unsigned get_opposite_direction(unsigned direction) noexcept { + return 3 - direction; +} + +#endif diff --git a/modules/wfc/direction.hpp b/modules/wfc/direction.hpp deleted file mode 100644 index 4a193a999..000000000 --- a/modules/wfc/direction.hpp +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef FAST_WFC_DIRECTION_HPP_ -#define FAST_WFC_DIRECTION_HPP_ - -/** - * A direction is represented by an unsigned integer in the range [0; 3]. - * The x and y values of the direction can be retrieved in these tables. - */ -constexpr int directions_x[4] = {0, -1, 1, 0}; -constexpr int directions_y[4] = {-1, 0, 0, 1}; - -/** - * Return the opposite direction of direction. - */ -constexpr unsigned get_opposite_direction(unsigned direction) noexcept { - return 3 - direction; -} - -#endif // FAST_WFC_DIRECTION_HPP_ diff --git a/modules/wfc/overlapping_wfc.h b/modules/wfc/overlapping_wfc.h new file mode 100644 index 000000000..543ef2612 --- /dev/null +++ b/modules/wfc/overlapping_wfc.h @@ -0,0 +1,360 @@ +#ifndef FAST_WFC_OVERLAPPING_WFC_HPP_ +#define FAST_WFC_OVERLAPPING_WFC_HPP_ + +#include +#include +#include + +#include "array_2d.h" +#include "wfc.h" + +/** + * Options needed to use the overlapping wfc. + */ +struct OverlappingWFCOptions { + bool periodic_input; // True if the input is toric. + bool periodic_output; // True if the output is toric. + unsigned out_height; // The height of the output in pixels. + unsigned out_width; // The width of the output in pixels. + unsigned symmetry; // The number of symmetries (the order is defined in wfc). + bool ground; // True if the ground needs to be set (see init_ground). + unsigned pattern_size; // The width and height in pixel of the patterns. + + /** + * Get the wave height given these options. + */ + unsigned get_wave_height() const noexcept { + return periodic_output ? out_height : out_height - pattern_size + 1; + } + + /** + * Get the wave width given these options. + */ + unsigned get_wave_width() const noexcept { + return periodic_output ? out_width : out_width - pattern_size + 1; + } +}; + +/** + * Class generating a new image with the overlapping WFC algorithm. + */ +template +class OverlappingWFC { +private: + /** + * The input image. T is usually a color. + */ + Array2D input; + + /** + * Options needed by the algorithm. + */ + OverlappingWFCOptions options; + + /** + * The array of the different patterns extracted from the input. + */ + std::vector> patterns; + + /** + * The underlying generic WFC algorithm. + */ + WFC wfc; + + /** + * Constructor initializing the wfc. + * This constructor is called by the other constructors. + * This is necessary in order to initialize wfc only once. + */ + OverlappingWFC( + const Array2D &input, const OverlappingWFCOptions &options, + const int &seed, + const std::pair>, std::vector> &patterns, + const std::vector, 4>> + &propagator) noexcept + : + input(input), options(options), patterns(patterns.first), wfc(options.periodic_output, seed, patterns.second, propagator, options.get_wave_height(), options.get_wave_width()) { + // If necessary, the ground is set. + if (options.ground) { + init_ground(wfc, input, patterns.first, options); + } + } + + /** + * Constructor used only to call the other constructor with more computed + * parameters. + */ + OverlappingWFC(const Array2D &input, const OverlappingWFCOptions &options, + const int &seed, + const std::pair>, std::vector> + &patterns) noexcept + : + OverlappingWFC(input, options, seed, patterns, + generate_compatible(patterns.first)) {} + + /** + * Init the ground of the output image. + * The lowest middle pattern is used as a floor (and ceiling when the input is + * toric) and is placed at the lowest possible pattern position in the output + * image, on all its width. The pattern cannot be used at any other place in + * the output image. + */ + void init_ground(WFC &wfc, const Array2D &input, + const std::vector> &patterns, + const OverlappingWFCOptions &options) noexcept { + unsigned ground_pattern_id = + get_ground_pattern_id(input, patterns, options); + + // Place the pattern in the ground. + for (unsigned j = 0; j < options.get_wave_width(); j++) { + set_pattern(ground_pattern_id, options.get_wave_height() - 1, j); + } + + // Remove the pattern from the other positions. + for (unsigned i = 0; i < options.get_wave_height() - 1; i++) { + for (unsigned j = 0; j < options.get_wave_width(); j++) { + wfc.remove_wave_pattern(i, j, ground_pattern_id); + } + } + + // Propagate the information with wfc. + wfc.propagate(); + } + + /** + * Return the id of the lowest middle pattern. + */ + static unsigned + get_ground_pattern_id(const Array2D &input, + const std::vector> &patterns, + const OverlappingWFCOptions &options) noexcept { + // Get the pattern. + Array2D ground_pattern = + input.get_sub_array(input.height - 1, input.width / 2, + options.pattern_size, options.pattern_size); + + // Retrieve the id of the pattern. + for (unsigned i = 0; i < patterns.size(); i++) { + if (ground_pattern == patterns[i]) { + return i; + } + } + + // The pattern exists. + assert(false); + return 0; + } + + /** + * Return the list of patterns, as well as their probabilities of apparition. + */ + static std::pair>, std::vector> + get_patterns(const Array2D &input, + const OverlappingWFCOptions &options) noexcept { + std::unordered_map, unsigned> patterns_id; + std::vector> patterns; + + // The number of time a pattern is seen in the input image. + std::vector patterns_weight; + + std::vector> symmetries( + 8, Array2D(options.pattern_size, options.pattern_size)); + unsigned max_i = options.periodic_input + ? input.height + : input.height - options.pattern_size + 1; + unsigned max_j = options.periodic_input + ? input.width + : input.width - options.pattern_size + 1; + + for (unsigned i = 0; i < max_i; i++) { + for (unsigned j = 0; j < max_j; j++) { + // Compute the symmetries of every pattern in the image. + symmetries[0].data = + input + .get_sub_array(i, j, options.pattern_size, options.pattern_size) + .data; + symmetries[1].data = symmetries[0].reflected().data; + symmetries[2].data = symmetries[0].rotated().data; + symmetries[3].data = symmetries[2].reflected().data; + symmetries[4].data = symmetries[2].rotated().data; + symmetries[5].data = symmetries[4].reflected().data; + symmetries[6].data = symmetries[4].rotated().data; + symmetries[7].data = symmetries[6].reflected().data; + + // The number of symmetries in the option class define which symetries + // will be used. + for (unsigned k = 0; k < options.symmetry; k++) { + auto res = patterns_id.insert( + std::make_pair(symmetries[k], patterns.size())); + + // If the pattern already exist, we just have to increase its number + // of appearance. + if (!res.second) { + patterns_weight[res.first->second] += 1; + } else { + patterns.push_back(symmetries[k]); + patterns_weight.push_back(1); + } + } + } + } + + return { patterns, patterns_weight }; + } + + /** + * Return true if the pattern1 is compatible with pattern2 + * when pattern2 is at a distance (dy,dx) from pattern1. + */ + static bool agrees(const Array2D &pattern1, const Array2D &pattern2, + int dy, int dx) noexcept { + unsigned xmin = dx < 0 ? 0 : dx; + unsigned xmax = dx < 0 ? dx + pattern2.width : pattern1.width; + unsigned ymin = dy < 0 ? 0 : dy; + unsigned ymax = dy < 0 ? dy + pattern2.height : pattern1.width; + + // Iterate on every pixel contained in the intersection of the two pattern. + for (unsigned y = ymin; y < ymax; y++) { + for (unsigned x = xmin; x < xmax; x++) { + // Check if the color is the same in the two patterns in that pixel. + if (pattern1.get(y, x) != pattern2.get(y - dy, x - dx)) { + return false; + } + } + } + return true; + } + + /** + * Precompute the function agrees(pattern1, pattern2, dy, dx). + * If agrees(pattern1, pattern2, dy, dx), then compatible[pattern1][direction] + * contains pattern2, where direction is the direction defined by (dy, dx) + * (see direction.hpp). + */ + static std::vector, 4>> + generate_compatible(const std::vector> &patterns) noexcept { + std::vector, 4>> compatible = + std::vector, 4>>(patterns.size()); + + // Iterate on every dy, dx, pattern1 and pattern2 + for (unsigned pattern1 = 0; pattern1 < patterns.size(); pattern1++) { + for (unsigned direction = 0; direction < 4; direction++) { + for (unsigned pattern2 = 0; pattern2 < patterns.size(); pattern2++) { + if (agrees(patterns[pattern1], patterns[pattern2], + directions_y[direction], directions_x[direction])) { + compatible[pattern1][direction].push_back(pattern2); + } + } + } + } + + return compatible; + } + + /** + * Transform a 2D array containing the patterns id to a 2D array containing + * the pixels. + */ + Array2D to_image(const Array2D &output_patterns) const noexcept { + Array2D output = Array2D(options.out_height, options.out_width); + + if (options.periodic_output) { + for (unsigned y = 0; y < options.get_wave_height(); y++) { + for (unsigned x = 0; x < options.get_wave_width(); x++) { + output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0); + } + } + } else { + for (unsigned y = 0; y < options.get_wave_height(); y++) { + for (unsigned x = 0; x < options.get_wave_width(); x++) { + output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0); + } + } + for (unsigned y = 0; y < options.get_wave_height(); y++) { + const Array2D &pattern = + patterns[output_patterns.get(y, options.get_wave_width() - 1)]; + for (unsigned dx = 1; dx < options.pattern_size; dx++) { + output.get(y, options.get_wave_width() - 1 + dx) = pattern.get(0, dx); + } + } + for (unsigned x = 0; x < options.get_wave_width(); x++) { + const Array2D &pattern = + patterns[output_patterns.get(options.get_wave_height() - 1, x)]; + for (unsigned dy = 1; dy < options.pattern_size; dy++) { + output.get(options.get_wave_height() - 1 + dy, x) = + pattern.get(dy, 0); + } + } + const Array2D &pattern = patterns[output_patterns.get( + options.get_wave_height() - 1, options.get_wave_width() - 1)]; + for (unsigned dy = 1; dy < options.pattern_size; dy++) { + for (unsigned dx = 1; dx < options.pattern_size; dx++) { + output.get(options.get_wave_height() - 1 + dy, + options.get_wave_width() - 1 + dx) = pattern.get(dy, dx); + } + } + } + + return output; + } + + std::optional get_pattern_id(const Array2D &pattern) { + unsigned *pattern_id = std::find(patterns.begin(), patterns.end(), pattern); + + if (pattern_id != patterns.end()) { + return *pattern_id; + } + + return std::nullopt; + } + + /** + * Set the pattern at a specific position, given its pattern id + * pattern_id needs to be a valid pattern id, and i and j needs to be in the wave range + */ + void set_pattern(unsigned pattern_id, unsigned i, unsigned j) noexcept { + for (unsigned p = 0; p < patterns.size(); p++) { + if (pattern_id != p) { + wfc.remove_wave_pattern(i, j, p); + } + } + } + +public: + /** + * The constructor used by the user. + */ + OverlappingWFC(const Array2D &input, const OverlappingWFCOptions &options, + int seed) noexcept + : + OverlappingWFC(input, options, seed, get_patterns(input, options)) {} + + /** + * Set the pattern at a specific position. + * Returns false if the given pattern does not exist, or if the + * coordinates are not in the wave + */ + bool set_pattern(const Array2D &pattern, unsigned i, unsigned j) noexcept { + auto pattern_id = get_pattern_id(pattern); + + if (pattern_id == std::nullopt || i >= options.get_wave_height() || j >= options.get_wave_width()) { + return false; + } + + set_pattern(pattern_id, i, j); + return true; + } + + /** + * Run the WFC algorithm, and return the result if the algorithm succeeded. + */ + std::optional> run() noexcept { + std::optional> result = wfc.run(); + if (result.has_value()) { + return to_image(*result); + } + return std::nullopt; + } +}; + +#endif // FAST_WFC_WFC_HPP_ diff --git a/modules/wfc/overlapping_wfc.hpp b/modules/wfc/overlapping_wfc.hpp deleted file mode 100644 index b899e5961..000000000 --- a/modules/wfc/overlapping_wfc.hpp +++ /dev/null @@ -1,359 +0,0 @@ -#ifndef FAST_WFC_OVERLAPPING_WFC_HPP_ -#define FAST_WFC_OVERLAPPING_WFC_HPP_ - -#include -#include -#include - -#include "utils/array2D.hpp" -#include "wfc.hpp" - -/** - * Options needed to use the overlapping wfc. - */ -struct OverlappingWFCOptions { - bool periodic_input; // True if the input is toric. - bool periodic_output; // True if the output is toric. - unsigned out_height; // The height of the output in pixels. - unsigned out_width; // The width of the output in pixels. - unsigned symmetry; // The number of symmetries (the order is defined in wfc). - bool ground; // True if the ground needs to be set (see init_ground). - unsigned pattern_size; // The width and height in pixel of the patterns. - - /** - * Get the wave height given these options. - */ - unsigned get_wave_height() const noexcept { - return periodic_output ? out_height : out_height - pattern_size + 1; - } - - /** - * Get the wave width given these options. - */ - unsigned get_wave_width() const noexcept { - return periodic_output ? out_width : out_width - pattern_size + 1; - } -}; - -/** - * Class generating a new image with the overlapping WFC algorithm. - */ -template class OverlappingWFC { - -private: - /** - * The input image. T is usually a color. - */ - Array2D input; - - /** - * Options needed by the algorithm. - */ - OverlappingWFCOptions options; - - /** - * The array of the different patterns extracted from the input. - */ - std::vector> patterns; - - /** - * The underlying generic WFC algorithm. - */ - WFC wfc; - - /** - * Constructor initializing the wfc. - * This constructor is called by the other constructors. - * This is necessary in order to initialize wfc only once. - */ - OverlappingWFC( - const Array2D &input, const OverlappingWFCOptions &options, - const int &seed, - const std::pair>, std::vector> &patterns, - const std::vector, 4>> - &propagator) noexcept - : input(input), options(options), patterns(patterns.first), - wfc(options.periodic_output, seed, patterns.second, propagator, - options.get_wave_height(), options.get_wave_width()) { - // If necessary, the ground is set. - if (options.ground) { - init_ground(wfc, input, patterns.first, options); - } - } - - /** - * Constructor used only to call the other constructor with more computed - * parameters. - */ - OverlappingWFC(const Array2D &input, const OverlappingWFCOptions &options, - const int &seed, - const std::pair>, std::vector> - &patterns) noexcept - : OverlappingWFC(input, options, seed, patterns, - generate_compatible(patterns.first)) {} - - /** - * Init the ground of the output image. - * The lowest middle pattern is used as a floor (and ceiling when the input is - * toric) and is placed at the lowest possible pattern position in the output - * image, on all its width. The pattern cannot be used at any other place in - * the output image. - */ - void init_ground(WFC &wfc, const Array2D &input, - const std::vector> &patterns, - const OverlappingWFCOptions &options) noexcept { - unsigned ground_pattern_id = - get_ground_pattern_id(input, patterns, options); - - // Place the pattern in the ground. - for (unsigned j = 0; j < options.get_wave_width(); j++) { - set_pattern(ground_pattern_id, options.get_wave_height() - 1, j); - } - - // Remove the pattern from the other positions. - for (unsigned i = 0; i < options.get_wave_height() - 1; i++) { - for (unsigned j = 0; j < options.get_wave_width(); j++) { - wfc.remove_wave_pattern(i, j, ground_pattern_id); - } - } - - // Propagate the information with wfc. - wfc.propagate(); - } - - /** - * Return the id of the lowest middle pattern. - */ - static unsigned - get_ground_pattern_id(const Array2D &input, - const std::vector> &patterns, - const OverlappingWFCOptions &options) noexcept { - // Get the pattern. - Array2D ground_pattern = - input.get_sub_array(input.height - 1, input.width / 2, - options.pattern_size, options.pattern_size); - - // Retrieve the id of the pattern. - for (unsigned i = 0; i < patterns.size(); i++) { - if (ground_pattern == patterns[i]) { - return i; - } - } - - // The pattern exists. - assert(false); - return 0; - } - - /** - * Return the list of patterns, as well as their probabilities of apparition. - */ - static std::pair>, std::vector> - get_patterns(const Array2D &input, - const OverlappingWFCOptions &options) noexcept { - std::unordered_map, unsigned> patterns_id; - std::vector> patterns; - - // The number of time a pattern is seen in the input image. - std::vector patterns_weight; - - std::vector> symmetries( - 8, Array2D(options.pattern_size, options.pattern_size)); - unsigned max_i = options.periodic_input - ? input.height - : input.height - options.pattern_size + 1; - unsigned max_j = options.periodic_input - ? input.width - : input.width - options.pattern_size + 1; - - for (unsigned i = 0; i < max_i; i++) { - for (unsigned j = 0; j < max_j; j++) { - // Compute the symmetries of every pattern in the image. - symmetries[0].data = - input - .get_sub_array(i, j, options.pattern_size, options.pattern_size) - .data; - symmetries[1].data = symmetries[0].reflected().data; - symmetries[2].data = symmetries[0].rotated().data; - symmetries[3].data = symmetries[2].reflected().data; - symmetries[4].data = symmetries[2].rotated().data; - symmetries[5].data = symmetries[4].reflected().data; - symmetries[6].data = symmetries[4].rotated().data; - symmetries[7].data = symmetries[6].reflected().data; - - // The number of symmetries in the option class define which symetries - // will be used. - for (unsigned k = 0; k < options.symmetry; k++) { - auto res = patterns_id.insert( - std::make_pair(symmetries[k], patterns.size())); - - // If the pattern already exist, we just have to increase its number - // of appearance. - if (!res.second) { - patterns_weight[res.first->second] += 1; - } else { - patterns.push_back(symmetries[k]); - patterns_weight.push_back(1); - } - } - } - } - - return {patterns, patterns_weight}; - } - - /** - * Return true if the pattern1 is compatible with pattern2 - * when pattern2 is at a distance (dy,dx) from pattern1. - */ - static bool agrees(const Array2D &pattern1, const Array2D &pattern2, - int dy, int dx) noexcept { - unsigned xmin = dx < 0 ? 0 : dx; - unsigned xmax = dx < 0 ? dx + pattern2.width : pattern1.width; - unsigned ymin = dy < 0 ? 0 : dy; - unsigned ymax = dy < 0 ? dy + pattern2.height : pattern1.width; - - // Iterate on every pixel contained in the intersection of the two pattern. - for (unsigned y = ymin; y < ymax; y++) { - for (unsigned x = xmin; x < xmax; x++) { - // Check if the color is the same in the two patterns in that pixel. - if (pattern1.get(y, x) != pattern2.get(y - dy, x - dx)) { - return false; - } - } - } - return true; - } - - /** - * Precompute the function agrees(pattern1, pattern2, dy, dx). - * If agrees(pattern1, pattern2, dy, dx), then compatible[pattern1][direction] - * contains pattern2, where direction is the direction defined by (dy, dx) - * (see direction.hpp). - */ - static std::vector, 4>> - generate_compatible(const std::vector> &patterns) noexcept { - std::vector, 4>> compatible = - std::vector, 4>>(patterns.size()); - - // Iterate on every dy, dx, pattern1 and pattern2 - for (unsigned pattern1 = 0; pattern1 < patterns.size(); pattern1++) { - for (unsigned direction = 0; direction < 4; direction++) { - for (unsigned pattern2 = 0; pattern2 < patterns.size(); pattern2++) { - if (agrees(patterns[pattern1], patterns[pattern2], - directions_y[direction], directions_x[direction])) { - compatible[pattern1][direction].push_back(pattern2); - } - } - } - } - - return compatible; - } - - /** - * Transform a 2D array containing the patterns id to a 2D array containing - * the pixels. - */ - Array2D to_image(const Array2D &output_patterns) const noexcept { - Array2D output = Array2D(options.out_height, options.out_width); - - if (options.periodic_output) { - for (unsigned y = 0; y < options.get_wave_height(); y++) { - for (unsigned x = 0; x < options.get_wave_width(); x++) { - output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0); - } - } - } else { - for (unsigned y = 0; y < options.get_wave_height(); y++) { - for (unsigned x = 0; x < options.get_wave_width(); x++) { - output.get(y, x) = patterns[output_patterns.get(y, x)].get(0, 0); - } - } - for (unsigned y = 0; y < options.get_wave_height(); y++) { - const Array2D &pattern = - patterns[output_patterns.get(y, options.get_wave_width() - 1)]; - for (unsigned dx = 1; dx < options.pattern_size; dx++) { - output.get(y, options.get_wave_width() - 1 + dx) = pattern.get(0, dx); - } - } - for (unsigned x = 0; x < options.get_wave_width(); x++) { - const Array2D &pattern = - patterns[output_patterns.get(options.get_wave_height() - 1, x)]; - for (unsigned dy = 1; dy < options.pattern_size; dy++) { - output.get(options.get_wave_height() - 1 + dy, x) = - pattern.get(dy, 0); - } - } - const Array2D &pattern = patterns[output_patterns.get( - options.get_wave_height() - 1, options.get_wave_width() - 1)]; - for (unsigned dy = 1; dy < options.pattern_size; dy++) { - for (unsigned dx = 1; dx < options.pattern_size; dx++) { - output.get(options.get_wave_height() - 1 + dy, - options.get_wave_width() - 1 + dx) = pattern.get(dy, dx); - } - } - } - - return output; - } - - std::optional get_pattern_id(const Array2D &pattern) { - unsigned* pattern_id = std::find(patterns.begin(), patterns.end(), pattern); - - if (pattern_id != patterns.end()) { - return *pattern_id; - } - - return std::nullopt; - } - - /** - * Set the pattern at a specific position, given its pattern id - * pattern_id needs to be a valid pattern id, and i and j needs to be in the wave range - */ - void set_pattern(unsigned pattern_id, unsigned i, unsigned j) noexcept { - for (unsigned p = 0; p < patterns.size(); p++) { - if (pattern_id != p) { - wfc.remove_wave_pattern(i, j, p); - } - } - } - -public: - /** - * The constructor used by the user. - */ - OverlappingWFC(const Array2D &input, const OverlappingWFCOptions &options, - int seed) noexcept - : OverlappingWFC(input, options, seed, get_patterns(input, options)) {} - - /** - * Set the pattern at a specific position. - * Returns false if the given pattern does not exist, or if the - * coordinates are not in the wave - */ - bool set_pattern(const Array2D& pattern, unsigned i, unsigned j) noexcept { - auto pattern_id = get_pattern_id(pattern); - - if (pattern_id == std::nullopt || i >= options.get_wave_height() || j >= options.get_wave_width()) { - return false; - } - - set_pattern(pattern_id, i, j); - return true; - } - - /** - * Run the WFC algorithm, and return the result if the algorithm succeeded. - */ - std::optional> run() noexcept { - std::optional> result = wfc.run(); - if (result.has_value()) { - return to_image(*result); - } - return std::nullopt; - } -}; - -#endif // FAST_WFC_WFC_HPP_ diff --git a/modules/wfc/propagator.cpp b/modules/wfc/propagator.cpp index 98e2e9fb3..05b5e076f 100644 --- a/modules/wfc/propagator.cpp +++ b/modules/wfc/propagator.cpp @@ -1,77 +1,73 @@ -#include "propagator.hpp" -#include "wave.hpp" +#include "propagator.h" +#include "wave.h" void Propagator::init_compatible() noexcept { - std::array value; - // We compute the number of pattern compatible in all directions. - for (unsigned y = 0; y < wave_height; y++) { - for (unsigned x = 0; x < wave_width; x++) { - for (unsigned pattern = 0; pattern < patterns_size; pattern++) { - for (int direction = 0; direction < 4; direction++) { - value[direction] = - static_cast(propagator_state[pattern][get_opposite_direction(direction)] - .size()); - } - compatible.get(y, x, pattern) = value; - } - } - } + std::array value; + // We compute the number of pattern compatible in all directions. + for (unsigned y = 0; y < wave_height; y++) { + for (unsigned x = 0; x < wave_width; x++) { + for (unsigned pattern = 0; pattern < patterns_size; pattern++) { + for (int direction = 0; direction < 4; direction++) { + value[direction] = + static_cast(propagator_state[pattern][get_opposite_direction(direction)] + .size()); + } + compatible.get(y, x, pattern) = value; + } + } + } } void Propagator::propagate(Wave &wave) noexcept { + // We propagate every element while there is element to propagate. + while (propagating.size() != 0) { + // The cell and pattern that has been set to false. + unsigned y1, x1, pattern; + std::tie(y1, x1, pattern) = propagating.back(); + propagating.pop_back(); - // We propagate every element while there is element to propagate. - while (propagating.size() != 0) { + // We propagate the information in all 4 directions. + for (unsigned direction = 0; direction < 4; direction++) { + // We get the next cell in the direction direction. + int dx = directions_x[direction]; + int dy = directions_y[direction]; + int x2, y2; + if (periodic_output) { + x2 = ((int)x1 + dx + (int)wave.width) % wave.width; + y2 = ((int)y1 + dy + (int)wave.height) % wave.height; + } else { + x2 = x1 + dx; + y2 = y1 + dy; + if (x2 < 0 || x2 >= (int)wave.width) { + continue; + } + if (y2 < 0 || y2 >= (int)wave.height) { + continue; + } + } - // The cell and pattern that has been set to false. - unsigned y1, x1, pattern; - std::tie(y1, x1, pattern) = propagating.back(); - propagating.pop_back(); + // The index of the second cell, and the patterns compatible + unsigned i2 = x2 + y2 * wave.width; + const std::vector &patterns = + propagator_state[pattern][direction]; - // We propagate the information in all 4 directions. - for (unsigned direction = 0; direction < 4; direction++) { + // For every pattern that could be placed in that cell without being in + // contradiction with pattern1 + for (auto it = patterns.begin(), it_end = patterns.end(); it < it_end; + ++it) { + // We decrease the number of compatible patterns in the opposite + // direction If the pattern was discarded from the wave, the element + // is still negative, which is not a problem + std::array &value = compatible.get(y2, x2, *it); + value[direction]--; - // We get the next cell in the direction direction. - int dx = directions_x[direction]; - int dy = directions_y[direction]; - int x2, y2; - if (periodic_output) { - x2 = ((int)x1 + dx + (int)wave.width) % wave.width; - y2 = ((int)y1 + dy + (int)wave.height) % wave.height; - } else { - x2 = x1 + dx; - y2 = y1 + dy; - if (x2 < 0 || x2 >= (int)wave.width) { - continue; - } - if (y2 < 0 || y2 >= (int)wave.height) { - continue; - } - } - - // The index of the second cell, and the patterns compatible - unsigned i2 = x2 + y2 * wave.width; - const std::vector &patterns = - propagator_state[pattern][direction]; - - // For every pattern that could be placed in that cell without being in - // contradiction with pattern1 - for (auto it = patterns.begin(), it_end = patterns.end(); it < it_end; - ++it) { - - // We decrease the number of compatible patterns in the opposite - // direction If the pattern was discarded from the wave, the element - // is still negative, which is not a problem - std::array &value = compatible.get(y2, x2, *it); - value[direction]--; - - // If the element was set to 0 with this operation, we need to remove - // the pattern from the wave, and propagate the information - if (value[direction] == 0) { - add_to_propagator(y2, x2, *it); - wave.set(i2, *it, false); - } - } - } - } + // If the element was set to 0 with this operation, we need to remove + // the pattern from the wave, and propagate the information + if (value[direction] == 0) { + add_to_propagator(y2, x2, *it); + wave.set(i2, *it, false); + } + } + } + } } diff --git a/modules/wfc/propagator.h b/modules/wfc/propagator.h new file mode 100644 index 000000000..d7dd0b786 --- /dev/null +++ b/modules/wfc/propagator.h @@ -0,0 +1,96 @@ +#ifndef FAST_WFC_PROPAGATOR_HPP_ +#define FAST_WFC_PROPAGATOR_HPP_ + +#include "direction.h" +#include "array_3d.h" +#include +#include +#include + +class Wave; + +/** + * Propagate information about patterns in the wave. + */ +class Propagator { +public: + using PropagatorState = std::vector, 4>>; + +private: + /** + * The size of the patterns. + */ + const std::size_t patterns_size; + + /** + * propagator[pattern1][direction] contains all the patterns that can + * be placed in next to pattern1 in the direction direction. + */ + PropagatorState propagator_state; + + /** + * The wave width and height. + */ + const unsigned wave_width; + const unsigned wave_height; + + /** + * True if the wave and the output is toric. + */ + const bool periodic_output; + + /** + * All the tuples (y, x, pattern) that should be propagated. + * The tuple should be propagated when wave.get(y, x, pattern) is set to + * false. + */ + std::vector> propagating; + + /** + * compatible.get(y, x, pattern)[direction] contains the number of patterns + * present in the wave that can be placed in the cell next to (y,x) in the + * opposite direction of direction without being in contradiction with pattern + * placed in (y,x). If wave.get(y, x, pattern) is set to false, then + * compatible.get(y, x, pattern) has every element negative or null + */ + Array3D> compatible; + + /** + * Initialize compatible. + */ + void init_compatible() noexcept; + +public: + /** + * Constructor building the propagator and initializing compatible. + */ + Propagator(unsigned wave_height, unsigned wave_width, bool periodic_output, + PropagatorState propagator_state) noexcept + : + patterns_size(propagator_state.size()), + propagator_state(propagator_state), + wave_width(wave_width), + wave_height(wave_height), + periodic_output(periodic_output), + compatible(wave_height, wave_width, patterns_size) { + init_compatible(); + } + + /** + * Add an element to the propagator. + * This function is called when wave.get(y, x, pattern) is set to false. + */ + void add_to_propagator(unsigned y, unsigned x, unsigned pattern) noexcept { + // All the direction are set to 0, since the pattern cannot be set in (y,x). + std::array temp = {}; + compatible.get(y, x, pattern) = temp; + propagating.emplace_back(y, x, pattern); + } + + /** + * Propagate the information given with add_to_propagator. + */ + void propagate(Wave &wave) noexcept; +}; + +#endif // FAST_WFC_PROPAGATOR_HPP_ diff --git a/modules/wfc/propagator.hpp b/modules/wfc/propagator.hpp deleted file mode 100644 index cb4324c0a..000000000 --- a/modules/wfc/propagator.hpp +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef FAST_WFC_PROPAGATOR_HPP_ -#define FAST_WFC_PROPAGATOR_HPP_ - -#include "direction.hpp" -#include "utils/array3D.hpp" -#include -#include -#include - -class Wave; - -/** - * Propagate information about patterns in the wave. - */ -class Propagator { -public: - using PropagatorState = std::vector, 4>>; - -private: - /** - * The size of the patterns. - */ - const std::size_t patterns_size; - - /** - * propagator[pattern1][direction] contains all the patterns that can - * be placed in next to pattern1 in the direction direction. - */ - PropagatorState propagator_state; - - /** - * The wave width and height. - */ - const unsigned wave_width; - const unsigned wave_height; - - /** - * True if the wave and the output is toric. - */ - const bool periodic_output; - - /** - * All the tuples (y, x, pattern) that should be propagated. - * The tuple should be propagated when wave.get(y, x, pattern) is set to - * false. - */ - std::vector> propagating; - - /** - * compatible.get(y, x, pattern)[direction] contains the number of patterns - * present in the wave that can be placed in the cell next to (y,x) in the - * opposite direction of direction without being in contradiction with pattern - * placed in (y,x). If wave.get(y, x, pattern) is set to false, then - * compatible.get(y, x, pattern) has every element negative or null - */ - Array3D> compatible; - - /** - * Initialize compatible. - */ - void init_compatible() noexcept; - -public: - /** - * Constructor building the propagator and initializing compatible. - */ - Propagator(unsigned wave_height, unsigned wave_width, bool periodic_output, - PropagatorState propagator_state) noexcept - : patterns_size(propagator_state.size()), - propagator_state(propagator_state), wave_width(wave_width), - wave_height(wave_height), periodic_output(periodic_output), - compatible(wave_height, wave_width, patterns_size) { - init_compatible(); - } - - /** - * Add an element to the propagator. - * This function is called when wave.get(y, x, pattern) is set to false. - */ - void add_to_propagator(unsigned y, unsigned x, unsigned pattern) noexcept { - // All the direction are set to 0, since the pattern cannot be set in (y,x). - std::array temp = {}; - compatible.get(y, x, pattern) = temp; - propagating.emplace_back(y, x, pattern); - } - - /** - * Propagate the information given with add_to_propagator. - */ - void propagate(Wave &wave) noexcept; -}; - -#endif // FAST_WFC_PROPAGATOR_HPP_ diff --git a/modules/wfc/tiling_wfc.h b/modules/wfc/tiling_wfc.h new file mode 100644 index 000000000..7ca229365 --- /dev/null +++ b/modules/wfc/tiling_wfc.h @@ -0,0 +1,407 @@ +#ifndef FAST_WFC_TILING_WFC_HPP_ +#define FAST_WFC_TILING_WFC_HPP_ + +#include +#include + +#include "array_2d.h" +#include "wfc.h" + +/** + * The distinct symmetries of a tile. + * It represents how the tile behave when it is rotated or reflected + */ +enum class Symmetry { + X, + T, + I, + L, + backslash, + P +}; + +/** + * Return the number of possible distinct orientations for a tile. + * An orientation is a combination of rotations and reflections. + */ +constexpr unsigned nb_of_possible_orientations(const Symmetry &symmetry) { + switch (symmetry) { + case Symmetry::X: + return 1; + case Symmetry::I: + case Symmetry::backslash: + return 2; + case Symmetry::T: + case Symmetry::L: + return 4; + default: + return 8; + } +} + +/** + * A tile that can be placed on the board. + */ +template +struct Tile { + std::vector> data; // The different orientations of the tile + Symmetry symmetry; // The symmetry of the tile + double weight; // Its weight on the distribution of presence of tiles + + /** + * Generate the map associating an orientation id to the orientation + * id obtained when rotating 90° anticlockwise the tile. + */ + static std::vector + generate_rotation_map(const Symmetry &symmetry) noexcept { + switch (symmetry) { + case Symmetry::X: + return { 0 }; + case Symmetry::I: + case Symmetry::backslash: + return { 1, 0 }; + case Symmetry::T: + case Symmetry::L: + return { 1, 2, 3, 0 }; + case Symmetry::P: + default: + return { 1, 2, 3, 0, 5, 6, 7, 4 }; + } + } + + /** + * Generate the map associating an orientation id to the orientation + * id obtained when reflecting the tile along the x axis. + */ + static std::vector + generate_reflection_map(const Symmetry &symmetry) noexcept { + switch (symmetry) { + case Symmetry::X: + return { 0 }; + case Symmetry::I: + return { 0, 1 }; + case Symmetry::backslash: + return { 1, 0 }; + case Symmetry::T: + return { 0, 3, 2, 1 }; + case Symmetry::L: + return { 1, 0, 3, 2 }; + case Symmetry::P: + default: + return { 4, 7, 6, 5, 0, 3, 2, 1 }; + } + } + + /** + * Generate the map associating an orientation id and an action to the + * resulting orientation id. + * Actions 0, 1, 2, and 3 are 0°, 90°, 180°, and 270° anticlockwise rotations. + * Actions 4, 5, 6, and 7 are actions 0, 1, 2, and 3 preceded by a reflection + * on the x axis. + */ + static std::vector> + generate_action_map(const Symmetry &symmetry) noexcept { + std::vector rotation_map = generate_rotation_map(symmetry); + std::vector reflection_map = generate_reflection_map(symmetry); + size_t size = rotation_map.size(); + std::vector> action_map(8, + std::vector(size)); + for (size_t i = 0; i < size; ++i) { + action_map[0][i] = i; + } + + for (size_t a = 1; a < 4; ++a) { + for (size_t i = 0; i < size; ++i) { + action_map[a][i] = rotation_map[action_map[a - 1][i]]; + } + } + for (size_t i = 0; i < size; ++i) { + action_map[4][i] = reflection_map[action_map[0][i]]; + } + for (size_t a = 5; a < 8; ++a) { + for (size_t i = 0; i < size; ++i) { + action_map[a][i] = rotation_map[action_map[a - 1][i]]; + } + } + return action_map; + } + + /** + * Generate all distincts rotations of a 2D array given its symmetries; + */ + static std::vector> generate_oriented(Array2D data, + Symmetry symmetry) noexcept { + std::vector> oriented; + oriented.push_back(data); + + switch (symmetry) { + case Symmetry::I: + case Symmetry::backslash: + oriented.push_back(data.rotated()); + break; + case Symmetry::T: + case Symmetry::L: + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated()); + break; + case Symmetry::P: + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated().reflected()); + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated()); + oriented.push_back(data = data.rotated()); + break; + default: + break; + } + + return oriented; + } + + /** + * Create a tile with its differents orientations, its symmetries and its + * weight on the distribution of tiles. + */ + Tile(std::vector> data, Symmetry symmetry, double weight) noexcept + : + data(data), symmetry(symmetry), weight(weight) {} + + /* + * Create a tile with its base orientation, its symmetries and its + * weight on the distribution of tiles. + * The other orientations are generated with its first one. + */ + Tile(Array2D data, Symmetry symmetry, double weight) noexcept + : + data(generate_oriented(data, symmetry)), symmetry(symmetry), weight(weight) {} +}; + +/** + * Options needed to use the tiling wfc. + */ +struct TilingWFCOptions { + bool periodic_output; +}; + +/** + * Class generating a new image with the tiling WFC algorithm. + */ +template +class TilingWFC { +private: + /** + * The distincts tiles. + */ + std::vector> tiles; + + /** + * Map ids of oriented tiles to tile and orientation. + */ + std::vector> id_to_oriented_tile; + + /** + * Map tile and orientation to oriented tile id. + */ + std::vector> oriented_tile_ids; + + /** + * Otions needed to use the tiling wfc. + */ + TilingWFCOptions options; + + /** + * The underlying generic WFC algorithm. + */ + WFC wfc; + +public: + /** + * The number of vertical tiles + */ + unsigned height; + + /** + * The number of horizontal tiles + */ + unsigned width; + +private: + /** + * Generate mapping from id to oriented tiles and vice versa. + */ + static std::pair>, + std::vector>> + generate_oriented_tile_ids(const std::vector> &tiles) noexcept { + std::vector> id_to_oriented_tile; + std::vector> oriented_tile_ids; + + unsigned id = 0; + for (unsigned i = 0; i < tiles.size(); i++) { + oriented_tile_ids.push_back({}); + for (unsigned j = 0; j < tiles[i].data.size(); j++) { + id_to_oriented_tile.push_back({ i, j }); + oriented_tile_ids[i].push_back(id); + id++; + } + } + + return { id_to_oriented_tile, oriented_tile_ids }; + } + + /** + * Generate the propagator which will be used in the wfc algorithm. + */ + static std::vector, 4>> generate_propagator( + const std::vector> + &neighbors, + std::vector> tiles, + std::vector> id_to_oriented_tile, + std::vector> oriented_tile_ids) { + size_t nb_oriented_tiles = id_to_oriented_tile.size(); + std::vector, 4>> dense_propagator( + nb_oriented_tiles, { std::vector(nb_oriented_tiles, false), std::vector(nb_oriented_tiles, false), std::vector(nb_oriented_tiles, false), std::vector(nb_oriented_tiles, false) }); + + for (auto neighbor : neighbors) { + unsigned tile1 = std::get<0>(neighbor); + unsigned orientation1 = std::get<1>(neighbor); + unsigned tile2 = std::get<2>(neighbor); + unsigned orientation2 = std::get<3>(neighbor); + std::vector> action_map1 = + Tile::generate_action_map(tiles[tile1].symmetry); + std::vector> action_map2 = + Tile::generate_action_map(tiles[tile2].symmetry); + + auto add = [&](unsigned action, unsigned direction) { + unsigned temp_orientation1 = action_map1[action][orientation1]; + unsigned temp_orientation2 = action_map2[action][orientation2]; + unsigned oriented_tile_id1 = + oriented_tile_ids[tile1][temp_orientation1]; + unsigned oriented_tile_id2 = + oriented_tile_ids[tile2][temp_orientation2]; + dense_propagator[oriented_tile_id1][direction][oriented_tile_id2] = + true; + direction = get_opposite_direction(direction); + dense_propagator[oriented_tile_id2][direction][oriented_tile_id1] = + true; + }; + + add(0, 2); + add(1, 0); + add(2, 1); + add(3, 3); + add(4, 1); + add(5, 3); + add(6, 2); + add(7, 0); + } + + std::vector, 4>> propagator( + nb_oriented_tiles); + for (size_t i = 0; i < nb_oriented_tiles; ++i) { + for (size_t j = 0; j < nb_oriented_tiles; ++j) { + for (size_t d = 0; d < 4; ++d) { + if (dense_propagator[i][d][j]) { + propagator[i][d].push_back(j); + } + } + } + } + + return propagator; + } + + /** + * Get probability of presence of tiles. + */ + static std::vector + get_tiles_weights(const std::vector> &tiles) { + std::vector frequencies; + for (size_t i = 0; i < tiles.size(); ++i) { + for (size_t j = 0; j < tiles[i].data.size(); ++j) { + frequencies.push_back(tiles[i].weight / tiles[i].data.size()); + } + } + return frequencies; + } + + /** + * Translate the generic WFC result into the image result + */ + Array2D id_to_tiling(Array2D ids) { + unsigned size = tiles[0].data[0].height; + Array2D tiling(size * ids.height, size * ids.width); + for (unsigned i = 0; i < ids.height; i++) { + for (unsigned j = 0; j < ids.width; j++) { + std::pair oriented_tile = + id_to_oriented_tile[ids.get(i, j)]; + for (unsigned y = 0; y < size; y++) { + for (unsigned x = 0; x < size; x++) { + tiling.get(i * size + y, j * size + x) = + tiles[oriented_tile.first].data[oriented_tile.second].get(y, x); + } + } + } + } + return tiling; + } + + void set_tile(unsigned tile_id, unsigned i, unsigned j) noexcept { + for (unsigned p = 0; p < id_to_oriented_tile.size(); p++) { + if (tile_id != p) { + wfc.remove_wave_pattern(i, j, p); + } + } + } + +public: + /** + * Construct the TilingWFC class to generate a tiled image. + */ + TilingWFC( + const std::vector> &tiles, + const std::vector> + &neighbors, + const unsigned height, const unsigned width, + const TilingWFCOptions &options, int seed) : + tiles(tiles), + id_to_oriented_tile(generate_oriented_tile_ids(tiles).first), + oriented_tile_ids(generate_oriented_tile_ids(tiles).second), + options(options), + wfc(options.periodic_output, seed, get_tiles_weights(tiles), + generate_propagator(neighbors, tiles, id_to_oriented_tile, + oriented_tile_ids), + height, width), + height(height), + width(width) {} + + /** + * Set the tile at a specific position. + * Returns false if the given tile and orientation does not exist, + * or if the coordinates are not in the wave + */ + bool set_tile(unsigned tile_id, unsigned orientation, unsigned i, unsigned j) noexcept { + if (tile_id >= oriented_tile_ids.size() || orientation >= oriented_tile_ids[tile_id].size() || i >= height || j >= width) { + return false; + } + + unsigned oriented_tile_id = oriented_tile_ids[tile_id][orientation]; + set_tile(oriented_tile_id, i, j); + return true; + } + + /** + * Run the tiling wfc and return the result if the algorithm succeeded + */ + std::optional> run() { + auto a = wfc.run(); + if (a == std::nullopt) { + return std::nullopt; + } + return id_to_tiling(*a); + } +}; + +#endif // FAST_WFC_TILING_WFC_HPP_ diff --git a/modules/wfc/tiling_wfc.hpp b/modules/wfc/tiling_wfc.hpp deleted file mode 100644 index 397f77268..000000000 --- a/modules/wfc/tiling_wfc.hpp +++ /dev/null @@ -1,401 +0,0 @@ -#ifndef FAST_WFC_TILING_WFC_HPP_ -#define FAST_WFC_TILING_WFC_HPP_ - -#include -#include - -#include "utils/array2D.hpp" -#include "wfc.hpp" - -/** - * The distinct symmetries of a tile. - * It represents how the tile behave when it is rotated or reflected - */ -enum class Symmetry { X, T, I, L, backslash, P }; - -/** - * Return the number of possible distinct orientations for a tile. - * An orientation is a combination of rotations and reflections. - */ -constexpr unsigned nb_of_possible_orientations(const Symmetry &symmetry) { - switch (symmetry) { - case Symmetry::X: - return 1; - case Symmetry::I: - case Symmetry::backslash: - return 2; - case Symmetry::T: - case Symmetry::L: - return 4; - default: - return 8; - } -} - -/** - * A tile that can be placed on the board. - */ -template struct Tile { - std::vector> data; // The different orientations of the tile - Symmetry symmetry; // The symmetry of the tile - double weight; // Its weight on the distribution of presence of tiles - - /** - * Generate the map associating an orientation id to the orientation - * id obtained when rotating 90° anticlockwise the tile. - */ - static std::vector - generate_rotation_map(const Symmetry &symmetry) noexcept { - switch (symmetry) { - case Symmetry::X: - return {0}; - case Symmetry::I: - case Symmetry::backslash: - return {1, 0}; - case Symmetry::T: - case Symmetry::L: - return {1, 2, 3, 0}; - case Symmetry::P: - default: - return {1, 2, 3, 0, 5, 6, 7, 4}; - } - } - - /** - * Generate the map associating an orientation id to the orientation - * id obtained when reflecting the tile along the x axis. - */ - static std::vector - generate_reflection_map(const Symmetry &symmetry) noexcept { - switch (symmetry) { - case Symmetry::X: - return {0}; - case Symmetry::I: - return {0, 1}; - case Symmetry::backslash: - return {1, 0}; - case Symmetry::T: - return {0, 3, 2, 1}; - case Symmetry::L: - return {1, 0, 3, 2}; - case Symmetry::P: - default: - return {4, 7, 6, 5, 0, 3, 2, 1}; - } - } - - /** - * Generate the map associating an orientation id and an action to the - * resulting orientation id. - * Actions 0, 1, 2, and 3 are 0°, 90°, 180°, and 270° anticlockwise rotations. - * Actions 4, 5, 6, and 7 are actions 0, 1, 2, and 3 preceded by a reflection - * on the x axis. - */ - static std::vector> - generate_action_map(const Symmetry &symmetry) noexcept { - std::vector rotation_map = generate_rotation_map(symmetry); - std::vector reflection_map = generate_reflection_map(symmetry); - size_t size = rotation_map.size(); - std::vector> action_map(8, - std::vector(size)); - for (size_t i = 0; i < size; ++i) { - action_map[0][i] = i; - } - - for (size_t a = 1; a < 4; ++a) { - for (size_t i = 0; i < size; ++i) { - action_map[a][i] = rotation_map[action_map[a - 1][i]]; - } - } - for (size_t i = 0; i < size; ++i) { - action_map[4][i] = reflection_map[action_map[0][i]]; - } - for (size_t a = 5; a < 8; ++a) { - for (size_t i = 0; i < size; ++i) { - action_map[a][i] = rotation_map[action_map[a - 1][i]]; - } - } - return action_map; - } - - /** - * Generate all distincts rotations of a 2D array given its symmetries; - */ - static std::vector> generate_oriented(Array2D data, - Symmetry symmetry) noexcept { - std::vector> oriented; - oriented.push_back(data); - - switch (symmetry) { - case Symmetry::I: - case Symmetry::backslash: - oriented.push_back(data.rotated()); - break; - case Symmetry::T: - case Symmetry::L: - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated()); - break; - case Symmetry::P: - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated().reflected()); - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated()); - oriented.push_back(data = data.rotated()); - break; - default: - break; - } - - return oriented; - } - - /** - * Create a tile with its differents orientations, its symmetries and its - * weight on the distribution of tiles. - */ - Tile(std::vector> data, Symmetry symmetry, double weight) noexcept - : data(data), symmetry(symmetry), weight(weight) {} - - /* - * Create a tile with its base orientation, its symmetries and its - * weight on the distribution of tiles. - * The other orientations are generated with its first one. - */ - Tile(Array2D data, Symmetry symmetry, double weight) noexcept - : data(generate_oriented(data, symmetry)), symmetry(symmetry), - weight(weight) {} -}; - -/** - * Options needed to use the tiling wfc. - */ -struct TilingWFCOptions { - bool periodic_output; -}; - -/** - * Class generating a new image with the tiling WFC algorithm. - */ -template class TilingWFC { -private: - /** - * The distincts tiles. - */ - std::vector> tiles; - - /** - * Map ids of oriented tiles to tile and orientation. - */ - std::vector> id_to_oriented_tile; - - /** - * Map tile and orientation to oriented tile id. - */ - std::vector> oriented_tile_ids; - - /** - * Otions needed to use the tiling wfc. - */ - TilingWFCOptions options; - - /** - * The underlying generic WFC algorithm. - */ - WFC wfc; - -public: - - /** - * The number of vertical tiles - */ - unsigned height; - - /** - * The number of horizontal tiles - */ - unsigned width; - -private: - - /** - * Generate mapping from id to oriented tiles and vice versa. - */ - static std::pair>, - std::vector>> - generate_oriented_tile_ids(const std::vector> &tiles) noexcept { - std::vector> id_to_oriented_tile; - std::vector> oriented_tile_ids; - - unsigned id = 0; - for (unsigned i = 0; i < tiles.size(); i++) { - oriented_tile_ids.push_back({}); - for (unsigned j = 0; j < tiles[i].data.size(); j++) { - id_to_oriented_tile.push_back({i, j}); - oriented_tile_ids[i].push_back(id); - id++; - } - } - - return {id_to_oriented_tile, oriented_tile_ids}; - } - - /** - * Generate the propagator which will be used in the wfc algorithm. - */ - static std::vector, 4>> generate_propagator( - const std::vector> - &neighbors, - std::vector> tiles, - std::vector> id_to_oriented_tile, - std::vector> oriented_tile_ids) { - size_t nb_oriented_tiles = id_to_oriented_tile.size(); - std::vector, 4>> dense_propagator( - nb_oriented_tiles, {std::vector(nb_oriented_tiles, false), - std::vector(nb_oriented_tiles, false), - std::vector(nb_oriented_tiles, false), - std::vector(nb_oriented_tiles, false)}); - - for (auto neighbor : neighbors) { - unsigned tile1 = std::get<0>(neighbor); - unsigned orientation1 = std::get<1>(neighbor); - unsigned tile2 = std::get<2>(neighbor); - unsigned orientation2 = std::get<3>(neighbor); - std::vector> action_map1 = - Tile::generate_action_map(tiles[tile1].symmetry); - std::vector> action_map2 = - Tile::generate_action_map(tiles[tile2].symmetry); - - auto add = [&](unsigned action, unsigned direction) { - unsigned temp_orientation1 = action_map1[action][orientation1]; - unsigned temp_orientation2 = action_map2[action][orientation2]; - unsigned oriented_tile_id1 = - oriented_tile_ids[tile1][temp_orientation1]; - unsigned oriented_tile_id2 = - oriented_tile_ids[tile2][temp_orientation2]; - dense_propagator[oriented_tile_id1][direction][oriented_tile_id2] = - true; - direction = get_opposite_direction(direction); - dense_propagator[oriented_tile_id2][direction][oriented_tile_id1] = - true; - }; - - add(0, 2); - add(1, 0); - add(2, 1); - add(3, 3); - add(4, 1); - add(5, 3); - add(6, 2); - add(7, 0); - } - - std::vector, 4>> propagator( - nb_oriented_tiles); - for (size_t i = 0; i < nb_oriented_tiles; ++i) { - for (size_t j = 0; j < nb_oriented_tiles; ++j) { - for (size_t d = 0; d < 4; ++d) { - if (dense_propagator[i][d][j]) { - propagator[i][d].push_back(j); - } - } - } - } - - return propagator; - } - - /** - * Get probability of presence of tiles. - */ - static std::vector - get_tiles_weights(const std::vector> &tiles) { - std::vector frequencies; - for (size_t i = 0; i < tiles.size(); ++i) { - for (size_t j = 0; j < tiles[i].data.size(); ++j) { - frequencies.push_back(tiles[i].weight / tiles[i].data.size()); - } - } - return frequencies; - } - - /** - * Translate the generic WFC result into the image result - */ - Array2D id_to_tiling(Array2D ids) { - unsigned size = tiles[0].data[0].height; - Array2D tiling(size * ids.height, size * ids.width); - for (unsigned i = 0; i < ids.height; i++) { - for (unsigned j = 0; j < ids.width; j++) { - std::pair oriented_tile = - id_to_oriented_tile[ids.get(i, j)]; - for (unsigned y = 0; y < size; y++) { - for (unsigned x = 0; x < size; x++) { - tiling.get(i * size + y, j * size + x) = - tiles[oriented_tile.first].data[oriented_tile.second].get(y, x); - } - } - } - } - return tiling; - } - - void set_tile(unsigned tile_id, unsigned i, unsigned j) noexcept { - for (unsigned p = 0; p < id_to_oriented_tile.size(); p++) { - if (tile_id != p) { - wfc.remove_wave_pattern(i, j, p); - } - } - } - -public: - /** - * Construct the TilingWFC class to generate a tiled image. - */ - TilingWFC( - const std::vector> &tiles, - const std::vector> - &neighbors, - const unsigned height, const unsigned width, - const TilingWFCOptions &options, int seed) - : tiles(tiles), - id_to_oriented_tile(generate_oriented_tile_ids(tiles).first), - oriented_tile_ids(generate_oriented_tile_ids(tiles).second), - options(options), - wfc(options.periodic_output, seed, get_tiles_weights(tiles), - generate_propagator(neighbors, tiles, id_to_oriented_tile, - oriented_tile_ids), - height, width), - height(height), width(width) {} - - /** - * Set the tile at a specific position. - * Returns false if the given tile and orientation does not exist, - * or if the coordinates are not in the wave - */ - bool set_tile(unsigned tile_id, unsigned orientation, unsigned i, unsigned j) noexcept { - if (tile_id >= oriented_tile_ids.size() || orientation >= oriented_tile_ids[tile_id].size() || i >= height || j >= width) { - return false; - } - - unsigned oriented_tile_id = oriented_tile_ids[tile_id][orientation]; - set_tile(oriented_tile_id, i, j); - return true; - } - - /** - * Run the tiling wfc and return the result if the algorithm succeeded - */ - std::optional> run() { - auto a = wfc.run(); - if (a == std::nullopt) { - return std::nullopt; - } - return id_to_tiling(*a); - } -}; - -#endif // FAST_WFC_TILING_WFC_HPP_ diff --git a/modules/wfc/wave.cpp b/modules/wfc/wave.cpp index 5d5067783..2e5a63b9c 100644 --- a/modules/wfc/wave.cpp +++ b/modules/wfc/wave.cpp @@ -1,4 +1,4 @@ -#include "wave.hpp" +#include "wave.h" #include @@ -9,113 +9,113 @@ namespace { */ std::vector get_plogp(const std::vector &distribution) noexcept { - std::vector plogp; - for (unsigned i = 0; i < distribution.size(); i++) { - plogp.push_back(distribution[i] * log(distribution[i])); - } - return plogp; + std::vector plogp; + for (unsigned i = 0; i < distribution.size(); i++) { + plogp.push_back(distribution[i] * log(distribution[i])); + } + return plogp; } /** * Return min(v) / 2. */ double get_min_abs_half(const std::vector &v) noexcept { - double min_abs_half = std::numeric_limits::infinity(); - for (unsigned i = 0; i < v.size(); i++) { - min_abs_half = std::min(min_abs_half, std::abs(v[i] / 2.0)); - } - return min_abs_half; + double min_abs_half = std::numeric_limits::infinity(); + for (unsigned i = 0; i < v.size(); i++) { + min_abs_half = std::min(min_abs_half, std::abs(v[i] / 2.0)); + } + return min_abs_half; } } // namespace Wave::Wave(unsigned height, unsigned width, - const std::vector &patterns_frequencies) noexcept - : patterns_frequencies(patterns_frequencies), - plogp_patterns_frequencies(get_plogp(patterns_frequencies)), - min_abs_half_plogp(get_min_abs_half(plogp_patterns_frequencies)), - is_impossible(false), nb_patterns(patterns_frequencies.size()), - data(width * height, nb_patterns, 1), width(width), height(height), - size(height * width) { - // Initialize the memoisation of entropy. - double base_entropy = 0; - double base_s = 0; - for (unsigned i = 0; i < nb_patterns; i++) { - base_entropy += plogp_patterns_frequencies[i]; - base_s += patterns_frequencies[i]; - } - double log_base_s = log(base_s); - double entropy_base = log_base_s - base_entropy / base_s; - memoisation.plogp_sum = std::vector(width * height, base_entropy); - memoisation.sum = std::vector(width * height, base_s); - memoisation.log_sum = std::vector(width * height, log_base_s); - memoisation.nb_patterns = - std::vector(width * height, static_cast(nb_patterns)); - memoisation.entropy = std::vector(width * height, entropy_base); + const std::vector &patterns_frequencies) noexcept + : + patterns_frequencies(patterns_frequencies), + plogp_patterns_frequencies(get_plogp(patterns_frequencies)), + min_abs_half_plogp(get_min_abs_half(plogp_patterns_frequencies)), + is_impossible(false), + nb_patterns(patterns_frequencies.size()), + data(width * height, nb_patterns, 1), + width(width), + height(height), + size(height * width) { + // Initialize the memoisation of entropy. + double base_entropy = 0; + double base_s = 0; + for (unsigned i = 0; i < nb_patterns; i++) { + base_entropy += plogp_patterns_frequencies[i]; + base_s += patterns_frequencies[i]; + } + double log_base_s = log(base_s); + double entropy_base = log_base_s - base_entropy / base_s; + memoisation.plogp_sum = std::vector(width * height, base_entropy); + memoisation.sum = std::vector(width * height, base_s); + memoisation.log_sum = std::vector(width * height, log_base_s); + memoisation.nb_patterns = + std::vector(width * height, static_cast(nb_patterns)); + memoisation.entropy = std::vector(width * height, entropy_base); } - void Wave::set(unsigned index, unsigned pattern, bool value) noexcept { - bool old_value = data.get(index, pattern); - // If the value isn't changed, nothing needs to be done. - if (old_value == value) { - return; - } - // Otherwise, the memoisation should be updated. - data.get(index, pattern) = value; - memoisation.plogp_sum[index] -= plogp_patterns_frequencies[pattern]; - memoisation.sum[index] -= patterns_frequencies[pattern]; - memoisation.log_sum[index] = log(memoisation.sum[index]); - memoisation.nb_patterns[index]--; - memoisation.entropy[index] = - memoisation.log_sum[index] - - memoisation.plogp_sum[index] / memoisation.sum[index]; - // If there is no patterns possible in the cell, then there is a - // contradiction. - if (memoisation.nb_patterns[index] == 0) { - is_impossible = true; - } + bool old_value = data.get(index, pattern); + // If the value isn't changed, nothing needs to be done. + if (old_value == value) { + return; + } + // Otherwise, the memoisation should be updated. + data.get(index, pattern) = value; + memoisation.plogp_sum[index] -= plogp_patterns_frequencies[pattern]; + memoisation.sum[index] -= patterns_frequencies[pattern]; + memoisation.log_sum[index] = log(memoisation.sum[index]); + memoisation.nb_patterns[index]--; + memoisation.entropy[index] = + memoisation.log_sum[index] - + memoisation.plogp_sum[index] / memoisation.sum[index]; + // If there is no patterns possible in the cell, then there is a + // contradiction. + if (memoisation.nb_patterns[index] == 0) { + is_impossible = true; + } } - int Wave::get_min_entropy(std::minstd_rand &gen) const noexcept { - if (is_impossible) { - return -2; - } + if (is_impossible) { + return -2; + } - std::uniform_real_distribution<> dis(0, min_abs_half_plogp); + std::uniform_real_distribution<> dis(0, min_abs_half_plogp); - // The minimum entropy (plus a small noise) - double min = std::numeric_limits::infinity(); - int argmin = -1; + // The minimum entropy (plus a small noise) + double min = std::numeric_limits::infinity(); + int argmin = -1; - for (unsigned i = 0; i < size; i++) { + for (unsigned i = 0; i < size; i++) { + // If the cell is decided, we do not compute the entropy (which is equal + // to 0). + double nb_patterns_local = memoisation.nb_patterns[i]; + if (nb_patterns_local == 1) { + continue; + } - // If the cell is decided, we do not compute the entropy (which is equal - // to 0). - double nb_patterns_local = memoisation.nb_patterns[i]; - if (nb_patterns_local == 1) { - continue; - } + // Otherwise, we take the memoised entropy. + double entropy = memoisation.entropy[i]; - // Otherwise, we take the memoised entropy. - double entropy = memoisation.entropy[i]; + // We first check if the entropy is less than the minimum. + // This is important to reduce noise computation (which is not + // negligible). + if (entropy <= min) { + // Then, we add noise to decide randomly which will be chosen. + // noise is smaller than the smallest p * log(p), so the minimum entropy + // will always be chosen. + double noise = dis(gen); + if (entropy + noise < min) { + min = entropy + noise; + argmin = i; + } + } + } - // We first check if the entropy is less than the minimum. - // This is important to reduce noise computation (which is not - // negligible). - if (entropy <= min) { - - // Then, we add noise to decide randomly which will be chosen. - // noise is smaller than the smallest p * log(p), so the minimum entropy - // will always be chosen. - double noise = dis(gen); - if (entropy + noise < min) { - min = entropy + noise; - argmin = i; - } - } - } - - return argmin; + return argmin; } diff --git a/modules/wfc/wave.h b/modules/wfc/wave.h new file mode 100644 index 000000000..3a9f6fd6b --- /dev/null +++ b/modules/wfc/wave.h @@ -0,0 +1,114 @@ +#ifndef FAST_WFC_WAVE_HPP_ +#define FAST_WFC_WAVE_HPP_ + +#include "array_2d.h" +#include +#include + +/** + * Struct containing the values needed to compute the entropy of all the cells. + * This struct is updated every time the wave is changed. + * p'(pattern) is equal to patterns_frequencies[pattern] if wave.get(cell, + * pattern) is set to true, otherwise 0. + */ +struct EntropyMemoisation { + std::vector plogp_sum; // The sum of p'(pattern) * log(p'(pattern)). + std::vector sum; // The sum of p'(pattern). + std::vector log_sum; // The log of sum. + std::vector nb_patterns; // The number of patterns present + std::vector entropy; // The entropy of the cell. +}; + +/** + * Contains the pattern possibilities in every cell. + * Also contains information about cell entropy. + */ +class Wave { +private: + /** + * The patterns frequencies p given to wfc. + */ + const std::vector patterns_frequencies; + + /** + * The precomputation of p * log(p). + */ + const std::vector plogp_patterns_frequencies; + + /** + * The precomputation of min (p * log(p)) / 2. + * This is used to define the maximum value of the noise. + */ + const double min_abs_half_plogp; + + /** + * The memoisation of important values for the computation of entropy. + */ + EntropyMemoisation memoisation; + + /** + * This value is set to true if there is a contradiction in the wave (all + * elements set to false in a cell). + */ + bool is_impossible; + + /** + * The number of distinct patterns. + */ + const size_t nb_patterns; + + /** + * The actual wave. data.get(index, pattern) is equal to 0 if the pattern can + * be placed in the cell index. + */ + Array2D data; + +public: + /** + * The size of the wave. + */ + const unsigned width; + const unsigned height; + const unsigned size; + + /** + * Initialize the wave with every cell being able to have every pattern. + */ + Wave(unsigned height, unsigned width, + const std::vector &patterns_frequencies) noexcept; + + /** + * Return true if pattern can be placed in cell index. + */ + bool get(unsigned index, unsigned pattern) const noexcept { + return data.get(index, pattern); + } + + /** + * Return true if pattern can be placed in cell (i,j) + */ + bool get(unsigned i, unsigned j, unsigned pattern) const noexcept { + return get(i * width + j, pattern); + } + + /** + * Set the value of pattern in cell index. + */ + void set(unsigned index, unsigned pattern, bool value) noexcept; + + /** + * Set the value of pattern in cell (i,j). + */ + void set(unsigned i, unsigned j, unsigned pattern, bool value) noexcept { + set(i * width + j, pattern, value); + } + + /** + * Return the index of the cell with lowest entropy different of 0. + * If there is a contradiction in the wave, return -2. + * If every cell is decided, return -1. + */ + int get_min_entropy(std::minstd_rand &gen) const noexcept; +}; + +#endif // FAST_WFC_WAVE_HPP_ diff --git a/modules/wfc/wave.hpp b/modules/wfc/wave.hpp deleted file mode 100644 index bf51e968d..000000000 --- a/modules/wfc/wave.hpp +++ /dev/null @@ -1,115 +0,0 @@ -#ifndef FAST_WFC_WAVE_HPP_ -#define FAST_WFC_WAVE_HPP_ - -#include "utils/array2D.hpp" -#include -#include - -/** - * Struct containing the values needed to compute the entropy of all the cells. - * This struct is updated every time the wave is changed. - * p'(pattern) is equal to patterns_frequencies[pattern] if wave.get(cell, - * pattern) is set to true, otherwise 0. - */ -struct EntropyMemoisation { - std::vector plogp_sum; // The sum of p'(pattern) * log(p'(pattern)). - std::vector sum; // The sum of p'(pattern). - std::vector log_sum; // The log of sum. - std::vector nb_patterns; // The number of patterns present - std::vector entropy; // The entropy of the cell. -}; - -/** - * Contains the pattern possibilities in every cell. - * Also contains information about cell entropy. - */ -class Wave { -private: - /** - * The patterns frequencies p given to wfc. - */ - const std::vector patterns_frequencies; - - /** - * The precomputation of p * log(p). - */ - const std::vector plogp_patterns_frequencies; - - /** - * The precomputation of min (p * log(p)) / 2. - * This is used to define the maximum value of the noise. - */ - const double min_abs_half_plogp; - - /** - * The memoisation of important values for the computation of entropy. - */ - EntropyMemoisation memoisation; - - /** - * This value is set to true if there is a contradiction in the wave (all - * elements set to false in a cell). - */ - bool is_impossible; - - /** - * The number of distinct patterns. - */ - const size_t nb_patterns; - - /** - * The actual wave. data.get(index, pattern) is equal to 0 if the pattern can - * be placed in the cell index. - */ - Array2D data; - -public: - /** - * The size of the wave. - */ - const unsigned width; - const unsigned height; - const unsigned size; - - /** - * Initialize the wave with every cell being able to have every pattern. - */ - Wave(unsigned height, unsigned width, - const std::vector &patterns_frequencies) noexcept; - - /** - * Return true if pattern can be placed in cell index. - */ - bool get(unsigned index, unsigned pattern) const noexcept { - return data.get(index, pattern); - } - - /** - * Return true if pattern can be placed in cell (i,j) - */ - bool get(unsigned i, unsigned j, unsigned pattern) const noexcept { - return get(i * width + j, pattern); - } - - /** - * Set the value of pattern in cell index. - */ - void set(unsigned index, unsigned pattern, bool value) noexcept; - - /** - * Set the value of pattern in cell (i,j). - */ - void set(unsigned i, unsigned j, unsigned pattern, bool value) noexcept { - set(i * width + j, pattern, value); - } - - /** - * Return the index of the cell with lowest entropy different of 0. - * If there is a contradiction in the wave, return -2. - * If every cell is decided, return -1. - */ - int get_min_entropy(std::minstd_rand &gen) const noexcept; - -}; - -#endif // FAST_WFC_WAVE_HPP_ diff --git a/modules/wfc/wfc.cpp b/modules/wfc/wfc.cpp index 263066d73..645944e53 100644 --- a/modules/wfc/wfc.cpp +++ b/modules/wfc/wfc.cpp @@ -1,109 +1,103 @@ -#include "wfc.hpp" +#include "wfc.h" #include namespace { - /** - * Normalize a vector so the sum of its elements is equal to 1.0f - */ - std::vector& normalize(std::vector& v) { - double sum_weights = 0.0; - for(double weight: v) { - sum_weights += weight; - } +/** + * Normalize a vector so the sum of its elements is equal to 1.0f + */ +std::vector &normalize(std::vector &v) { + double sum_weights = 0.0; + for (double weight : v) { + sum_weights += weight; + } - double inv_sum_weights = 1.0/sum_weights; - for(double& weight: v) { - weight *= inv_sum_weights; - } + double inv_sum_weights = 1.0 / sum_weights; + for (double &weight : v) { + weight *= inv_sum_weights; + } - return v; - } + return v; } - +} //namespace Array2D WFC::wave_to_output() const noexcept { - Array2D output_patterns(wave.height, wave.width); - for (unsigned i = 0; i < wave.size; i++) { - for (unsigned k = 0; k < nb_patterns; k++) { - if (wave.get(i, k)) { - output_patterns.data[i] = k; - } - } - } - return output_patterns; + Array2D output_patterns(wave.height, wave.width); + for (unsigned i = 0; i < wave.size; i++) { + for (unsigned k = 0; k < nb_patterns; k++) { + if (wave.get(i, k)) { + output_patterns.data[i] = k; + } + } + } + return output_patterns; } WFC::WFC(bool periodic_output, int seed, - std::vector patterns_frequencies, - Propagator::PropagatorState propagator, unsigned wave_height, - unsigned wave_width) - noexcept - : gen(seed), patterns_frequencies(normalize(patterns_frequencies)), - wave(wave_height, wave_width, patterns_frequencies), - nb_patterns(propagator.size()), - propagator(wave.height, wave.width, periodic_output, propagator) {} + std::vector patterns_frequencies, + Propagator::PropagatorState propagator, unsigned wave_height, + unsigned wave_width) noexcept + : + gen(seed), patterns_frequencies(normalize(patterns_frequencies)), wave(wave_height, wave_width, patterns_frequencies), nb_patterns(propagator.size()), propagator(wave.height, wave.width, periodic_output, propagator) {} std::optional> WFC::run() noexcept { - while (true) { + while (true) { + // Define the value of an undefined cell. + ObserveStatus result = observe(); - // Define the value of an undefined cell. - ObserveStatus result = observe(); + // Check if the algorithm has terminated. + if (result == failure) { + return std::nullopt; + } else if (result == success) { + return wave_to_output(); + } - // Check if the algorithm has terminated. - if (result == failure) { - return std::nullopt; - } else if (result == success) { - return wave_to_output(); - } - - // Propagate the information. - propagator.propagate(wave); - } + // Propagate the information. + propagator.propagate(wave); + } } - WFC::ObserveStatus WFC::observe() noexcept { - // Get the cell with lowest entropy. - int argmin = wave.get_min_entropy(gen); + // Get the cell with lowest entropy. + int argmin = wave.get_min_entropy(gen); - // If there is a contradiction, the algorithm has failed. - if (argmin == -2) { - return failure; - } + // If there is a contradiction, the algorithm has failed. + if (argmin == -2) { + return failure; + } - // If the lowest entropy is 0, then the algorithm has succeeded and - // finished. - if (argmin == -1) { - wave_to_output(); - return success; - } + // If the lowest entropy is 0, then the algorithm has succeeded and + // finished. + if (argmin == -1) { + wave_to_output(); + return success; + } - // Choose an element according to the pattern distribution - double s = 0; - for (unsigned k = 0; k < nb_patterns; k++) { - s += wave.get(argmin, k) ? patterns_frequencies[k] : 0; - } + // Choose an element according to the pattern distribution + double s = 0; + for (unsigned k = 0; k < nb_patterns; k++) { + s += wave.get(argmin, k) ? patterns_frequencies[k] : 0; + } - std::uniform_real_distribution<> dis(0, s); - double random_value = dis(gen); - size_t chosen_value = nb_patterns - 1; + std::uniform_real_distribution<> dis(0, s); + double random_value = dis(gen); + size_t chosen_value = nb_patterns - 1; - for (unsigned k = 0; k < nb_patterns; k++) { - random_value -= wave.get(argmin, k) ? patterns_frequencies[k] : 0; - if (random_value <= 0) { - chosen_value = k; - break; - } - } + for (unsigned k = 0; k < nb_patterns; k++) { + random_value -= wave.get(argmin, k) ? patterns_frequencies[k] : 0; + if (random_value <= 0) { + chosen_value = k; + break; + } + } - // And define the cell with the pattern. - for (unsigned k = 0; k < nb_patterns; k++) { - if (wave.get(argmin, k) != (k == chosen_value)) { - propagator.add_to_propagator(argmin / wave.width, argmin % wave.width, - k); - wave.set(argmin, k, false); - } - } + // And define the cell with the pattern. + for (unsigned k = 0; k < nb_patterns; k++) { + if (wave.get(argmin, k) != (k == chosen_value)) { + propagator.add_to_propagator(argmin / wave.width, argmin % wave.width, + k); + wave.set(argmin, k, false); + } + } - return to_continue; - } + return to_continue; +} diff --git a/modules/wfc/wfc.h b/modules/wfc/wfc.h new file mode 100644 index 000000000..499725a06 --- /dev/null +++ b/modules/wfc/wfc.h @@ -0,0 +1,92 @@ +#ifndef FAST_WFC_WFC_HPP_ +#define FAST_WFC_WFC_HPP_ + +#include +#include + +#include "propagator.h" +#include "array_2d.h" +#include "wave.h" + +/** + * Class containing the generic WFC algorithm. + */ +class WFC { +private: + /** + * The random number generator. + */ + std::minstd_rand gen; + + /** + * The distribution of the patterns as given in input. + */ + const std::vector patterns_frequencies; + + /** + * The wave, indicating which patterns can be put in which cell. + */ + Wave wave; + + /** + * The number of distinct patterns. + */ + const size_t nb_patterns; + + /** + * The propagator, used to propagate the information in the wave. + */ + Propagator propagator; + + /** + * Transform the wave to a valid output (a 2d array of patterns that aren't in + * contradiction). This function should be used only when all cell of the wave + * are defined. + */ + Array2D wave_to_output() const noexcept; + +public: + /** + * Basic constructor initializing the algorithm. + */ + WFC(bool periodic_output, int seed, std::vector patterns_frequencies, + Propagator::PropagatorState propagator, unsigned wave_height, + unsigned wave_width) + noexcept; + + /** + * Run the algorithm, and return a result if it succeeded. + */ + std::optional> run() noexcept; + + /** + * Return value of observe. + */ + enum ObserveStatus { + success, // WFC has finished and has succeeded. + failure, // WFC has finished and failed. + to_continue // WFC isn't finished. + }; + + /** + * Define the value of the cell with lowest entropy. + */ + ObserveStatus observe() noexcept; + + /** + * Propagate the information of the wave. + */ + void propagate() noexcept { propagator.propagate(wave); } + + /** + * Remove pattern from cell (i,j). + */ + void remove_wave_pattern(unsigned i, unsigned j, unsigned pattern) noexcept { + if (wave.get(i, j, pattern)) { + wave.set(i, j, pattern, false); + propagator.add_to_propagator(i, j, pattern); + } + } +}; + +#endif // FAST_WFC_WFC_HPP_ diff --git a/modules/wfc/wfc.hpp b/modules/wfc/wfc.hpp deleted file mode 100644 index ced4bbcd2..000000000 --- a/modules/wfc/wfc.hpp +++ /dev/null @@ -1,92 +0,0 @@ -#ifndef FAST_WFC_WFC_HPP_ -#define FAST_WFC_WFC_HPP_ - -#include -#include - -#include "utils/array2D.hpp" -#include "propagator.hpp" -#include "wave.hpp" - -/** - * Class containing the generic WFC algorithm. - */ -class WFC { -private: - /** - * The random number generator. - */ - std::minstd_rand gen; - - /** - * The distribution of the patterns as given in input. - */ - const std::vector patterns_frequencies; - - /** - * The wave, indicating which patterns can be put in which cell. - */ - Wave wave; - - /** - * The number of distinct patterns. - */ - const size_t nb_patterns; - - /** - * The propagator, used to propagate the information in the wave. - */ - Propagator propagator; - - /** - * Transform the wave to a valid output (a 2d array of patterns that aren't in - * contradiction). This function should be used only when all cell of the wave - * are defined. - */ - Array2D wave_to_output() const noexcept; - -public: - /** - * Basic constructor initializing the algorithm. - */ - WFC(bool periodic_output, int seed, std::vector patterns_frequencies, - Propagator::PropagatorState propagator, unsigned wave_height, - unsigned wave_width) - noexcept; - - /** - * Run the algorithm, and return a result if it succeeded. - */ - std::optional> run() noexcept; - - /** - * Return value of observe. - */ - enum ObserveStatus { - success, // WFC has finished and has succeeded. - failure, // WFC has finished and failed. - to_continue // WFC isn't finished. - }; - - /** - * Define the value of the cell with lowest entropy. - */ - ObserveStatus observe() noexcept; - - /** - * Propagate the information of the wave. - */ - void propagate() noexcept { propagator.propagate(wave); } - - /** - * Remove pattern from cell (i,j). - */ - void remove_wave_pattern(unsigned i, unsigned j, unsigned pattern) noexcept { - if (wave.get(i, j, pattern)) { - wave.set(i, j, pattern, false); - propagator.add_to_propagator(i, j, pattern); - } - } -}; - -#endif // FAST_WFC_WFC_HPP_