diff --git a/modules/denoise/SCsub b/modules/denoise/SCsub deleted file mode 100644 index 83193f522..000000000 --- a/modules/denoise/SCsub +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python - -import resource_to_cpp - -Import("env") -Import("env_modules") - -env_oidn = env_modules.Clone() - -# Thirdparty source files - -thirdparty_obj = [] - -thirdparty_dir = "#thirdparty/oidn/" -thirdparty_sources = [ - "core/api.cpp", - "core/device.cpp", - "core/filter.cpp", - "core/network.cpp", - "core/autoencoder.cpp", - "core/transfer_function.cpp", - "weights/rtlightmap_hdr.gen.cpp", - "mkl-dnn/src/common/batch_normalization.cpp", - "mkl-dnn/src/common/concat.cpp", - "mkl-dnn/src/common/convolution.cpp", - "mkl-dnn/src/common/convolution_pd.cpp", - "mkl-dnn/src/common/deconvolution.cpp", - "mkl-dnn/src/common/eltwise.cpp", - "mkl-dnn/src/common/engine.cpp", - "mkl-dnn/src/common/inner_product.cpp", - "mkl-dnn/src/common/inner_product_pd.cpp", - "mkl-dnn/src/common/lrn.cpp", - "mkl-dnn/src/common/memory.cpp", - "mkl-dnn/src/common/memory_desc_wrapper.cpp", - "mkl-dnn/src/common/mkldnn_debug.cpp", - "mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp", - "mkl-dnn/src/common/pooling.cpp", - "mkl-dnn/src/common/primitive.cpp", - "mkl-dnn/src/common/primitive_attr.cpp", - "mkl-dnn/src/common/primitive_desc.cpp", - "mkl-dnn/src/common/primitive_exec_types.cpp", - "mkl-dnn/src/common/primitive_iterator.cpp", - "mkl-dnn/src/common/query.cpp", - "mkl-dnn/src/common/reorder.cpp", - "mkl-dnn/src/common/rnn.cpp", - "mkl-dnn/src/common/scratchpad.cpp", - "mkl-dnn/src/common/shuffle.cpp", - "mkl-dnn/src/common/softmax.cpp", - "mkl-dnn/src/common/stream.cpp", - "mkl-dnn/src/common/sum.cpp", - "mkl-dnn/src/common/utils.cpp", - "mkl-dnn/src/common/verbose.cpp", - "mkl-dnn/src/cpu/cpu_barrier.cpp", - "mkl-dnn/src/cpu/cpu_concat.cpp", - "mkl-dnn/src/cpu/cpu_engine.cpp", - "mkl-dnn/src/cpu/cpu_memory.cpp", - "mkl-dnn/src/cpu/cpu_reducer.cpp", - "mkl-dnn/src/cpu/cpu_reorder.cpp", - "mkl-dnn/src/cpu/cpu_sum.cpp", - "mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp", - "mkl-dnn/src/cpu/jit_avx2_convolution.cpp", - "mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp", - "mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp", - "mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp", - "mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp", - "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp", - "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp", - "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp", - "mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp", - "mkl-dnn/src/cpu/jit_sse42_convolution.cpp", - "mkl-dnn/src/cpu/jit_transpose_src_utils.cpp", - "mkl-dnn/src/cpu/jit_uni_eltwise.cpp", - "mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp", - "mkl-dnn/src/cpu/jit_uni_pooling.cpp", - "mkl-dnn/src/cpu/jit_uni_reorder.cpp", - "mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp", - "mkl-dnn/src/cpu/jit_utils/jit_utils.cpp", - "mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c", - "common/platform.cpp", - "common/thread.cpp", - "common/tensor.cpp", -] -thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources] - -thirdparty_include_dirs = [ - "", - "include", - "mkl-dnn/include", - "mkl-dnn/src", - "mkl-dnn/src/common", - "mkl-dnn/src/cpu/xbyak", - "mkl-dnn/src/cpu", -] -thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs] - - -env_oidn.Prepend(CPPPATH=thirdparty_include_dirs) -env_oidn.Append( - CPPDEFINES=[ - "MKLDNN_THR=MKLDNN_THR_SEQ", - "OIDN_STATIC_LIB", - "__STDC_CONSTANT_MACROS", - "__STDC_LIMIT_MACROS", - "DISABLE_VERBOSE", - "MKLDNN_ENABLE_CONCURRENT_EXEC", - "NDEBUG", - ] -) - -env_thirdparty = env_oidn.Clone() -env_thirdparty.disable_warnings() -env_thirdparty.add_source_files(thirdparty_obj, thirdparty_sources) -env.modules_sources += thirdparty_obj - -if env["platform"] == "windows" and not env.msvc: - env_thirdparty.Append(CPPFLAGS=["-mstackrealign"]) - -weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza" -weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp" - -env_thirdparty.Depends(weights_out_path, weights_in_path) -env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp) - -# Godot source files - -module_obj = [] - -env_oidn.add_source_files(module_obj, "denoise_wrapper.cpp") -env_modules.add_source_files(module_obj, ["register_types.cpp", "lightmap_denoiser.cpp"]) -env.modules_sources += module_obj - -# Needed to force rebuilding the module files when the thirdparty library is updated. -env.Depends(module_obj, thirdparty_obj) diff --git a/modules/denoise/config.py b/modules/denoise/config.py deleted file mode 100644 index 84f312042..000000000 --- a/modules/denoise/config.py +++ /dev/null @@ -1,29 +0,0 @@ -def can_build(env, platform): - # Thirdparty dependency OpenImage Denoise includes oneDNN library - # and the version we use only supports x86_64. - # It's also only relevant for tools build and desktop platforms, - # as doing lightmap generation and denoising on Android or HTML5 - # would be a bit far-fetched. - # Note: oneDNN doesn't support ARM64, OIDN needs updating to the latest version - supported_platform = platform in ["x11", "osx", "windows", "server"] - supported_arch = env["bits"] == "64" - if env["arch"] == "arm64": - supported_arch = False - if env["arch"].startswith("ppc"): - supported_arch = False - if env["arch"].startswith("rv"): - supported_arch = False - - # Hack to disable on Linux arm64. This won't work well for cross-compilation (checks - # host, not target) and would need a more thorough fix by refactoring our arch and - # bits-handling code. - from platform import machine - - if platform == "x11" and machine() != "x86_64": - supported_arch = False - - return env["tools"] and supported_platform and supported_arch - - -def configure(env): - pass diff --git a/modules/denoise/denoise_wrapper.cpp b/modules/denoise/denoise_wrapper.cpp deleted file mode 100644 index 1e99eb65f..000000000 --- a/modules/denoise/denoise_wrapper.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/*************************************************************************/ -/* denoise_wrapper.cpp */ -/*************************************************************************/ -/* This file is part of: */ -/* GODOT ENGINE */ -/* https://godotengine.org */ -/*************************************************************************/ -/* Copyright (c) 2007-2022 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2022 Godot Engine contributors (cf. AUTHORS.md). */ -/* */ -/* Permission is hereby granted, free of charge, to any person obtaining */ -/* a copy of this software and associated documentation files (the */ -/* "Software"), to deal in the Software without restriction, including */ -/* without limitation the rights to use, copy, modify, merge, publish, */ -/* distribute, sublicense, and/or sell copies of the Software, and to */ -/* permit persons to whom the Software is furnished to do so, subject to */ -/* the following conditions: */ -/* */ -/* The above copyright notice and this permission notice shall be */ -/* included in all copies or substantial portions of the Software. */ -/* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ -/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ -/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ -/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ -/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ -/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ -/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/*************************************************************************/ - -#include "denoise_wrapper.h" -#include "core/os/memory.h" -#include "thirdparty/oidn/include/OpenImageDenoise/oidn.h" -#include - -void *oidn_denoiser_init() { - OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU); - oidnCommitDevice(device); - return device; -} - -bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) { - OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr; - OIDNFilter filter = oidnNewFilter(device, "RTLightmap"); - oidnSetSharedFilterImage(filter, "color", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); - oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); - oidnSetFilter1b(filter, "hdr", true); - oidnCommitFilter(filter); - oidnExecuteFilter(filter); - - const char *msg; - bool success = true; - if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) { - printf("LightmapDenoiser: %s\n", msg); - success = false; - } - - oidnReleaseFilter(filter); - return success; -} - -void oidn_denoiser_finish(void *device) { - oidnReleaseDevice((OIDNDeviceImpl *)device); -} diff --git a/modules/denoise/denoise_wrapper.h b/modules/denoise/denoise_wrapper.h deleted file mode 100644 index 44e61ce31..000000000 --- a/modules/denoise/denoise_wrapper.h +++ /dev/null @@ -1,38 +0,0 @@ -/*************************************************************************/ -/* denoise_wrapper.h */ -/*************************************************************************/ -/* This file is part of: */ -/* GODOT ENGINE */ -/* https://godotengine.org */ -/*************************************************************************/ -/* Copyright (c) 2007-2022 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2022 Godot Engine contributors (cf. AUTHORS.md). */ -/* */ -/* Permission is hereby granted, free of charge, to any person obtaining */ -/* a copy of this software and associated documentation files (the */ -/* "Software"), to deal in the Software without restriction, including */ -/* without limitation the rights to use, copy, modify, merge, publish, */ -/* distribute, sublicense, and/or sell copies of the Software, and to */ -/* permit persons to whom the Software is furnished to do so, subject to */ -/* the following conditions: */ -/* */ -/* The above copyright notice and this permission notice shall be */ -/* included in all copies or substantial portions of the Software. */ -/* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ -/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ -/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ -/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ -/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ -/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ -/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/*************************************************************************/ - -#ifndef DENOISE_WRAPPER_H -#define DENOISE_WRAPPER_H - -void *oidn_denoiser_init(); -bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height); -void oidn_denoiser_finish(void *device); - -#endif // DENOISE_WRAPPER_H diff --git a/modules/denoise/lightmap_denoiser.cpp b/modules/denoise/lightmap_denoiser.cpp deleted file mode 100644 index d0a71e116..000000000 --- a/modules/denoise/lightmap_denoiser.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/*************************************************************************/ -/* lightmap_denoiser.cpp */ -/*************************************************************************/ -/* This file is part of: */ -/* GODOT ENGINE */ -/* https://godotengine.org */ -/*************************************************************************/ -/* Copyright (c) 2007-2022 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2022 Godot Engine contributors (cf. AUTHORS.md). */ -/* */ -/* Permission is hereby granted, free of charge, to any person obtaining */ -/* a copy of this software and associated documentation files (the */ -/* "Software"), to deal in the Software without restriction, including */ -/* without limitation the rights to use, copy, modify, merge, publish, */ -/* distribute, sublicense, and/or sell copies of the Software, and to */ -/* permit persons to whom the Software is furnished to do so, subject to */ -/* the following conditions: */ -/* */ -/* The above copyright notice and this permission notice shall be */ -/* included in all copies or substantial portions of the Software. */ -/* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ -/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ -/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ -/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ -/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ -/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ -/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/*************************************************************************/ - -#include "lightmap_denoiser.h" - -#include "denoise_wrapper.h" - -LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() { - return memnew(LightmapDenoiserOIDN); -} - -void LightmapDenoiserOIDN::make_default_denoiser() { - create_function = create_oidn_denoiser; -} - -Ref LightmapDenoiserOIDN::denoise_image(const Ref &p_image) { - Ref img = p_image->duplicate(); - - img->convert(Image::FORMAT_RGBF); - - PoolByteArray data = img->get_data(); - { - PoolByteArray::Write w = data.write(); - if (!oidn_denoise(device, (float *)w.ptr(), img->get_width(), img->get_height())) { - return p_image; - } - } - - img->create(img->get_width(), img->get_height(), false, img->get_format(), data); - return img; -} - -LightmapDenoiserOIDN::LightmapDenoiserOIDN() { - device = oidn_denoiser_init(); -} - -LightmapDenoiserOIDN::~LightmapDenoiserOIDN() { - oidn_denoiser_finish(device); -} diff --git a/modules/denoise/lightmap_denoiser.h b/modules/denoise/lightmap_denoiser.h deleted file mode 100644 index 63ecd8025..000000000 --- a/modules/denoise/lightmap_denoiser.h +++ /dev/null @@ -1,56 +0,0 @@ -/*************************************************************************/ -/* lightmap_denoiser.h */ -/*************************************************************************/ -/* This file is part of: */ -/* GODOT ENGINE */ -/* https://godotengine.org */ -/*************************************************************************/ -/* Copyright (c) 2007-2022 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2022 Godot Engine contributors (cf. AUTHORS.md). */ -/* */ -/* Permission is hereby granted, free of charge, to any person obtaining */ -/* a copy of this software and associated documentation files (the */ -/* "Software"), to deal in the Software without restriction, including */ -/* without limitation the rights to use, copy, modify, merge, publish, */ -/* distribute, sublicense, and/or sell copies of the Software, and to */ -/* permit persons to whom the Software is furnished to do so, subject to */ -/* the following conditions: */ -/* */ -/* The above copyright notice and this permission notice shall be */ -/* included in all copies or substantial portions of the Software. */ -/* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ -/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ -/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ -/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ -/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ -/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ -/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/*************************************************************************/ - -#ifndef LIGHTMAP_DENOISER_H -#define LIGHTMAP_DENOISER_H - -#include "core/class_db.h" -#include "scene/3d/lightmapper.h" - -struct OIDNDeviceImpl; - -class LightmapDenoiserOIDN : public LightmapDenoiser { - GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser); - -protected: - void *device = nullptr; - -public: - static LightmapDenoiser *create_oidn_denoiser(); - - Ref denoise_image(const Ref &p_image); - - static void make_default_denoiser(); - - LightmapDenoiserOIDN(); - ~LightmapDenoiserOIDN(); -}; - -#endif // LIGHTMAP_DENOISER_H diff --git a/modules/denoise/register_types.cpp b/modules/denoise/register_types.cpp deleted file mode 100644 index 9a8c9ad29..000000000 --- a/modules/denoise/register_types.cpp +++ /dev/null @@ -1,41 +0,0 @@ -/*************************************************************************/ -/* register_types.cpp */ -/*************************************************************************/ -/* This file is part of: */ -/* GODOT ENGINE */ -/* https://godotengine.org */ -/*************************************************************************/ -/* Copyright (c) 2007-2022 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2022 Godot Engine contributors (cf. AUTHORS.md). */ -/* */ -/* Permission is hereby granted, free of charge, to any person obtaining */ -/* a copy of this software and associated documentation files (the */ -/* "Software"), to deal in the Software without restriction, including */ -/* without limitation the rights to use, copy, modify, merge, publish, */ -/* distribute, sublicense, and/or sell copies of the Software, and to */ -/* permit persons to whom the Software is furnished to do so, subject to */ -/* the following conditions: */ -/* */ -/* The above copyright notice and this permission notice shall be */ -/* included in all copies or substantial portions of the Software. */ -/* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ -/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ -/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ -/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ -/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ -/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ -/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/*************************************************************************/ - -#include "register_types.h" - -#include "core/engine.h" -#include "lightmap_denoiser.h" - -void register_denoise_types() { - LightmapDenoiserOIDN::make_default_denoiser(); -} - -void unregister_denoise_types() { -} diff --git a/modules/denoise/register_types.h b/modules/denoise/register_types.h deleted file mode 100644 index 6ce386dc5..000000000 --- a/modules/denoise/register_types.h +++ /dev/null @@ -1,37 +0,0 @@ -/*************************************************************************/ -/* register_types.h */ -/*************************************************************************/ -/* This file is part of: */ -/* GODOT ENGINE */ -/* https://godotengine.org */ -/*************************************************************************/ -/* Copyright (c) 2007-2022 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2022 Godot Engine contributors (cf. AUTHORS.md). */ -/* */ -/* Permission is hereby granted, free of charge, to any person obtaining */ -/* a copy of this software and associated documentation files (the */ -/* "Software"), to deal in the Software without restriction, including */ -/* without limitation the rights to use, copy, modify, merge, publish, */ -/* distribute, sublicense, and/or sell copies of the Software, and to */ -/* permit persons to whom the Software is furnished to do so, subject to */ -/* the following conditions: */ -/* */ -/* The above copyright notice and this permission notice shall be */ -/* included in all copies or substantial portions of the Software. */ -/* */ -/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ -/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ -/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ -/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ -/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ -/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ -/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -/*************************************************************************/ - -#ifndef DENOISE_REGISTER_TYPES_H -#define DENOISE_REGISTER_TYPES_H - -void register_denoise_types(); -void unregister_denoise_types(); - -#endif // DENOISE_REGISTER_TYPES_H diff --git a/modules/denoise/resource_to_cpp.py b/modules/denoise/resource_to_cpp.py deleted file mode 100644 index 6c8327735..000000000 --- a/modules/denoise/resource_to_cpp.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python - -## ======================================================================== ## -## Copyright 2009-2019 Intel Corporation ## -## ## -## Licensed under the Apache License, Version 2.0 (the "License"); ## -## you may not use this file except in compliance with the License. ## -## You may obtain a copy of the License at ## -## ## -## http://www.apache.org/licenses/LICENSE-2.0 ## -## ## -## Unless required by applicable law or agreed to in writing, software ## -## distributed under the License is distributed on an "AS IS" BASIS, ## -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ## -## See the License for the specific language governing permissions and ## -## limitations under the License. ## -## ======================================================================== ## - -import os -from array import array - -# Generates a C++ file from the specified binary resource file -def generate(in_path, out_path): - - namespace = "oidn::weights" - scopes = namespace.split("::") - - file_name = os.path.basename(in_path) - var_name = os.path.splitext(file_name)[0] - - with open(in_path, "rb") as in_file, open(out_path, "w") as out_file: - # Header - out_file.write("// Generated from: %s\n" % file_name) - out_file.write("#include \n\n") - - # Open the namespaces - for s in scopes: - out_file.write("namespace %s {\n" % s) - if scopes: - out_file.write("\n") - - # Read the file - in_data = array("B", in_file.read()) - - # Write the size - out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data))) - - # Write the data - out_file.write("unsigned char %s[] = {" % var_name) - for i in range(len(in_data)): - c = in_data[i] - if i > 0: - out_file.write(",") - if (i + 1) % 20 == 1: - out_file.write("\n") - out_file.write("%d" % c) - out_file.write("\n};\n") - - # Close the namespaces - if scopes: - out_file.write("\n") - for scope in reversed(scopes): - out_file.write("} // namespace %s\n" % scope) - - -def tza_to_cpp(target, source, env): - for x in zip(source, target): - generate(str(x[0]), str(x[1])) diff --git a/thirdparty/oidn/LICENSE.txt b/thirdparty/oidn/LICENSE.txt deleted file mode 100644 index d64569567..000000000 --- a/thirdparty/oidn/LICENSE.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/thirdparty/oidn/common/barrier.h b/thirdparty/oidn/common/barrier.h deleted file mode 100644 index b20f67005..000000000 --- a/thirdparty/oidn/common/barrier.h +++ /dev/null @@ -1,52 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "platform.h" -#include -#include - -namespace oidn { - - class Barrier - { - private: - std::mutex m; - std::condition_variable cv; - volatile int count; - - public: - Barrier(int count) : count(count) {} - - void wait() - { - std::unique_lock lk(m); - count--; - - if (count == 0) - { - lk.unlock(); - cv.notify_all(); - } - else - { - cv.wait(lk, [&]{ return count == 0; }); - } - } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/common/exception.h b/thirdparty/oidn/common/exception.h deleted file mode 100644 index 18069c6a7..000000000 --- a/thirdparty/oidn/common/exception.h +++ /dev/null @@ -1,45 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include -#include "platform.h" - -namespace oidn { - - class Exception : public std::exception - { - private: - Error error; - const char* message; - - public: - Exception(Error error, const char* message) - : error(error), message(message) {} - - Error code() const noexcept - { - return error; - } - - const char* what() const noexcept override - { - return message; - } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/common/platform.cpp b/thirdparty/oidn/common/platform.cpp deleted file mode 100644 index 59a14ff47..000000000 --- a/thirdparty/oidn/common/platform.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "platform.h" - -namespace oidn { - - // ---------------------------------------------------------------------------- - // Common functions - // ---------------------------------------------------------------------------- - - void* alignedMalloc(size_t size, size_t alignment) - { - if (size == 0) - return nullptr; - - assert((alignment & (alignment-1)) == 0); - void* ptr = _mm_malloc(size, alignment); - - if (ptr == nullptr) - throw std::bad_alloc(); - - return ptr; - } - - void alignedFree(void* ptr) - { - if (ptr) - _mm_free(ptr); - } - - // ---------------------------------------------------------------------------- - // System information - // ---------------------------------------------------------------------------- - - std::string getPlatformName() - { - std::string name; - - #if defined(__linux__) - name = "Linux"; - #elif defined(__FreeBSD__) - name = "FreeBSD"; - #elif defined(__CYGWIN__) - name = "Cygwin"; - #elif defined(_WIN32) - name = "Windows"; - #elif defined(__APPLE__) - name = "macOS"; - #elif defined(__unix__) - name = "Unix"; - #else - return "Unknown"; - #endif - - #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__) - name += " (64-bit)"; - #else - name += " (32-bit)"; - #endif - - return name; - } - - std::string getCompilerName() - { - #if defined(__INTEL_COMPILER) - int mayor = __INTEL_COMPILER / 100 % 100; - int minor = __INTEL_COMPILER % 100; - std::string version = "Intel Compiler "; - version += toString(mayor); - version += "." + toString(minor); - #if defined(__INTEL_COMPILER_UPDATE) - version += "." + toString(__INTEL_COMPILER_UPDATE); - #endif - return version; - #elif defined(__clang__) - return "Clang " __clang_version__; - #elif defined(__GNUC__) - return "GCC " __VERSION__; - #elif defined(_MSC_VER) - std::string version = toString(_MSC_FULL_VER); - version.insert(4, "."); - version.insert(9, "."); - version.insert(2, "."); - return "Visual C++ Compiler " + version; - #else - return "Unknown"; - #endif - } - - std::string getBuildName() - { - #if defined(NDEBUG) - return "Release"; - #else - return "Debug"; - #endif - } - -} // namespace oidn diff --git a/thirdparty/oidn/common/platform.h b/thirdparty/oidn/common/platform.h deleted file mode 100644 index 9373b617b..000000000 --- a/thirdparty/oidn/common/platform.h +++ /dev/null @@ -1,131 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN - #define NOMINMAX - #include -#elif defined(__APPLE__) - #include -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "include/OpenImageDenoise/oidn.hpp" - -namespace oidn { - - // ---------------------------------------------------------------------------- - // Macros - // ---------------------------------------------------------------------------- - - #if defined(_WIN32) - // Windows - #if !defined(__noinline) - #define __noinline __declspec(noinline) - #endif - #else - // Unix - #if !defined(__forceinline) - #define __forceinline inline __attribute__((always_inline)) - #endif - #if !defined(__noinline) - #define __noinline __attribute__((noinline)) - #endif - #endif - - #ifndef UNUSED - #define UNUSED(x) ((void)x) - #endif - #ifndef MAYBE_UNUSED - #define MAYBE_UNUSED(x) UNUSED(x) - #endif - - // ---------------------------------------------------------------------------- - // Error handling and debugging - // ---------------------------------------------------------------------------- - - struct Verbose - { - int verbose; - - Verbose(int v = 0) : verbose(v) {} - __forceinline bool isVerbose(int v = 1) const { return v <= verbose; } - }; - - #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; } - #define OIDN_FATAL(message) throw std::runtime_error(message); - - // ---------------------------------------------------------------------------- - // Common functions - // ---------------------------------------------------------------------------- - - using std::min; - using std::max; - - template - __forceinline T clamp(const T& value, const T& minValue, const T& maxValue) - { - return min(max(value, minValue), maxValue); - } - - void* alignedMalloc(size_t size, size_t alignment); - void alignedFree(void* ptr); - - template - inline std::string toString(const T& a) - { - std::stringstream sm; - sm << a; - return sm.str(); - } - -#if defined(__APPLE__) - template - bool getSysctl(const char* name, T& value) - { - int64_t result = 0; - size_t size = sizeof(result); - - if (sysctlbyname(name, &result, &size, nullptr, 0) != 0) - return false; - - value = T(result); - return true; - } -#endif - - // ---------------------------------------------------------------------------- - // System information - // ---------------------------------------------------------------------------- - - std::string getPlatformName(); - std::string getCompilerName(); - std::string getBuildName(); - -} // namespace oidn diff --git a/thirdparty/oidn/common/ref.h b/thirdparty/oidn/common/ref.h deleted file mode 100644 index de44603af..000000000 --- a/thirdparty/oidn/common/ref.h +++ /dev/null @@ -1,163 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "platform.h" - -namespace oidn { - - class RefCount - { - private: - std::atomic count; - - public: - __forceinline RefCount(int count = 0) noexcept : count(count) {} - - __forceinline size_t incRef() noexcept - { - return count.fetch_add(1) + 1; - } - - __forceinline size_t decRef() - { - const size_t newCount = decRefKeep(); - if (newCount == 0) - destroy(); - return newCount; - } - - __forceinline size_t decRefKeep() noexcept - { - return count.fetch_add(-1) - 1; - } - - __forceinline void destroy() - { - delete this; - } - - protected: - // Disable copying - RefCount(const RefCount&) = delete; - RefCount& operator =(const RefCount&) = delete; - - virtual ~RefCount() noexcept = default; - }; - - template - class Ref - { - private: - T* ptr; - - public: - __forceinline Ref() noexcept : ptr(nullptr) {} - __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {} - __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); } - __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; } - __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } - - template - __forceinline Ref(const Ref& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); } - - template - __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } - - __forceinline ~Ref() { if (ptr) ptr->decRef(); } - - __forceinline Ref& operator =(const Ref& other) - { - if (other.ptr) - other.ptr->incRef(); - if (ptr) - ptr->decRef(); - ptr = other.ptr; - return *this; - } - - __forceinline Ref& operator =(Ref&& other) - { - if (ptr) - ptr->decRef(); - ptr = other.ptr; - other.ptr = nullptr; - return *this; - } - - __forceinline Ref& operator =(T* other) - { - if (other) - other->incRef(); - if (ptr) - ptr->decRef(); - ptr = other; - return *this; - } - - __forceinline Ref& operator =(std::nullptr_t) - { - if (ptr) - ptr->decRef(); - ptr = nullptr; - return *this; - } - - __forceinline operator bool() const noexcept { return ptr != nullptr; } - - __forceinline T& operator *() const noexcept { return *ptr; } - __forceinline T* operator ->() const noexcept { return ptr; } - - __forceinline T* get() const noexcept { return ptr; } - - __forceinline T* detach() noexcept - { - T* res = ptr; - ptr = nullptr; - return res; - } - }; - - template __forceinline bool operator < (const Ref& a, const Ref& b) noexcept { return a.ptr < b.ptr; } - - template __forceinline bool operator ==(const Ref& a, std::nullptr_t) noexcept { return a.ptr == nullptr; } - template __forceinline bool operator ==(std::nullptr_t, const Ref& b) noexcept { return nullptr == b.ptr; } - template __forceinline bool operator ==(const Ref& a, const Ref& b) noexcept { return a.ptr == b.ptr; } - - template __forceinline bool operator !=(const Ref& a, std::nullptr_t) noexcept { return a.ptr != nullptr; } - template __forceinline bool operator !=(std::nullptr_t, const Ref& b) noexcept { return nullptr != b.ptr; } - template __forceinline bool operator !=(const Ref& a, const Ref& b) noexcept { return a.ptr != b.ptr; } - - template - __forceinline Ref makeRef(Args&&... args) - { - return Ref(new T(std::forward(args)...)); - } - - template - __forceinline Ref staticRefCast(const Ref& a) - { - return Ref(static_cast(a.get())); - } - - template - __forceinline Ref dynamicRefCast(const Ref& a) - { - return Ref(dynamic_cast(a.get())); - } - -} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.cpp b/thirdparty/oidn/common/tensor.cpp deleted file mode 100644 index 0249f2e14..000000000 --- a/thirdparty/oidn/common/tensor.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "exception.h" -#include "tensor.h" - -namespace oidn { - - std::map parseTensors(void* buffer) - { - char* input = (char*)buffer; - - // Parse the magic value - const int magic = *(unsigned short*)input; - if (magic != 0x41D7) - throw Exception(Error::InvalidOperation, "invalid tensor archive"); - input += sizeof(unsigned short); - - // Parse the version - const int majorVersion = *(unsigned char*)input++; - const int minorVersion = *(unsigned char*)input++; - UNUSED(minorVersion); - if (majorVersion > 1) - throw Exception(Error::InvalidOperation, "unsupported tensor archive version"); - - // Parse the number of tensors - const int numTensors = *(int*)input; - input += sizeof(int); - - // Parse the tensors - std::map tensorMap; - for (int i = 0; i < numTensors; ++i) - { - Tensor tensor; - - // Parse the name - const int nameLen = *(unsigned char*)input++; - std::string name(input, nameLen); - input += nameLen; - - // Parse the number of dimensions - const int ndims = *(unsigned char*)input++; - - // Parse the shape of the tensor - tensor.dims.resize(ndims); - for (int i = 0; i < ndims; ++i) - tensor.dims[i] = ((int*)input)[i]; - input += ndims * sizeof(int); - - // Parse the format of the tensor - tensor.format = std::string(input, input + ndims); - input += ndims; - - // Parse the data type of the tensor - const char type = *(unsigned char*)input++; - if (type != 'f') // only float32 is supported - throw Exception(Error::InvalidOperation, "unsupported tensor data type"); - - // Skip the data - tensor.data = (float*)input; - input += tensor.size() * sizeof(float); - - // Add the tensor to the map - tensorMap.emplace(name, std::move(tensor)); - } - - return tensorMap; - } - -} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h deleted file mode 100644 index 48e7d1123..000000000 --- a/thirdparty/oidn/common/tensor.h +++ /dev/null @@ -1,66 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "platform.h" -#include -#include - -namespace oidn { - - template - using shared_vector = std::shared_ptr>; - - // Generic tensor - struct Tensor - { - float* data; - std::vector dims; - std::string format; - shared_vector buffer; // optional, only for reference counting - - __forceinline Tensor() : data(nullptr) {} - - __forceinline Tensor(const std::vector& dims, const std::string& format) - : dims(dims), - format(format) - { - buffer = std::make_shared>(size() * sizeof(float)); - data = (float*)buffer->data(); - } - - __forceinline operator bool() const { return data != nullptr; } - - __forceinline int ndims() const { return (int)dims.size(); } - - // Returns the number of values - __forceinline size_t size() const - { - size_t size = 1; - for (int i = 0; i < ndims(); ++i) - size *= dims[i]; - return size; - } - - __forceinline float& operator [](size_t i) { return data[i]; } - __forceinline const float& operator [](size_t i) const { return data[i]; } - }; - - // Parses tensors from a buffer - std::map parseTensors(void* buffer); - -} // namespace oidn diff --git a/thirdparty/oidn/common/thread.cpp b/thirdparty/oidn/common/thread.cpp deleted file mode 100644 index 48c489c57..000000000 --- a/thirdparty/oidn/common/thread.cpp +++ /dev/null @@ -1,297 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#if defined(_MSC_VER) - #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned -#endif - -#if defined(__APPLE__) - #include - #include -#endif - -#include "thread.h" -#include - -namespace oidn { - -#if defined(_WIN32) - - // -------------------------------------------------------------------------- - // ThreadAffinity - Windows - // -------------------------------------------------------------------------- - - ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) - : Verbose(verbose) - { - HMODULE hLib = GetModuleHandle(TEXT("kernel32")); - pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx"); - pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity"); - - if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity) - { - // Get logical processor information - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr; - DWORD bufferSize = 0; - - // First call the function with an empty buffer to get the required buffer size - BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); - if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER) - { - OIDN_WARNING("GetLogicalProcessorInformationEx failed"); - return; - } - - // Allocate the buffer - buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize); - if (!buffer) - { - OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed"); - return; - } - - // Call again the function but now with the properly sized buffer - result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); - if (!result) - { - OIDN_WARNING("GetLogicalProcessorInformationEx failed"); - free(buffer); - return; - } - - // Iterate over the logical processor information structures - // There should be one structure for each physical core - char* ptr = (char*)buffer; - while (ptr < (char*)buffer + bufferSize) - { - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr; - if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0) - { - // Iterate over the groups - int numThreads = 0; - for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group) - { - GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group]; - while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore)) - { - // Extract the next set bit/thread from the mask - GROUP_AFFINITY threadAffinity = coreAffinity; - threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask; - - // Push the affinity for this thread - affinities.push_back(threadAffinity); - oldAffinities.push_back(threadAffinity); - numThreads++; - - // Remove this bit/thread from the mask - coreAffinity.Mask ^= threadAffinity.Mask; - } - } - } - - // Next structure - ptr += item->Size; - } - - // Free the buffer - free(buffer); - } - } - - void ThreadAffinity::set(int threadIndex) - { - if (threadIndex >= (int)affinities.size()) - return; - - // Save the current affinity and set the new one - const HANDLE thread = GetCurrentThread(); - if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex])) - OIDN_WARNING("SetThreadGroupAffinity failed"); - } - - void ThreadAffinity::restore(int threadIndex) - { - if (threadIndex >= (int)affinities.size()) - return; - - // Restore the original affinity - const HANDLE thread = GetCurrentThread(); - if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr)) - OIDN_WARNING("SetThreadGroupAffinity failed"); - } - -#elif defined(__linux__) - - // -------------------------------------------------------------------------- - // ThreadAffinity - Linux - // -------------------------------------------------------------------------- - - ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) - : Verbose(verbose) - { - std::vector threadIds; - - // Parse the thread/CPU topology - for (int cpuId = 0; ; cpuId++) - { - std::fstream fs; - std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list"); - fs.open(cpu.c_str(), std::fstream::in); - if (fs.fail()) break; - - int i; - int j = 0; - while ((j < numThreadsPerCore) && (fs >> i)) - { - if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; })) - threadIds.push_back(i); - - if (fs.peek() == ',') - fs.ignore(); - j++; - } - - fs.close(); - } - - #if 0 - for (size_t i = 0; i < thread_ids.size(); ++i) - std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl; - #endif - - // Create the affinity structures - affinities.resize(threadIds.size()); - oldAffinities.resize(threadIds.size()); - - for (size_t i = 0; i < threadIds.size(); ++i) - { - cpu_set_t affinity; - CPU_ZERO(&affinity); - CPU_SET(threadIds[i], &affinity); - - affinities[i] = affinity; - oldAffinities[i] = affinity; - } - } - - void ThreadAffinity::set(int threadIndex) - { - if (threadIndex >= (int)affinities.size()) - return; - - const pthread_t thread = pthread_self(); - - // Save the current affinity - if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) - { - OIDN_WARNING("pthread_getaffinity_np failed"); - oldAffinities[threadIndex] = affinities[threadIndex]; - return; - } - - // Set the new affinity - if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0) - OIDN_WARNING("pthread_setaffinity_np failed"); - } - - void ThreadAffinity::restore(int threadIndex) - { - if (threadIndex >= (int)affinities.size()) - return; - - const pthread_t thread = pthread_self(); - - // Restore the original affinity - if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) - OIDN_WARNING("pthread_setaffinity_np failed"); - } - -#elif defined(__APPLE__) - - // -------------------------------------------------------------------------- - // ThreadAffinity - macOS - // -------------------------------------------------------------------------- - - ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) - : Verbose(verbose) - { - // Query the thread/CPU topology - int numPhysicalCpus; - int numLogicalCpus; - - if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus)) - { - OIDN_WARNING("sysctlbyname failed"); - return; - } - - if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1)) - return; // this shouldn't happen - const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus; - - // Create the affinity structures - // macOS doesn't support binding a thread to a specific core, but we can at least group threads which - // should be on the same core together - for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1! - { - thread_affinity_policy affinity; - affinity.affinity_tag = core; - - for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread) - { - affinities.push_back(affinity); - oldAffinities.push_back(affinity); - } - } - } - - void ThreadAffinity::set(int threadIndex) - { - if (threadIndex >= (int)affinities.size()) - return; - - const auto thread = mach_thread_self(); - - // Save the current affinity - mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT; - boolean_t getDefault = FALSE; - if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS) - { - OIDN_WARNING("thread_policy_get failed"); - oldAffinities[threadIndex] = affinities[threadIndex]; - return; - } - - // Set the new affinity - if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) - OIDN_WARNING("thread_policy_set failed"); - } - - void ThreadAffinity::restore(int threadIndex) - { - if (threadIndex >= (int)affinities.size()) - return; - - const auto thread = mach_thread_self(); - - // Restore the original affinity - if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) - OIDN_WARNING("thread_policy_set failed"); - } - -#endif - -} // namespace oidn diff --git a/thirdparty/oidn/common/thread.h b/thirdparty/oidn/common/thread.h deleted file mode 100644 index 2c731367d..000000000 --- a/thirdparty/oidn/common/thread.h +++ /dev/null @@ -1,202 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "platform.h" - -#if !defined(_WIN32) - #include - #include - #if defined(__APPLE__) - #include - #endif -#endif - -#include -#include - -namespace oidn { - - // -------------------------------------------------------------------------- - // ThreadLocal - // -------------------------------------------------------------------------- - - // Wrapper which makes any variable thread-local - template - class ThreadLocal : public Verbose - { - private: - #if defined(_WIN32) - DWORD key; - #else - pthread_key_t key; - #endif - - std::vector instances; - std::mutex mutex; - - public: - ThreadLocal(int verbose = 0) - : Verbose(verbose) - { - #if defined(_WIN32) - key = TlsAlloc(); - if (key == TLS_OUT_OF_INDEXES) - OIDN_FATAL("TlsAlloc failed"); - #else - if (pthread_key_create(&key, nullptr) != 0) - OIDN_FATAL("pthread_key_create failed"); - #endif - } - - ~ThreadLocal() - { - std::lock_guard lock(mutex); - for (T* ptr : instances) - delete ptr; - - #if defined(_WIN32) - if (!TlsFree(key)) - OIDN_WARNING("TlsFree failed"); - #else - if (pthread_key_delete(key) != 0) - OIDN_WARNING("pthread_key_delete failed"); - #endif - } - - T& get() - { - #if defined(_WIN32) - T* ptr = (T*)TlsGetValue(key); - #else - T* ptr = (T*)pthread_getspecific(key); - #endif - - if (ptr) - return *ptr; - - ptr = new T; - std::lock_guard lock(mutex); - instances.push_back(ptr); - - #if defined(_WIN32) - if (!TlsSetValue(key, ptr)) - OIDN_FATAL("TlsSetValue failed"); - #else - if (pthread_setspecific(key, ptr) != 0) - OIDN_FATAL("pthread_setspecific failed"); - #endif - - return *ptr; - } - }; - -#if defined(_WIN32) - - // -------------------------------------------------------------------------- - // ThreadAffinity - Windows - // -------------------------------------------------------------------------- - - class ThreadAffinity : public Verbose - { - private: - typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP, - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, - PDWORD); - - typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE, - CONST GROUP_AFFINITY*, - PGROUP_AFFINITY); - - GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr; - SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr; - - std::vector affinities; // thread affinities - std::vector oldAffinities; // original thread affinities - - public: - ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); - - int getNumThreads() const - { - return (int)affinities.size(); - } - - // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity - void set(int threadIndex); - - // Restores the affinity of the thread - void restore(int threadIndex); - }; - -#elif defined(__linux__) - - // -------------------------------------------------------------------------- - // ThreadAffinity - Linux - // -------------------------------------------------------------------------- - - class ThreadAffinity : public Verbose - { - private: - std::vector affinities; // thread affinities - std::vector oldAffinities; // original thread affinities - - public: - ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); - - int getNumThreads() const - { - return (int)affinities.size(); - } - - // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity - void set(int threadIndex); - - // Restores the affinity of the thread - void restore(int threadIndex); - }; - -#elif defined(__APPLE__) - - // -------------------------------------------------------------------------- - // ThreadAffinity - macOS - // -------------------------------------------------------------------------- - - class ThreadAffinity : public Verbose - { - private: - std::vector affinities; // thread affinities - std::vector oldAffinities; // original thread affinities - - public: - ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); - - int getNumThreads() const - { - return (int)affinities.size(); - } - - // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity - void set(int threadIndex); - - // Restores the affinity of the thread - void restore(int threadIndex); - }; - -#endif - -} // namespace oidn diff --git a/thirdparty/oidn/common/timer.h b/thirdparty/oidn/common/timer.h deleted file mode 100644 index 62aaaa1c3..000000000 --- a/thirdparty/oidn/common/timer.h +++ /dev/null @@ -1,49 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "platform.h" -#include - -namespace oidn { - - class Timer - { - private: - using clock = std::chrono::high_resolution_clock; - - std::chrono::time_point start; - - public: - Timer() - { - reset(); - } - - void reset() - { - start = clock::now(); - } - - double query() const - { - auto end = clock::now(); - return std::chrono::duration_cast>(end - start).count(); - } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/api.cpp b/thirdparty/oidn/core/api.cpp deleted file mode 100644 index 7353fe4e2..000000000 --- a/thirdparty/oidn/core/api.cpp +++ /dev/null @@ -1,408 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#ifdef _WIN32 -# define OIDN_API extern "C" __declspec(dllexport) -#else -# define OIDN_API extern "C" __attribute__ ((visibility ("default"))) -#endif - -// Locks the device that owns the specified object -// Use *only* inside OIDN_TRY/CATCH! -#define OIDN_LOCK(obj) \ - std::lock_guard lock(obj->getDevice()->getMutex()); - -// Try/catch for converting exceptions to errors -#define OIDN_TRY \ - try { - -#define OIDN_CATCH(obj) \ - } catch (Exception& e) { \ - Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \ - } catch (std::bad_alloc&) { \ - Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ - } catch (mkldnn::error& e) { \ - if (e.status == mkldnn_out_of_memory) \ - Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ - else \ - Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \ - } catch (std::exception& e) { \ - Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \ - } catch (...) { \ - Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \ - } - -#include "device.h" -#include "filter.h" -#include - -namespace oidn { - - namespace - { - __forceinline void checkHandle(void* handle) - { - if (handle == nullptr) - throw Exception(Error::InvalidArgument, "invalid handle"); - } - - template - __forceinline void retainObject(T* obj) - { - if (obj) - { - obj->incRef(); - } - else - { - OIDN_TRY - checkHandle(obj); - OIDN_CATCH(obj) - } - } - - template - __forceinline void releaseObject(T* obj) - { - if (obj == nullptr || obj->decRefKeep() == 0) - { - OIDN_TRY - checkHandle(obj); - OIDN_LOCK(obj); - obj->destroy(); - OIDN_CATCH(obj) - } - } - - template<> - __forceinline void releaseObject(Device* obj) - { - if (obj == nullptr || obj->decRefKeep() == 0) - { - OIDN_TRY - checkHandle(obj); - // Do NOT lock the device because it owns the mutex - obj->destroy(); - OIDN_CATCH(obj) - } - } - } - - OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type) - { - Ref device = nullptr; - OIDN_TRY - if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT) - device = makeRef(); - else - throw Exception(Error::InvalidArgument, "invalid device type"); - OIDN_CATCH(device) - return (OIDNDevice)device.detach(); - } - - OIDN_API void oidnRetainDevice(OIDNDevice hDevice) - { - Device* device = (Device*)hDevice; - retainObject(device); - } - - OIDN_API void oidnReleaseDevice(OIDNDevice hDevice) - { - Device* device = (Device*)hDevice; - releaseObject(device); - } - - OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - device->set1i(name, value); - OIDN_CATCH(device) - } - - OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - device->set1i(name, value); - OIDN_CATCH(device) - } - - OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - return device->get1i(name); - OIDN_CATCH(device) - return false; - } - - OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - return device->get1i(name); - OIDN_CATCH(device) - return 0; - } - - OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - device->setErrorFunction((ErrorFunction)func, userPtr); - OIDN_CATCH(device) - } - - OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage) - { - Device* device = (Device*)hDevice; - OIDN_TRY - return (OIDNError)Device::getError(device, outMessage); - OIDN_CATCH(device) - if (outMessage) *outMessage = ""; - return OIDN_ERROR_UNKNOWN; - } - - OIDN_API void oidnCommitDevice(OIDNDevice hDevice) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - device->commit(); - OIDN_CATCH(device) - } - - OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - Ref buffer = device->newBuffer(byteSize); - return (OIDNBuffer)buffer.detach(); - OIDN_CATCH(device) - return nullptr; - } - - OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - Ref buffer = device->newBuffer(ptr, byteSize); - return (OIDNBuffer)buffer.detach(); - OIDN_CATCH(device) - return nullptr; - } - - OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer) - { - Buffer* buffer = (Buffer*)hBuffer; - retainObject(buffer); - } - - OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer) - { - Buffer* buffer = (Buffer*)hBuffer; - releaseObject(buffer); - } - - OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize) - { - Buffer* buffer = (Buffer*)hBuffer; - OIDN_TRY - checkHandle(hBuffer); - OIDN_LOCK(buffer); - return buffer->map(byteOffset, byteSize); - OIDN_CATCH(buffer) - return nullptr; - } - - OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr) - { - Buffer* buffer = (Buffer*)hBuffer; - OIDN_TRY - checkHandle(hBuffer); - OIDN_LOCK(buffer); - return buffer->unmap(mappedPtr); - OIDN_CATCH(buffer) - } - - OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type) - { - Device* device = (Device*)hDevice; - OIDN_TRY - checkHandle(hDevice); - OIDN_LOCK(device); - Ref filter = device->newFilter(type); - return (OIDNFilter)filter.detach(); - OIDN_CATCH(device) - return nullptr; - } - - OIDN_API void oidnRetainFilter(OIDNFilter hFilter) - { - Filter* filter = (Filter*)hFilter; - retainObject(filter); - } - - OIDN_API void oidnReleaseFilter(OIDNFilter hFilter) - { - Filter* filter = (Filter*)hFilter; - releaseObject(filter); - } - - OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name, - OIDNBuffer hBuffer, OIDNFormat format, - size_t width, size_t height, - size_t byteOffset, - size_t bytePixelStride, size_t byteRowStride) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - checkHandle(hBuffer); - OIDN_LOCK(filter); - Ref buffer = (Buffer*)hBuffer; - if (buffer->getDevice() != filter->getDevice()) - throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices"); - Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); - filter->setImage(name, data); - OIDN_CATCH(filter) - } - - OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name, - void* ptr, OIDNFormat format, - size_t width, size_t height, - size_t byteOffset, - size_t bytePixelStride, size_t byteRowStride) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); - filter->setImage(name, data); - OIDN_CATCH(filter) - } - - OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - filter->set1i(name, int(value)); - OIDN_CATCH(filter) - } - - OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - return filter->get1i(name); - OIDN_CATCH(filter) - return false; - } - - OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - filter->set1i(name, value); - OIDN_CATCH(filter) - } - - OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - return filter->get1i(name); - OIDN_CATCH(filter) - return 0; - } - - OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - filter->set1f(name, value); - OIDN_CATCH(filter) - } - - OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - return filter->get1f(name); - OIDN_CATCH(filter) - return 0; - } - - OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - filter->setProgressMonitorFunction(func, userPtr); - OIDN_CATCH(filter) - } - - OIDN_API void oidnCommitFilter(OIDNFilter hFilter) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - filter->commit(); - OIDN_CATCH(filter) - } - - OIDN_API void oidnExecuteFilter(OIDNFilter hFilter) - { - Filter* filter = (Filter*)hFilter; - OIDN_TRY - checkHandle(hFilter); - OIDN_LOCK(filter); - filter->execute(); - OIDN_CATCH(filter) - } - -} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.cpp b/thirdparty/oidn/core/autoencoder.cpp deleted file mode 100644 index d8da684cb..000000000 --- a/thirdparty/oidn/core/autoencoder.cpp +++ /dev/null @@ -1,535 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "autoencoder.h" - -namespace oidn { - - // -------------------------------------------------------------------------- - // AutoencoderFilter - // -------------------------------------------------------------------------- - - AutoencoderFilter::AutoencoderFilter(const Ref& device) - : Filter(device) - { - } - - void AutoencoderFilter::setImage(const std::string& name, const Image& data) - { - if (name == "color") - color = data; - else if (name == "albedo") - albedo = data; - else if (name == "normal") - normal = data; - else if (name == "output") - output = data; - - dirty = true; - } - - void AutoencoderFilter::set1i(const std::string& name, int value) - { - if (name == "hdr") - hdr = value; - else if (name == "srgb") - srgb = value; - else if (name == "maxMemoryMB") - maxMemoryMB = value; - - dirty = true; - } - - int AutoencoderFilter::get1i(const std::string& name) - { - if (name == "hdr") - return hdr; - else if (name == "srgb") - return srgb; - else if (name == "maxMemoryMB") - return maxMemoryMB; - else if (name == "alignment") - return alignment; - else if (name == "overlap") - return overlap; - else - throw Exception(Error::InvalidArgument, "invalid parameter"); - } - - void AutoencoderFilter::set1f(const std::string& name, float value) - { - if (name == "hdrScale") - hdrScale = value; - - dirty = true; - } - - float AutoencoderFilter::get1f(const std::string& name) - { - if (name == "hdrScale") - return hdrScale; - else - throw Exception(Error::InvalidArgument, "invalid parameter"); - } - - void AutoencoderFilter::commit() - { - if (!dirty) - return; - - // -- GODOT start -- - //device->executeTask([&]() - //{ - // GODOT end -- - - if (mayiuse(avx512_common)) - net = buildNet<16>(); - else - net = buildNet<8>(); - - // GODOT start -- - //}); - // GODOT end -- - - dirty = false; - } - - void AutoencoderFilter::execute() - { - if (dirty) - throw Exception(Error::InvalidOperation, "changes to the filter are not committed"); - - if (!net) - return; - // -- GODOT start -- - //device->executeTask([&]() - //{ - // -- GODOT end -- - Progress progress; - progress.func = progressFunc; - progress.userPtr = progressUserPtr; - progress.taskCount = tileCountH * tileCountW; - - // Iterate over the tiles - int tileIndex = 0; - - for (int i = 0; i < tileCountH; ++i) - { - const int h = i * (tileH - 2*overlap); // input tile position (including overlap) - const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top - const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom - const int tileH1 = min(H - h, tileH); // input tile size (including overlap) - const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size - const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer - - for (int j = 0; j < tileCountW; ++j) - { - const int w = j * (tileW - 2*overlap); // input tile position (including overlap) - const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left - const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right - const int tileW1 = min(W - w, tileW); // input tile size (including overlap) - const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size - const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer - - // Set the input tile - inputReorder->setTile(h, w, - alignOffsetH, alignOffsetW, - tileH1, tileW1); - - // Set the output tile - outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW, - h + overlapBeginH, w + overlapBeginW, - tileH2, tileW2); - - //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2); - - // Denoise the tile - net->execute(progress, tileIndex); - - // Next tile - tileIndex++; - } - } - // -- GODOT start -- - //}); - // -- GODOT end -- - } - - void AutoencoderFilter::computeTileSize() - { - const int minTileSize = 3*overlap; - const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8; - const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel; - - tileCountH = 1; - tileCountW = 1; - tileH = roundUp(H, alignment); - tileW = roundUp(W, alignment); - - // Divide the image into tiles until the tile size gets below the threshold - while (int64_t(tileH) * tileW > maxTilePixels) - { - if (tileH > minTileSize && tileH > tileW) - { - tileCountH++; - tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize); - } - else if (tileW > minTileSize) - { - tileCountW++; - tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize); - } - else - break; - } - - // Compute the final number of tiles - tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1; - tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1; - - if (device->isVerbose(2)) - { - std::cout << "Tile size : " << tileW << "x" << tileH << std::endl; - std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl; - } - } - - template - std::shared_ptr AutoencoderFilter::buildNet() - { - H = color.height; - W = color.width; - - // Configure the network - int inputC; - void* weightPtr; - - if (srgb && hdr) - throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time"); - - if (color && !albedo && !normal && weightData.hdr) - { - inputC = 3; - weightPtr = hdr ? weightData.hdr : weightData.ldr; - } - else if (color && albedo && !normal && weightData.hdr_alb) - { - inputC = 6; - weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb; - } - else if (color && albedo && normal && weightData.hdr_alb_nrm) - { - inputC = 9; - weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm; - } - else - { - throw Exception(Error::InvalidOperation, "unsupported combination of input features"); - } - - if (!output) - throw Exception(Error::InvalidOperation, "output image not specified"); - - if ((color.format != Format::Float3) - || (albedo && albedo.format != Format::Float3) - || (normal && normal.format != Format::Float3) - || (output.format != Format::Float3)) - throw Exception(Error::InvalidOperation, "unsupported image format"); - - if ((albedo && (albedo.width != W || albedo.height != H)) - || (normal && (normal.width != W || normal.height != H)) - || (output.width != W || output.height != H)) - throw Exception(Error::InvalidOperation, "image size mismatch"); - - // Compute the tile size - computeTileSize(); - - // If the image size is zero, there is nothing else to do - if (H <= 0 || W <= 0) - return nullptr; - - // Parse the weights - const auto weightMap = parseTensors(weightPtr); - - // Create the network - std::shared_ptr> net = std::make_shared>(device, weightMap); - - // Compute the tensor sizes - const auto inputDims = memory::dims({1, inputC, tileH, tileW}); - const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0 - - const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0 - const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1 - const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1 - const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0 - const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2 - const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0 - const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3 - const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0 - const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4 - const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0 - const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1 - const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4 - const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims); - const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0 - const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1 - const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3 - const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims); - const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0 - const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1 - const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2 - const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims); - const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0 - const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1 - const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1 - const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims); - const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0 - const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1 - const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0 - const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims); - const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0 - const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1 - const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0 - - const auto outputDims = memory::dims({1, 3, tileH, tileW}); - - // Allocate two temporary ping-pong buffers to decrease memory usage - const auto temp0Dims = getMaxTensorDims({ - conv1Dims, - conv2Dims, - conv3Dims, - conv4Dims, - conv5Dims, - conv6Dims, - conv7Dims, - conv8Dims, - conv9Dims, - conv10Dims, - conv11Dims - }); - - const auto temp1Dims = getMaxTensorDims({ - conv1bDims, - pool5Dims, - conv6bDims, - conv7bDims, - conv8bDims, - conv9bDims, - conv10bDims, - }); - - auto temp0 = net->allocTensor(temp0Dims); - auto temp1 = net->allocTensor(temp1Dims); - - // Allocate enough memory to hold the concat outputs. Then use the first - // half to hold the previous conv output and the second half to hold the - // pool/orig image output. This works because everything is C dimension - // outermost, padded to K floats, and all the concats are on the C dimension. - auto concat0Dst = net->allocTensor(concat0Dims); - auto concat1Dst = net->allocTensor(concat1Dims); - auto concat2Dst = net->allocTensor(concat2Dims); - auto concat3Dst = net->allocTensor(concat3Dims); - auto concat4Dst = net->allocTensor(concat4Dims); - - // Transfer function - std::shared_ptr transferFunc = makeTransferFunc(); - - // Autoexposure - if (auto tf = std::dynamic_pointer_cast(transferFunc)) - { - if (isnan(hdrScale)) - net->addAutoexposure(color, tf); - else - tf->setExposure(hdrScale); - } - - // Input reorder - auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims); - inputReorder = net->addInputReorder(color, albedo, normal, - transferFunc, - alignment, inputReorderDst); - - // conv1 - auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0); - - // conv1b - auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1); - - // pool1 - // Adjust pointer for pool1 to eliminate concat1 - auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims); - auto pool1 = net->addPool(conv1b->getDst(), pool1Dst); - - // conv2 - auto conv2 = net->addConv("conv2", pool1->getDst(), temp0); - - // pool2 - // Adjust pointer for pool2 to eliminate concat2 - auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims); - auto pool2 = net->addPool(conv2->getDst(), pool2Dst); - - // conv3 - auto conv3 = net->addConv("conv3", pool2->getDst(), temp0); - - // pool3 - // Adjust pointer for pool3 to eliminate concat3 - auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims); - auto pool3 = net->addPool(conv3->getDst(), pool3Dst); - - // conv4 - auto conv4 = net->addConv("conv4", pool3->getDst(), temp0); - - // pool4 - // Adjust pointer for pool4 to eliminate concat4 - auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims); - auto pool4 = net->addPool(conv4->getDst(), pool4Dst); - - // conv5 - auto conv5 = net->addConv("conv5", pool4->getDst(), temp0); - - // pool5 - auto pool5 = net->addPool(conv5->getDst(), temp1); - - // upsample4 - auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst); - auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst); - - // conv6 - auto conv6 = net->addConv("conv6", concat4Dst, temp0); - - // conv6b - auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1); - - // upsample3 - auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst); - auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst); - - // conv7 - auto conv7 = net->addConv("conv7", concat3Dst, temp0); - - // conv7b - auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1); - - // upsample2 - auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst); - auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst); - - // conv8 - auto conv8 = net->addConv("conv8", concat2Dst, temp0); - - // conv8b - auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1); - - // upsample1 - auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst); - auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst); - - // conv9 - auto conv9 = net->addConv("conv9", concat1Dst, temp0); - - // conv9b - auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1); - - // upsample0 - auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst); - auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst); - - // conv10 - auto conv10 = net->addConv("conv10", concat0Dst, temp0); - - // conv10b - auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1); - - // conv11 - auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */); - - // Output reorder - outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output); - - net->finalize(); - return net; - } - - std::shared_ptr AutoencoderFilter::makeTransferFunc() - { - if (hdr) - return std::make_shared(); - else if (srgb) - return std::make_shared(); - else - return std::make_shared(); - } - -// -- GODOT start -- -// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -#if 0 -// -- GODOT end -- - - // -------------------------------------------------------------------------- - // RTFilter - // -------------------------------------------------------------------------- - - namespace weights - { - // LDR - extern unsigned char rt_ldr[]; // color - extern unsigned char rt_ldr_alb[]; // color, albedo - extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal - - // HDR - extern unsigned char rt_hdr[]; // color - extern unsigned char rt_hdr_alb[]; // color, albedo - extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal - } - - RTFilter::RTFilter(const Ref& device) - : AutoencoderFilter(device) - { - weightData.ldr = weights::rt_ldr; - weightData.ldr_alb = weights::rt_ldr_alb; - weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm; - weightData.hdr = weights::rt_hdr; - weightData.hdr_alb = weights::rt_hdr_alb; - weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; - } -// -- GODOT start -- -#endif -// -- GODOT end -- - - // -------------------------------------------------------------------------- - // RTLightmapFilter - // -------------------------------------------------------------------------- - - namespace weights - { - // HDR - extern unsigned char rtlightmap_hdr[]; // color - } - - RTLightmapFilter::RTLightmapFilter(const Ref& device) - : AutoencoderFilter(device) - { - weightData.hdr = weights::rtlightmap_hdr; - - hdr = true; - } - - std::shared_ptr RTLightmapFilter::makeTransferFunc() - { - return std::make_shared(); - } - -} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.h b/thirdparty/oidn/core/autoencoder.h deleted file mode 100644 index 98b610844..000000000 --- a/thirdparty/oidn/core/autoencoder.h +++ /dev/null @@ -1,120 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "filter.h" -#include "network.h" -#include "transfer_function.h" - -namespace oidn { - - // -------------------------------------------------------------------------- - // AutoencoderFilter - Direct-predicting autoencoder - // -------------------------------------------------------------------------- - - class AutoencoderFilter : public Filter - { - protected: - static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary) - static constexpr int receptiveField = 222; // receptive field in pixels - static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels - - static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage - static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8 - static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16 - - Image color; - Image albedo; - Image normal; - Image output; - bool hdr = false; - float hdrScale = std::numeric_limits::quiet_NaN(); - bool srgb = false; - int maxMemoryMB = 6000; // approximate maximum memory usage in MBs - - int H = 0; // image height - int W = 0; // image width - int tileH = 0; // tile height - int tileW = 0; // tile width - int tileCountH = 1; // number of tiles in H dimension - int tileCountW = 1; // number of tiles in W dimension - - std::shared_ptr net; - std::shared_ptr inputReorder; - std::shared_ptr outputReorder; - - struct - { - void* ldr = nullptr; - void* ldr_alb = nullptr; - void* ldr_alb_nrm = nullptr; - void* hdr = nullptr; - void* hdr_alb = nullptr; - void* hdr_alb_nrm = nullptr; - } weightData; - - explicit AutoencoderFilter(const Ref& device); - virtual std::shared_ptr makeTransferFunc(); - - public: - void setImage(const std::string& name, const Image& data) override; - void set1i(const std::string& name, int value) override; - int get1i(const std::string& name) override; - void set1f(const std::string& name, float value) override; - float get1f(const std::string& name) override; - - void commit() override; - void execute() override; - - private: - void computeTileSize(); - - template - std::shared_ptr buildNet(); - - bool isCommitted() const { return bool(net); } - }; - - // -------------------------------------------------------------------------- - // RTFilter - Generic ray tracing denoiser - // -------------------------------------------------------------------------- - -// -- GODOT start -- -// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -#if 0 -// -- GODOT end -- - class RTFilter : public AutoencoderFilter - { - public: - explicit RTFilter(const Ref& device); - }; -// -- GODOT start -- -#endif -// -- GODOT end -- - - // -------------------------------------------------------------------------- - // RTLightmapFilter - Ray traced lightmap denoiser - // -------------------------------------------------------------------------- - - class RTLightmapFilter : public AutoencoderFilter - { - public: - explicit RTLightmapFilter(const Ref& device); - std::shared_ptr makeTransferFunc() override; - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/buffer.h b/thirdparty/oidn/core/buffer.h deleted file mode 100644 index b95109152..000000000 --- a/thirdparty/oidn/core/buffer.h +++ /dev/null @@ -1,75 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common.h" -#include "device.h" - -namespace oidn { - - class Device; - - // Buffer which may or may not own its data - class Buffer : public RefCount - { - private: - char* ptr; - size_t byteSize; - bool shared; - Ref device; - - public: - __forceinline Buffer(const Ref& device, size_t size) - : ptr((char*)alignedMalloc(size, 64)), - byteSize(size), - shared(false), - device(device) {} - - __forceinline Buffer(const Ref& device, void* data, size_t size) - : ptr((char*)data), - byteSize(size), - shared(true), - device(device) - { - if (data == nullptr) - throw Exception(Error::InvalidArgument, "buffer pointer null"); - } - - __forceinline ~Buffer() - { - if (!shared) - alignedFree(ptr); - } - - __forceinline char* data() { return ptr; } - __forceinline const char* data() const { return ptr; } - __forceinline size_t size() const { return byteSize; } - - void* map(size_t offset, size_t size) - { - if (offset + size > byteSize) - throw Exception(Error::InvalidArgument, "buffer region out of range"); - - return ptr + offset; - } - - void unmap(void* mappedPtr) {} - - Device* getDevice() { return device.get(); } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h deleted file mode 100644 index a35dd908b..000000000 --- a/thirdparty/oidn/core/common.h +++ /dev/null @@ -1,136 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common/platform.h" - -#include "mkl-dnn/include/mkldnn.hpp" -#include "mkl-dnn/include/mkldnn_debug.h" -#include "mkl-dnn/src/common/mkldnn_thread.hpp" -#include "mkl-dnn/src/common/type_helpers.hpp" -#include "mkl-dnn/src/cpu/jit_generator.hpp" - -#include "common/ref.h" -#include "common/exception.h" -#include "common/thread.h" -// -- GODOT start -- -//#include "common/tasking.h" -// -- GODOT end -- -#include "math.h" - -namespace oidn { - - using namespace mkldnn; - using namespace mkldnn::impl::cpu; - using mkldnn::impl::parallel_nd; - using mkldnn::impl::memory_desc_matches_tag; - - - inline size_t getFormatBytes(Format format) - { - switch (format) - { - case Format::Undefined: return 1; - case Format::Float: return sizeof(float); - case Format::Float2: return sizeof(float)*2; - case Format::Float3: return sizeof(float)*3; - case Format::Float4: return sizeof(float)*4; - } - assert(0); - return 0; - } - - - inline memory::dims getTensorDims(const std::shared_ptr& mem) - { - const mkldnn_memory_desc_t& desc = mem->get_desc().data; - return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]); - } - - inline memory::data_type getTensorType(const std::shared_ptr& mem) - { - const mkldnn_memory_desc_t& desc = mem->get_desc().data; - return memory::data_type(desc.data_type); - } - - // Returns the number of values in a tensor - inline size_t getTensorSize(const memory::dims& dims) - { - size_t res = 1; - for (int i = 0; i < (int)dims.size(); ++i) - res *= dims[i]; - return res; - } - - inline memory::dims getMaxTensorDims(const std::vector& dims) - { - memory::dims result; - size_t maxSize = 0; - - for (const auto& d : dims) - { - const size_t size = getTensorSize(d); - if (size > maxSize) - { - result = d; - maxSize = size; - } - } - - return result; - } - - inline size_t getTensorSize(const std::shared_ptr& mem) - { - return getTensorSize(getTensorDims(mem)); - } - - - template - inline int getPadded(int dim) - { - return (dim + (K-1)) & ~(K-1); - } - - template - inline memory::dims getPadded_nchw(const memory::dims& dims) - { - assert(dims.size() == 4); - memory::dims padDims = dims; - padDims[1] = getPadded(dims[1]); // pad C - return padDims; - } - - - template - struct BlockedFormat; - - template<> - struct BlockedFormat<8> - { - static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c; - static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o; - }; - - template<> - struct BlockedFormat<16> - { - static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c; - static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o; - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/device.cpp b/thirdparty/oidn/core/device.cpp deleted file mode 100644 index 3cd658b9c..000000000 --- a/thirdparty/oidn/core/device.cpp +++ /dev/null @@ -1,238 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "device.h" -#include "autoencoder.h" - -namespace oidn { - - thread_local Device::ErrorState Device::globalError; - - Device::Device() - { - if (!mayiuse(sse41)) - throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum"); - } - - Device::~Device() - { - // -- GODOT start -- - //observer.reset(); - // -- GODOT end -- - } - - void Device::setError(Device* device, Error code, const std::string& message) - { - // Update the stored error only if the previous error was queried - if (device) - { - ErrorState& curError = device->error.get(); - - if (curError.code == Error::None) - { - curError.code = code; - curError.message = message; - } - - // Print the error message in verbose mode - if (device->isVerbose()) - std::cerr << "Error: " << message << std::endl; - - // Call the error callback function - ErrorFunction errorFunc; - void* errorUserPtr; - - { - std::lock_guard lock(device->mutex); - errorFunc = device->errorFunc; - errorUserPtr = device->errorUserPtr; - } - - if (errorFunc) - errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str()); - } - else - { - if (globalError.code == Error::None) - { - globalError.code = code; - globalError.message = message; - } - } - } - - Error Device::getError(Device* device, const char** outMessage) - { - // Return and clear the stored error code, but keep the error message so pointers to it will - // remain valid until the next getError call - if (device) - { - ErrorState& curError = device->error.get(); - const Error code = curError.code; - if (outMessage) - *outMessage = (code == Error::None) ? nullptr : curError.message.c_str(); - curError.code = Error::None; - return code; - } - else - { - const Error code = globalError.code; - if (outMessage) - *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str(); - globalError.code = Error::None; - return code; - } - } - - void Device::setErrorFunction(ErrorFunction func, void* userPtr) - { - errorFunc = func; - errorUserPtr = userPtr; - } - - int Device::get1i(const std::string& name) - { - if (name == "numThreads") - return numThreads; - else if (name == "setAffinity") - return setAffinity; - else if (name == "verbose") - return verbose; - else if (name == "version") - return OIDN_VERSION; - else if (name == "versionMajor") - return OIDN_VERSION_MAJOR; - else if (name == "versionMinor") - return OIDN_VERSION_MINOR; - else if (name == "versionPatch") - return OIDN_VERSION_PATCH; - else - throw Exception(Error::InvalidArgument, "invalid parameter"); - } - - void Device::set1i(const std::string& name, int value) - { - if (name == "numThreads") - numThreads = value; - else if (name == "setAffinity") - setAffinity = value; - else if (name == "verbose") - { - verbose = value; - error.verbose = value; - } - - dirty = true; - } - - void Device::commit() - { - if (isCommitted()) - throw Exception(Error::InvalidOperation, "device can be committed only once"); - - // -- GODOT start -- - #if 0 - // -- GODOT end -- - // Get the optimal thread affinities - if (setAffinity) - { - affinity = std::make_shared(1, verbose); // one thread per core - if (affinity->getNumThreads() == 0) - affinity.reset(); - } - - // Create the task arena - const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency(); - numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads; - arena = std::make_shared(numThreads); - - // Automatically set the thread affinities - if (affinity) - observer = std::make_shared(affinity, *arena); - // -- GODOT start -- - #endif - numThreads = 1; - // -- GODOT end -- - dirty = false; - - if (isVerbose()) - print(); - } - - void Device::checkCommitted() - { - if (dirty) - throw Exception(Error::InvalidOperation, "changes to the device are not committed"); - } - - Ref Device::newBuffer(size_t byteSize) - { - checkCommitted(); - return makeRef(Ref(this), byteSize); - } - - Ref Device::newBuffer(void* ptr, size_t byteSize) - { - checkCommitted(); - return makeRef(Ref(this), ptr, byteSize); - } - - Ref Device::newFilter(const std::string& type) - { - checkCommitted(); - - if (isVerbose()) - std::cout << "Filter: " << type << std::endl; - - Ref filter; - -// -- GODOT start -- -// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -#if 0 -// -- GODOT end -- - if (type == "RT") - filter = makeRef(Ref(this)); -// -- GODOT start -- -// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -#endif - if (type == "RTLightmap") -// -- GODOT end -- - filter = makeRef(Ref(this)); - else - throw Exception(Error::InvalidArgument, "unknown filter type"); - - return filter; - } - - void Device::print() - { - std::cout << std::endl; - - std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl; - std::cout << " Compiler: " << getCompilerName() << std::endl; - std::cout << " Build : " << getBuildName() << std::endl; - std::cout << " Platform: " << getPlatformName() << std::endl; - -// -- GODOT start -- -// std::cout << " Tasking :"; -// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR; -// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version(); -// std::cout << std::endl; -// -- GODOT end -- - std::cout << std::endl; - } - -} // namespace oidn diff --git a/thirdparty/oidn/core/device.h b/thirdparty/oidn/core/device.h deleted file mode 100644 index d9cfd8541..000000000 --- a/thirdparty/oidn/core/device.h +++ /dev/null @@ -1,102 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common.h" - -namespace oidn { - - class Buffer; - class Filter; - - class Device : public RefCount, public Verbose - { - private: - // Thread-safety - std::mutex mutex; - - // Error handling - struct ErrorState - { - Error code = Error::None; - std::string message; - }; - - static thread_local ErrorState globalError; - ThreadLocal error; - ErrorFunction errorFunc = nullptr; - void* errorUserPtr = nullptr; - -// -- GODOT start -- -// // Tasking -// std::shared_ptr arena; -// std::shared_ptr observer; -// std::shared_ptr affinity; -// -- GODOT end -- - - // Parameters - int numThreads = 0; // autodetect by default - bool setAffinity = true; - - bool dirty = true; - - public: - Device(); - ~Device(); - - static void setError(Device* device, Error code, const std::string& message); - static Error getError(Device* device, const char** outMessage); - - void setErrorFunction(ErrorFunction func, void* userPtr); - - int get1i(const std::string& name); - void set1i(const std::string& name, int value); - - void commit(); - -// -- GODOT start -- -// template -// void executeTask(F& f) -// { -// arena->execute(f); -// } - -// template -// void executeTask(const F& f) -// { -// arena->execute(f); -// } -// -- GODOT end -- - - Ref newBuffer(size_t byteSize); - Ref newBuffer(void* ptr, size_t byteSize); - Ref newFilter(const std::string& type); - - __forceinline Device* getDevice() { return this; } - __forceinline std::mutex& getMutex() { return mutex; } - - private: -// -- GODOT start -- - //bool isCommitted() const { return bool(arena); } - bool isCommitted() const { return false; } -// -- GODOT end -- - void checkCommitted(); - - void print(); - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/filter.cpp b/thirdparty/oidn/core/filter.cpp deleted file mode 100644 index ec1f10af8..000000000 --- a/thirdparty/oidn/core/filter.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "filter.h" - -namespace oidn { - - void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr) - { - progressFunc = func; - progressUserPtr = userPtr; - } - -} // namespace oidn diff --git a/thirdparty/oidn/core/filter.h b/thirdparty/oidn/core/filter.h deleted file mode 100644 index 935fa202f..000000000 --- a/thirdparty/oidn/core/filter.h +++ /dev/null @@ -1,52 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common.h" -#include "device.h" -#include "image.h" - -namespace oidn { - - class Filter : public RefCount - { - protected: - Ref device; - - ProgressMonitorFunction progressFunc = nullptr; - void* progressUserPtr = nullptr; - - bool dirty = true; - - public: - explicit Filter(const Ref& device) : device(device) {} - - virtual void setImage(const std::string& name, const Image& data) = 0; - virtual void set1i(const std::string& name, int value) = 0; - virtual int get1i(const std::string& name) = 0; - virtual void set1f(const std::string& name, float value) = 0; - virtual float get1f(const std::string& name) = 0; - - void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr); - - virtual void commit() = 0; - virtual void execute() = 0; - - Device* getDevice() { return device.get(); } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/image.h b/thirdparty/oidn/core/image.h deleted file mode 100644 index 748f49c4e..000000000 --- a/thirdparty/oidn/core/image.h +++ /dev/null @@ -1,111 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common.h" -#include "buffer.h" - -namespace oidn { - - struct Image - { - static constexpr int maxSize = 65536; - - char* ptr; // pointer to the first pixel - int width; // width in number of pixels - int height; // height in number of pixels - size_t bytePixelStride; // pixel stride in number of *bytes* - size_t rowStride; // row stride in number of *pixel strides* - Format format; // pixel format - Ref buffer; // buffer containing the image data - - Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {} - - Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) - { - if (ptr == nullptr) - throw Exception(Error::InvalidArgument, "buffer pointer null"); - - init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); - } - - Image(const Ref& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) - { - init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); - - if (byteOffset + height * rowStride * bytePixelStride > buffer->size()) - throw Exception(Error::InvalidArgument, "buffer region out of range"); - } - - void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride) - { - assert(width >= 0); - assert(height >= 0); - if (width > maxSize || height > maxSize) - throw Exception(Error::InvalidArgument, "image size too large"); - - this->ptr = ptr; - this->width = width; - this->height = height; - - const size_t pixelSize = getFormatBytes(format); - if (inBytePixelStride != 0) - { - if (inBytePixelStride < pixelSize) - throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size"); - - this->bytePixelStride = inBytePixelStride; - } - else - { - this->bytePixelStride = pixelSize; - } - - if (inByteRowStride != 0) - { - if (inByteRowStride < width * this->bytePixelStride) - throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride"); - if (inByteRowStride % this->bytePixelStride != 0) - throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride"); - - this->rowStride = inByteRowStride / this->bytePixelStride; - } - else - { - this->rowStride = width; - } - - this->format = format; - } - - __forceinline char* get(int y, int x) - { - return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); - } - - __forceinline const char* get(int y, int x) const - { - return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); - } - - operator bool() const - { - return ptr != nullptr; - } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/input_reorder.h b/thirdparty/oidn/core/input_reorder.h deleted file mode 100644 index 966856afe..000000000 --- a/thirdparty/oidn/core/input_reorder.h +++ /dev/null @@ -1,232 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "node.h" -#include "image.h" - -namespace oidn { - - // Input reorder node - template - class InputReorderNode : public Node - { - private: - // Source - Image color; - Image albedo; - Image normal; - - // Destination - std::shared_ptr dst; - float* dstPtr; - int C2; - int H2; - int W2; - - // Tile - int h1Begin; - int w1Begin; - int h2Begin; - int w2Begin; - int H; - int W; - - std::shared_ptr transferFunc; - - public: - InputReorderNode(const Image& color, - const Image& albedo, - const Image& normal, - const std::shared_ptr& dst, - const std::shared_ptr& transferFunc) - : color(color), albedo(albedo), normal(normal), - dst(dst), - h1Begin(0), w1Begin(0), - H(color.height), W(color.width), - transferFunc(transferFunc) - { - const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; - assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); - assert(dstDesc.ndims == 4); - assert(dstDesc.data_type == memory::data_type::f32); - assert(dstDesc.dims[0] == 1); - //assert(dstDesc.dims[1] >= getPadded(C1)); - - dstPtr = (float*)dst->get_data_handle(); - C2 = dstDesc.dims[1]; - H2 = dstDesc.dims[2]; - W2 = dstDesc.dims[3]; - } - - void setTile(int h1, int w1, int h2, int w2, int H, int W) override - { - h1Begin = h1; - w1Begin = w1; - h2Begin = h2; - w2Begin = w2; - this->H = H; - this->W = W; - } - - void execute(stream& sm) override - { - assert(H + h1Begin <= color.height); - assert(W + w1Begin <= color.width); - assert(H + h2Begin <= H2); - assert(W + w2Begin <= W2); - - parallel_nd(H2, [&](int h2) - { - const int h = h2 - h2Begin; - - if (h >= 0 && h < H) - { - const int h1 = h + h1Begin; - - // Zero pad - for (int w2 = 0; w2 < w2Begin; ++w2) - { - int c = 0; - while (c < C2) - store(h2, w2, c, 0.f); - } - - // Reorder - for (int w = 0; w < W; ++w) - { - const int w1 = w + w1Begin; - const int w2 = w + w2Begin; - - int c = 0; - storeColor(h2, w2, c, (float*)color.get(h1, w1)); - if (albedo) - storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1)); - if (normal) - storeNormal(h2, w2, c, (float*)normal.get(h1, w1)); - while (c < C2) - store(h2, w2, c, 0.f); - } - - // Zero pad - for (int w2 = W + w2Begin; w2 < W2; ++w2) - { - int c = 0; - while (c < C2) - store(h2, w2, c, 0.f); - } - } - else - { - // Zero pad - for (int w2 = 0; w2 < W2; ++w2) - { - int c = 0; - while (c < C2) - store(h2, w2, c, 0.f); - } - } - }); - } - - std::shared_ptr getDst() const override { return dst; } - - private: - // Stores a single value - __forceinline void store(int h, int w, int& c, float value) - { - // Destination is in nChwKc format - float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K); - *dst_c = value; - c++; - } - - // Stores a color - __forceinline void storeColor(int h, int w, int& c, const float* values) - { - #pragma unroll - for (int i = 0; i < 3; ++i) - { - // Load the value - float x = values[i]; - - // Sanitize the value - x = maxSafe(x, 0.f); - - // Apply the transfer function - x = transferFunc->forward(x); - - // Store the value - store(h, w, c, x); - } - } - - // Stores an albedo - __forceinline void storeAlbedo(int h, int w, int& c, const float* values) - { - #pragma unroll - for (int i = 0; i < 3; ++i) - { - // Load the value - float x = values[i]; - - // Sanitize the value - x = clampSafe(x, 0.f, 1.f); - - // Store the value - store(h, w, c, x); - } - } - - // Stores a normal - __forceinline void storeNormal(int h, int w, int& c, const float* values) - { - // Load the normal - float x = values[0]; - float y = values[1]; - float z = values[2]; - - // Compute the length of the normal - const float lengthSqr = sqr(x) + sqr(y) + sqr(z); - - // Normalize the normal and transform it to [0..1] - if (isfinite(lengthSqr)) - { - const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f; - - const float scale = invLength * 0.5f; - const float offset = 0.5f; - - x = x * scale + offset; - y = y * scale + offset; - z = z * scale + offset; - } - else - { - x = 0.f; - y = 0.f; - z = 0.f; - } - - // Store the normal - store(h, w, c, x); - store(h, w, c, y); - store(h, w, c, z); - } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/math.h b/thirdparty/oidn/core/math.h deleted file mode 100644 index a844ef0d1..000000000 --- a/thirdparty/oidn/core/math.h +++ /dev/null @@ -1,78 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common/platform.h" - -namespace oidn { - - constexpr float minVectorLength = 1e-10f; - constexpr float minVectorLengthSqr = minVectorLength * minVectorLength; - - using std::log; - using std::log2; - using std::exp; - using std::exp2; - using std::pow; - using std::isfinite; - using std::isnan; - - __forceinline float sqr(float x) - { - return x * x; - } - - __forceinline float rcp(float x) - { - __m128 r = _mm_rcp_ss(_mm_set_ss(x)); - return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x)))); - } - - __forceinline float rsqrt(float x) - { - __m128 r = _mm_rsqrt_ss(_mm_set_ss(x)); - return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r), - _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r)))); - } - - __forceinline float maxSafe(float value, float minValue) - { - return isfinite(value) ? max(value, minValue) : minValue; - } - - __forceinline float clampSafe(float value, float minValue, float maxValue) - { - return isfinite(value) ? clamp(value, minValue, maxValue) : minValue; - } - - // Returns ceil(a / b) for non-negative integers - template - __forceinline constexpr Int ceilDiv(Int a, Int b) - { - //assert(a >= 0); - //assert(b > 0); - return (a + b - 1) / b; - } - - // Returns a rounded up to multiple of b - template - __forceinline constexpr Int roundUp(Int a, Int b) - { - return ceilDiv(a, b) * b; - } - -} // namespace oidn diff --git a/thirdparty/oidn/core/network.cpp b/thirdparty/oidn/core/network.cpp deleted file mode 100644 index ed8328c95..000000000 --- a/thirdparty/oidn/core/network.cpp +++ /dev/null @@ -1,436 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "upsample.h" -#include "weights_reorder.h" -#include "network.h" -// -- GODOT start -- -#include -// -- GODOT end -- - -namespace oidn { - - template - Network::Network(const Ref& device, const std::map& weightMap) - : device(device), - eng(engine::cpu, 0), - sm(eng), - weightMap(weightMap) - { - } - - template - void Network::execute(const Progress& progress, int taskIndex) - { - if (progress.func) - { - const double value = double(taskIndex) / double(progress.taskCount); - if (!progress.func(progress.userPtr, value)) - throw Exception(Error::Cancelled, "execution was cancelled"); - } - - for (size_t i = 0; i < nodes.size(); ++i) - { - nodes[i]->execute(sm); - - if (progress.func) - { - const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount); - if (!progress.func(progress.userPtr, value)) - throw Exception(Error::Cancelled, "execution was cancelled"); - } - } - } - - template - std::shared_ptr Network::allocTensor(const memory::dims& dims, - memory::format_tag format, - void* data) - { - if (format == memory::format_tag::any) - { - if (dims.size() == 4) - format = BlockedFormat::nChwKc; - else if (dims.size() == 1) - format = memory::format_tag::x; - else - assert(0); - } - memory::desc desc(dims, memory::data_type::f32, format); - if (data == nullptr) - { - const size_t bytes = getTensorSize(dims) * sizeof(float); - if (format == BlockedFormat::nChwKc) - activationAllocBytes += bytes; - totalAllocBytes += bytes; - - return std::make_shared(desc, eng); - } - else - { - return std::make_shared(desc, eng, data); - } - } - - template - std::shared_ptr Network::castTensor(const memory::dims& dims, - const std::shared_ptr& src, - size_t srcOffset, - memory::format_tag format) - { - const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; - MAYBE_UNUSED(srcDesc); - assert(srcDesc.data_type == memory::data_type::f32); - assert(getTensorSize(src) >= srcOffset + getTensorSize(dims)); - - if (format == memory::format_tag::any) - { - if (dims.size() == 4) - format = BlockedFormat::nChwKc; - else if (dims.size() == 1) - format = memory::format_tag::x; - else - assert(0); - } - memory::desc desc(dims, memory::data_type::f32, format); - float* srcPtr = (float*)src->get_data_handle() + srcOffset; - return std::make_shared(desc, eng, srcPtr); - } - - template - std::shared_ptr Network::castTensor(const memory::dims& dims, - const std::shared_ptr& src, - const memory::dims& srcOffset) - { - return castTensor(dims, src, getTensorSize(srcOffset)); - } - - template - void Network::zeroTensor(const std::shared_ptr& dst) - { - assert(getTensorType(dst) == memory::data_type::f32); - memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float)); - } - - template - memory::dims Network::getInputReorderDims(const memory::dims& srcDims, int alignment) - { - memory::dims dstDims = srcDims; - dstDims[1] = getPadded(srcDims[1]); // round up C - dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H - dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W - return dstDims; - } - - template - std::shared_ptr Network::addInputReorder(const Image& color, - const Image& albedo, - const Image& normal, - const std::shared_ptr& transferFunc, - int alignment, - const std::shared_ptr& userDst) - { - assert(color); - int inputC = 3; - if (albedo) inputC += 3; - if (normal) inputC += 3; - - memory::dims srcDims = {1, inputC, color.height, color.width}; - memory::dims dstDims = getInputReorderDims(srcDims, alignment); - - // Allocate padded memory - auto dst = userDst; - if (!dst) - dst = allocTensor(dstDims); - - // Push node - std::shared_ptr node; - - if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(color, albedo, normal, dst, tf); - else if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(color, albedo, normal, dst, tf); - else if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(color, albedo, normal, dst, tf); - else if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(color, albedo, normal, dst, tf); - else - assert(0); - - nodes.push_back(node); - return node; - } - - template - std::shared_ptr Network::addOutputReorder(const std::shared_ptr& src, - const std::shared_ptr& transferFunc, - const Image& output) - { - memory::dims srcDims = getTensorDims(src); - assert(srcDims[1] == K); - - // Push node - std::shared_ptr node; - - if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(src, output, tf); - else if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(src, output, tf); - else if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(src, output, tf); - else if (auto tf = std::dynamic_pointer_cast(transferFunc)) - node = std::make_shared>(src, output, tf); - else - assert(0); - - nodes.push_back(node); - return node; - } - - template - memory::dims Network::getConvDims(const std::string& name, const memory::dims& srcDims) - { - auto b = weightMap[name + "/b"]; - memory::dims dstDims = srcDims; - dstDims[1] = getPadded(b.dims[0]); // dstDims[C] = getPadded(OC) - return dstDims; - } - - template - std::shared_ptr Network::addConv(const std::string& name, - const std::shared_ptr& src, - const std::shared_ptr& userDst, - bool relu) - { - const memory::dims strides = {1, 1}; - const memory::dims padding = {1, 1}; - - memory::dims srcDims = getTensorDims(src); - - // Get the weights - const auto& W = weightMap[name + "/W"]; - if (W.ndims() != 4 || W.format != "oihw") - throw Exception(Error::InvalidOperation, "invalid convolution weights"); - memory::dims weightsDims = W.dims; - auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data); - - // Pad the weights - memory::dims weightsPadDims = weightsDims; - weightsPadDims[1] = getPadded(weightsDims[1]); // IC - weightsPadDims[0] = getPadded(weightsDims[0]); // OC - assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC] - auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw); - WeightsReorderNode(userWeights, weightsPad).execute(sm); - - // Get the biases - const auto& b = weightMap[name + "/b"]; - if (b.ndims() != 1) - throw Exception(Error::InvalidOperation, "invalid convolution biases"); - memory::dims biasDims = b.dims; - - // Copy/pad the biases - memory::dims biasPadDims = {getPadded(biasDims[0])}; - auto bias = allocTensor(biasPadDims); - if (biasDims[0] != biasPadDims[0]) - memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float)); - memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float)); - - // Allocate memory for destination - memory::dims dstDims = srcDims; - dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC] - - std::shared_ptr dst; - if (!userDst) - dst = allocTensor(dstDims); - else if (getTensorDims(userDst) == dstDims) - dst = userDst; - else - dst = castTensor(dstDims, userDst); - - // Create a convolution - // Let the convolution primitive choose the weights format - auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any); - - auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct; - auto convDesc = convolution_forward::desc( - prop_kind::forward_inference, convAlgo, - src->get_desc(), - weightsDesc, - bias->get_desc(), - dst->get_desc(), - strides, padding, padding, padding_kind::zero); - - // Incorporate relu - mkldnn::primitive_attr convAttr; - if (relu) - { - mkldnn::post_ops ops; - ops.append_eltwise( - 1.f, // scale factor, not used - algorithm::eltwise_relu, - 0.f, // max with - 0.f // unused - ); - convAttr.set_post_ops(ops); - } - convAttr.set_scratchpad_mode(scratchpad_mode_user); - - auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng); - - // Reorder the weights to the final format, if necessary - auto weights = weightsPad; - if (convPrimDesc.weights_desc() != weightsPad->get_desc()) - { - weights = std::make_shared(convPrimDesc.weights_desc(), eng); - ReorderNode(weightsPad, weights).execute(sm); - } - - // Create convolution node and add it to the net - auto node = std::make_shared(convPrimDesc, src, weights, bias, dst); - nodes.push_back(node); - return node; - } - - template - memory::dims Network::getPoolDims(const memory::dims& srcDims) - { - memory::dims dstDims = srcDims; - dstDims[2] /= 2; // H/2 - dstDims[3] /= 2; // W/2 - return dstDims; - } - - template - std::shared_ptr Network::addPool(const std::shared_ptr& src, - const std::shared_ptr& userDst) - { - const memory::dims kernel = {2, 2}; - const memory::dims strides = {2, 2}; - const memory::dims padding = {0, 0}; - - memory::dims srcDims = getTensorDims(src); - memory::dims dstDims = getPoolDims(srcDims); - - std::shared_ptr dst; - if (!userDst) - dst = allocTensor(dstDims); - else if (getTensorDims(userDst) == dstDims) - dst = userDst; - else - dst = castTensor(dstDims, userDst); - - auto poolDesc = pooling_forward::desc( - prop_kind::forward_inference, pooling_max, - src->get_desc(), - dst->get_desc(), - strides, kernel, padding, padding, padding_kind::zero); - - mkldnn::primitive_attr poolAttr; - poolAttr.set_scratchpad_mode(scratchpad_mode_user); - - auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng); - - auto node = std::make_shared(poolPrimDesc, src, dst); - nodes.push_back(node); - return node; - } - - template - memory::dims Network::getUpsampleDims(const memory::dims& srcDims) - { - memory::dims dstDims = srcDims; - dstDims[2] *= 2; // H*2 - dstDims[3] *= 2; // W*2 - return dstDims; - } - - template - std::shared_ptr Network::addUpsample(const std::shared_ptr& src, - const std::shared_ptr& userDst) - { - memory::dims srcDims = getTensorDims(src); - memory::dims dstDims = getUpsampleDims(srcDims); - - std::shared_ptr dst; - if (!userDst) - dst = allocTensor(dstDims); - else if (getTensorDims(userDst) == dstDims) - dst = userDst; - else - dst = castTensor(dstDims, userDst); - - // Create upsampling node and add it to net - auto node = std::make_shared>(src, dst); - nodes.push_back(node); - return node; - } - - template - memory::dims Network::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims) - { - assert(src1Dims[0] == src2Dims[0]); // N - assert(src1Dims[2] == src2Dims[2]); // H - assert(src1Dims[3] == src2Dims[3]); // W - - memory::dims dstDims = src1Dims; - dstDims[1] += src2Dims[1]; // C - return dstDims; - } - - template - std::shared_ptr Network::addAutoexposure(const Image& color, - const std::shared_ptr& transferFunc) - { - auto node = std::make_shared(color, transferFunc); - nodes.push_back(node); - return node; - } - - template - void Network::finalize() - { - // Compute the size of the scratchpad - size_t scratchpadSize = 0; - for (const auto& node : nodes) - scratchpadSize = max(scratchpadSize, node->getScratchpadSize()); - - // Allocate the scratchpad - memory::dims scratchpadDims = { memory::dim(scratchpadSize) }; - memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x); - auto scratchpad = std::make_shared(scratchpadDesc, eng); - activationAllocBytes += scratchpadSize; - totalAllocBytes += scratchpadSize; - - // Set the scratchpad for the nodes - for (auto& node : nodes) - node->setScratchpad(scratchpad); - - // Free the weights - weightMap.clear(); - - // Print statistics - if (device->isVerbose(2)) - { - std::cout << "Activation bytes: " << activationAllocBytes << std::endl; - std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl; - std::cout << "Total bytes : " << totalAllocBytes << std::endl; - } - } - - template class Network<8>; - template class Network<16>; - -} // namespace oidn diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h deleted file mode 100644 index 7a696fd35..000000000 --- a/thirdparty/oidn/core/network.h +++ /dev/null @@ -1,112 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "common/tensor.h" -#include "image.h" -#include "node.h" -#include "input_reorder.h" -#include "output_reorder.h" -#include "transfer_function.h" - -#pragma once - -namespace oidn { - - // Progress state - struct Progress - { - ProgressMonitorFunction func; - void* userPtr; - int taskCount; - }; - - class Executable - { - public: - virtual ~Executable() {} - virtual void execute(const Progress& progress, int taskIndex) = 0; - }; - - template - class Network : public Executable - { - public: - Network(const Ref& device, const std::map& weightMap); - - void execute(const Progress& progress, int taskIndex) override; - - std::shared_ptr allocTensor(const memory::dims& dims, - memory::format_tag format = memory::format_tag::any, - void* data = nullptr); - - std::shared_ptr castTensor(const memory::dims& dims, - const std::shared_ptr& src, - size_t srcOffset = 0, - memory::format_tag format = memory::format_tag::any); - - std::shared_ptr castTensor(const memory::dims& dims, - const std::shared_ptr& src, - const memory::dims& srcOffset); - - void zeroTensor(const std::shared_ptr& dst); - - memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); - - std::shared_ptr addInputReorder(const Image& color, - const Image& albedo, - const Image& normal, - const std::shared_ptr& transferFunc, - int alignment, - const std::shared_ptr& userDst = nullptr); - - std::shared_ptr addOutputReorder(const std::shared_ptr& src, - const std::shared_ptr& transferFunc, - const Image& output); - - memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); - std::shared_ptr addConv(const std::string& name, - const std::shared_ptr& src, - const std::shared_ptr& userDst = nullptr, - bool relu = true); - - memory::dims getPoolDims(const memory::dims& srcDims); - std::shared_ptr addPool(const std::shared_ptr& src, - const std::shared_ptr& userDst = nullptr); - - memory::dims getUpsampleDims(const memory::dims& srcDims); - std::shared_ptr addUpsample(const std::shared_ptr& src, - const std::shared_ptr& userDst = nullptr); - - memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); - - std::shared_ptr addAutoexposure(const Image& color, - const std::shared_ptr& transferFunc); - - void finalize(); - - private: - Ref device; - engine eng; - stream sm; - std::vector> nodes; - std::map weightMap; - - // Memory allocation statistics - size_t activationAllocBytes = 0; // number of allocated activation bytes - size_t totalAllocBytes = 0; // total number of allocated bytes - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h deleted file mode 100644 index b9ffe906d..000000000 --- a/thirdparty/oidn/core/node.h +++ /dev/null @@ -1,142 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "common.h" -#include - -namespace oidn { - - class Node - { - public: - virtual ~Node() = default; - - virtual void execute(stream& sm) = 0; - - virtual std::shared_ptr getDst() const { return nullptr; } - - virtual size_t getScratchpadSize() const { return 0; } - virtual void setScratchpad(const std::shared_ptr& mem) {} - - virtual void setTile(int h1, int w1, int h2, int w2, int H, int W) - { - assert(0); // not supported - } - }; - - // Node wrapping an MKL-DNN primitive - class MklNode : public Node - { - private: - primitive prim; - std::unordered_map args; - std::shared_ptr scratchpad; - - public: - MklNode(const primitive& prim, const std::unordered_map& args) - : prim(prim), - args(args) - {} - - size_t getScratchpadSize() const override - { - const auto primDesc = prim.get_primitive_desc(); - const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0); - if (scratchpadDesc == nullptr) - return 0; - return mkldnn_memory_desc_get_size(scratchpadDesc); - } - - void setScratchpad(const std::shared_ptr& mem) override - { - scratchpad = mem; - args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad)); - } - - void execute(stream& sm) override - { - prim.execute(sm, args); - } - }; - - // Convolution node - class ConvNode : public MklNode - { - private: - std::shared_ptr src; - std::shared_ptr weights; - std::shared_ptr bias; - std::shared_ptr dst; - - public: - ConvNode(const convolution_forward::primitive_desc& desc, - const std::shared_ptr& src, - const std::shared_ptr& weights, - const std::shared_ptr& bias, - const std::shared_ptr& dst) - : MklNode(convolution_forward(desc), - { { MKLDNN_ARG_SRC, *src }, - { MKLDNN_ARG_WEIGHTS, *weights }, - { MKLDNN_ARG_BIAS, *bias }, - { MKLDNN_ARG_DST, *dst } }), - src(src), weights(weights), bias(bias), dst(dst) - {} - - std::shared_ptr getDst() const override { return dst; } - }; - - // Pooling node - class PoolNode : public MklNode - { - private: - std::shared_ptr src; - std::shared_ptr dst; - - public: - PoolNode(const pooling_forward::primitive_desc& desc, - const std::shared_ptr& src, - const std::shared_ptr& dst) - : MklNode(pooling_forward(desc), - { { MKLDNN_ARG_SRC, *src }, - { MKLDNN_ARG_DST, *dst } }), - src(src), dst(dst) - {} - - std::shared_ptr getDst() const override { return dst; } - }; - - // Reorder node - class ReorderNode : public MklNode - { - private: - std::shared_ptr src; - std::shared_ptr dst; - - public: - ReorderNode(const std::shared_ptr& src, - const std::shared_ptr& dst) - : MklNode(reorder(reorder::primitive_desc(*src, *dst)), - { { MKLDNN_ARG_SRC, *src }, - { MKLDNN_ARG_DST, *dst } }), - src(src), dst(dst) - {} - - std::shared_ptr getDst() const override { return dst; } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/output_reorder.h b/thirdparty/oidn/core/output_reorder.h deleted file mode 100644 index 7918d48e1..000000000 --- a/thirdparty/oidn/core/output_reorder.h +++ /dev/null @@ -1,126 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "node.h" -#include "image.h" - -namespace oidn { - - // Output reorder node - template - class OutputReorderNode : public Node - { - private: - // Source - std::shared_ptr src; - const float* srcPtr; - int H1; - int W1; - - // Destination - Image output; - - // Tile - int h1Begin; - int w1Begin; - int h2Begin; - int w2Begin; - int H; - int W; - - std::shared_ptr transferFunc; - - public: - OutputReorderNode(const std::shared_ptr& src, - const Image& output, - const std::shared_ptr& transferFunc) - : src(src), - output(output), - h1Begin(0), w1Begin(0), - h2Begin(0), w2Begin(0), - H(output.height), W(output.width), - transferFunc(transferFunc) - { - const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; - MAYBE_UNUSED(srcDesc); - assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); - assert(srcDesc.ndims == 4); - assert(srcDesc.data_type == memory::data_type::f32); - assert(srcDesc.dims[0] == 1); - // We assume output data is <= K OC - assert(srcDesc.dims[1] == K); - - srcPtr = (float*)src->get_data_handle(); - H1 = srcDesc.dims[2]; - W1 = srcDesc.dims[3]; - } - - void setTile(int h1, int w1, int h2, int w2, int H, int W) override - { - h1Begin = h1; - w1Begin = w1; - h2Begin = h2; - w2Begin = w2; - this->H = H; - this->W = W; - } - - void execute(stream& sm) override - { - assert(h1Begin + H <= H1); - assert(w1Begin + W <= W1); - assert(h2Begin + H <= output.height); - assert(w2Begin + W <= output.width); - - const int C1 = K; - - parallel_nd(H, [&](int h) - { - const int h1 = h + h1Begin; - const int h2 = h + h2Begin; - - for (int w = 0; w < W; ++w) - { - const int w1 = w + w1Begin; - const int w2 = w + w2Begin; - float* dstPtr_C = (float*)output.get(h2, w2); - - // Source is in nChwKc format. In this case C is 1 so this is really nhwc - const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1; - - #pragma unroll - for (int i = 0; i < 3; ++i) - { - // Load the value - float x = srcPtr_C[i]; - - // The CNN output may contain negative values or even NaNs, so it must be sanitized - x = maxSafe(x, 0.f); - - // Apply the inverse transfer function - x = transferFunc->inverse(x); - - // Sanitize and store the final value - dstPtr_C[i] = max(x, 0.f); - } - } - }); - } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.cpp b/thirdparty/oidn/core/transfer_function.cpp deleted file mode 100644 index ce5deca56..000000000 --- a/thirdparty/oidn/core/transfer_function.cpp +++ /dev/null @@ -1,103 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#include "transfer_function.h" - -namespace oidn { - - const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f); - const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale); - - float AutoexposureNode::autoexposure(const Image& color) - { - assert(color.format == Format::Float3); - - constexpr float key = 0.18f; - constexpr float eps = 1e-8f; - constexpr int K = 16; // downsampling amount - - // Downsample the image to minimize sensitivity to noise - const int H = color.height; // original height - const int W = color.width; // original width - const int HK = (H + K/2) / K; // downsampled height - const int WK = (W + K/2) / K; // downsampled width - - // Compute the average log luminance of the downsampled image - using Sum = std::pair; - - // -- GODOT start -- - // Sum sum = - // tbb::parallel_reduce( - // tbb::blocked_range2d(0, HK, 0, WK), - // Sum(0.f, 0), - // [&](const tbb::blocked_range2d& r, Sum sum) -> Sum - // { - // // Iterate over blocks - // for (int i = r.rows().begin(); i != r.rows().end(); ++i) - // { - // for (int j = r.cols().begin(); j != r.cols().end(); ++j) - // { - - Sum sum = Sum(0.0f, 0); - - for (int i = 0; i != HK; ++i) - { - for (int j = 0; j != WK; ++j) - { - // Compute the average luminance in the current block - const int beginH = int(ptrdiff_t(i) * H / HK); - const int beginW = int(ptrdiff_t(j) * W / WK); - const int endH = int(ptrdiff_t(i+1) * H / HK); - const int endW = int(ptrdiff_t(j+1) * W / WK); - - float L = 0.f; - - for (int h = beginH; h < endH; ++h) - { - for (int w = beginW; w < endW; ++w) - { - const float* rgb = (const float*)color.get(h, w); - - const float r = maxSafe(rgb[0], 0.f); - const float g = maxSafe(rgb[1], 0.f); - const float b = maxSafe(rgb[2], 0.f); - - L += luminance(r, g, b); - } - } - - L /= (endH - beginH) * (endW - beginW); - - // Accumulate the log luminance - if (L > eps) - { - sum.first += log2(L); - sum.second++; - } - } - } - - // return sum; - // }, - // [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, - // tbb::static_partitioner() - // ); - // -- GODOT end -- - - return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f; - } - -} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.h b/thirdparty/oidn/core/transfer_function.h deleted file mode 100644 index 35f283309..000000000 --- a/thirdparty/oidn/core/transfer_function.h +++ /dev/null @@ -1,201 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "image.h" -#include "node.h" - -namespace oidn { - - __forceinline float luminance(float r, float g, float b) - { - return 0.212671f * r + 0.715160f * g + 0.072169f * b; - } - - // Color transfer function base class - class TransferFunction - { - public: - virtual ~TransferFunction() = default; - - virtual float forward(float y) const = 0; - virtual float inverse(float x) const = 0; - }; - - // HDR transfer function base class - class HDRTransferFunction : public TransferFunction - { - protected: - static constexpr float yMax = 65504.f; - - float exposure; - float rcpExposure; - - public: - HDRTransferFunction(float exposure = 1.f) - { - setExposure(exposure); - } - - void setExposure(float exposure) - { - this->exposure = exposure; - this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f; - } - }; - - // Linear transfer function (LDR) - class LinearTransferFunction : public TransferFunction - { - public: - __forceinline float forward(float y) const override - { - return min(y, 1.f); - } - - __forceinline float inverse(float x) const override - { - return min(x, 1.f); - } - }; - - // 2.2 gamma transfer function (LDR) - class GammaTransferFunction : public TransferFunction - { - public: - __forceinline float forward(float y) const override - { - return min(pow(y, 1.f/2.2f), 1.f); - } - - __forceinline float inverse(float x) const override - { - return min(pow(x, 2.2f), 1.f); - } - }; - - // Logarithmic transfer function (HDR) - // Compresses [0..65504] to [0..1] - class LogTransferFunction : public HDRTransferFunction - { - private: - static const float xScale; - - public: - LogTransferFunction(float exposure = 1.f) - : HDRTransferFunction(exposure) - { - } - - __forceinline float forward(float y) const override - { - return log(y * exposure + 1.f) * xScale; - } - - __forceinline float inverse(float x) const override - { - return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure; - } - }; - - // PQX transfer function (HDR) - // Compresses [0..65504] to [0..1] - class PQXTransferFunction : public HDRTransferFunction - { - private: - static constexpr float m1 = 2610.f / 4096.f / 4.f; - static constexpr float m2 = 2523.f / 4096.f * 128.f; - static constexpr float c1 = 3424.f / 4096.f; - static constexpr float c2 = 2413.f / 4096.f * 32.f; - static constexpr float c3 = 2392.f / 4096.f * 32.f; - static constexpr float a = 3711.f / 4096.f / 8.f; - - static constexpr float yScale = 100.f / 10000.f; - static const float xScale; - - public: - PQXTransferFunction(float exposure = 1.f) - : HDRTransferFunction(exposure) - { - } - - __forceinline float forward(float y) const override - { - return pqxForward(y * exposure * yScale) * xScale; - } - - __forceinline float inverse(float x) const override - { - return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure; - } - - private: - static __forceinline float pqForward(float y) - { - const float yp = pow(y, m1); - return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2); - } - - static __forceinline float pqxForward(float y) - { - if (y <= 1.f) - return pqForward(y); - else - return a * log(y) + 1.f; - } - - static __forceinline float pqInverse(float x) - { - const float xp = pow(x, 1.f/m2); - return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1); - } - - static __forceinline float pqxInverse(float x) - { - if (x <= 1.f) - return pqInverse(x); - else - return exp((x - 1.f) * (1.f/a)); - } - }; - - // Autoexposure node - class AutoexposureNode : public Node - { - private: - Image color; - std::shared_ptr transferFunc; - - public: - AutoexposureNode(const Image& color, - const std::shared_ptr& transferFunc) - : color(color), - transferFunc(transferFunc) - {} - - void execute(stream& sm) override - { - const float exposure = autoexposure(color); - //printf("exposure = %f\n", exposure); - transferFunc->setExposure(exposure); - } - - private: - static float autoexposure(const Image& color); - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/upsample.h b/thirdparty/oidn/core/upsample.h deleted file mode 100644 index f6cace44c..000000000 --- a/thirdparty/oidn/core/upsample.h +++ /dev/null @@ -1,92 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "node.h" - -namespace oidn { - - // 2x2 nearest-neighbor upsampling node - template - class UpsampleNode : public Node - { - private: - std::shared_ptr src; - std::shared_ptr dst; - - public: - UpsampleNode(const std::shared_ptr& src, - const std::shared_ptr& dst) - : src(src), - dst(dst) - { - const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; - const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; - MAYBE_UNUSED(srcDesc); - MAYBE_UNUSED(dstDesc); - assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); - assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat::nChwKc))); - assert(srcDesc.ndims == 4); - assert(dstDesc.ndims == 4); - assert(srcDesc.data_type == memory::data_type::f32); - assert(dstDesc.data_type == memory::data_type::f32); - assert(srcDesc.dims[0] == 1); - assert(dstDesc.dims[0] == 1); - // 2x2 upsampling - assert(dstDesc.dims[2] == srcDesc.dims[2] * 2); - assert(dstDesc.dims[3] == srcDesc.dims[3] * 2); - } - - void execute(stream& sm) override - { - const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; - - const float* srcPtr = (float*)src->get_data_handle(); - float* dstPtr = (float*)dst->get_data_handle(); - - const int C = srcDesc.dims[1]; - const int H = srcDesc.dims[2]; - const int W = srcDesc.dims[3]; - const int CK = C / K; - - parallel_nd(CK, H, [&](int ck, int h) - { - const size_t offset = ck*H*W*K + h*W*K; - const float* srcPtr_line = srcPtr + offset; - float* dstPtr_line0 = dstPtr + offset * 4; - float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line - - for (int w = 0; w < W; ++w) - { - #pragma unroll - for (int k = 0; k < K; k += 4) - { - const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]); - - _mm_stream_ps(&dstPtr_line0[w*2*K + k], m); - _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m); - _mm_stream_ps(&dstPtr_line1[w*2*K + k], m); - _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m); - } - } - }); - } - - std::shared_ptr getDst() const override { return dst; } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/core/weights_reorder.h b/thirdparty/oidn/core/weights_reorder.h deleted file mode 100644 index 6c5dacb8a..000000000 --- a/thirdparty/oidn/core/weights_reorder.h +++ /dev/null @@ -1,99 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include "node.h" - -namespace oidn { - - // Reorders weights from oihw to padded oihw format - template - class WeightsReorderNode : public Node - { - private: - std::shared_ptr src; - std::shared_ptr dst; - - public: - WeightsReorderNode(const std::shared_ptr& src, - const std::shared_ptr& dst) - : src(src), - dst(dst) - { - const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; - const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; - MAYBE_UNUSED(srcDesc); - MAYBE_UNUSED(dstDesc); - assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); - assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); - assert(srcDesc.ndims == 4); - assert(dstDesc.ndims == 4); - assert(srcDesc.data_type == memory::data_type::f32); - assert(dstDesc.data_type == memory::data_type::f32); - assert(getPadded(srcDesc.dims[0]) == dstDesc.dims[0]); // OC - assert(getPadded(srcDesc.dims[1]) == dstDesc.dims[1]); // IC - assert(srcDesc.dims[2] == dstDesc.dims[2]); - assert(srcDesc.dims[3] == dstDesc.dims[3]); - } - - void execute(stream& sm) override - { - const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; - const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; - - const float* srcPtr = (float*)src->get_data_handle(); - float* dstPtr = (float*)dst->get_data_handle(); - - const int OC1 = srcDesc.dims[0]; - const int OC2 = dstDesc.dims[0]; - const int IC1 = srcDesc.dims[1]; - const int IC2 = dstDesc.dims[1]; - const int H = dstDesc.dims[2]; - const int W = dstDesc.dims[3]; - - for (int oc = 0; oc < OC2; ++oc) - { - for (int ic = 0; ic < IC2; ++ic) - { - for (int h = 0; h < H; ++h) - { - for (int w = 0; w < W; ++w) - { - // Output is in oihw format - float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w; - - if (oc < OC1 && ic < IC1) - { - // Input is in oihw format - const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w; - *dstPtr_c = *srcPtr_c; - } - else - { - // padding - *dstPtr_c = 0; - } - } - } - } - } - } - - std::shared_ptr getDst() const override { return dst; } - }; - -} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.h b/thirdparty/oidn/include/OpenImageDenoise/oidn.h deleted file mode 100644 index 57ba6baa2..000000000 --- a/thirdparty/oidn/include/OpenImageDenoise/oidn.h +++ /dev/null @@ -1,214 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include -#include -#include - -#include "version.h" - -#if defined(__cplusplus) -extern "C" { -#endif - -#ifndef OIDN_API -#if defined(_WIN32) && !defined(OIDN_STATIC_LIB) -# define OIDN_API __declspec(dllimport) -#else -# define OIDN_API -#endif -#endif - -// ---------------------------------------------------------------------------- -// Device -// ---------------------------------------------------------------------------- - -// Device types -typedef enum -{ - OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically - - OIDN_DEVICE_TYPE_CPU = 1, // CPU device -} OIDNDeviceType; - -// Error codes -typedef enum -{ - OIDN_ERROR_NONE = 0, // no error occurred - OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred - OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified - OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed - OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation - OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported - OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user -} OIDNError; - -// Error callback function -typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message); - -// Device handle -typedef struct OIDNDeviceImpl* OIDNDevice; - -// Creates a new device. -OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type); - -// Retains the device (increments the reference count). -OIDN_API void oidnRetainDevice(OIDNDevice device); - -// Releases the device (decrements the reference count). -OIDN_API void oidnReleaseDevice(OIDNDevice device); - -// Sets a boolean parameter of the device. -OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value); - -// Sets an integer parameter of the device. -OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value); - -// Gets a boolean parameter of the device. -OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name); - -// Gets an integer parameter of the device (e.g. "version"). -OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name); - -// Sets the error callback function of the device. -OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr); - -// Returns the first unqueried error code stored in the device for the current -// thread, optionally also returning a string message (if not NULL), and clears -// the stored error. Can be called with a NULL device as well to check why a -// device creation failed. -OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage); - -// Commits all previous changes to the device. -// Must be called before first using the device (e.g. creating filters). -OIDN_API void oidnCommitDevice(OIDNDevice device); - -// ---------------------------------------------------------------------------- -// Buffer -// ---------------------------------------------------------------------------- - -// Formats for images and other data stored in buffers -typedef enum -{ - OIDN_FORMAT_UNDEFINED = 0, - - // 32-bit single-precision floating point scalar and vector formats - OIDN_FORMAT_FLOAT = 1, - OIDN_FORMAT_FLOAT2 = 2, - OIDN_FORMAT_FLOAT3 = 3, - OIDN_FORMAT_FLOAT4 = 4, -} OIDNFormat; - -// Access modes for mapping buffers -typedef enum -{ - OIDN_ACCESS_READ = 0, // read-only access - OIDN_ACCESS_WRITE = 1, // write-only access - OIDN_ACCESS_READ_WRITE = 2, // read and write access - OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded -} OIDNAccess; - -// Buffer handle -typedef struct OIDNBufferImpl* OIDNBuffer; - -// Creates a new buffer (data allocated and owned by the device). -OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize); - -// Creates a new shared buffer (data allocated and owned by the user). -OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize); - -// Maps a region of the buffer to host memory. -// If byteSize is 0, the maximum available amount of memory will be mapped. -OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize); - -// Unmaps a region of the buffer. -// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer. -OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr); - -// Retains the buffer (increments the reference count). -OIDN_API void oidnRetainBuffer(OIDNBuffer buffer); - -// Releases the buffer (decrements the reference count). -OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer); - -// ---------------------------------------------------------------------------- -// Filter -// ---------------------------------------------------------------------------- - -// Progress monitor callback function -typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n); - -// Filter handle -typedef struct OIDNFilterImpl* OIDNFilter; - -// Creates a new filter of the specified type (e.g. "RT"). -OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type); - -// Retains the filter (increments the reference count). -OIDN_API void oidnRetainFilter(OIDNFilter filter); - -// Releases the filter (decrements the reference count). -OIDN_API void oidnReleaseFilter(OIDNFilter filter); - -// Sets an image parameter of the filter (stored in a buffer). -// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. -OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name, - OIDNBuffer buffer, OIDNFormat format, - size_t width, size_t height, - size_t byteOffset, - size_t bytePixelStride, size_t byteRowStride); - -// Sets an image parameter of the filter (owned by the user). -// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. -OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name, - void* ptr, OIDNFormat format, - size_t width, size_t height, - size_t byteOffset, - size_t bytePixelStride, size_t byteRowStride); - -// Sets a boolean parameter of the filter. -OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value); - -// Gets a boolean parameter of the filter. -OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name); - -// Sets an integer parameter of the filter. -OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value); - -// Gets an integer parameter of the filter. -OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name); - -// Sets a float parameter of the filter. -OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value); - -// Gets a float parameter of the filter. -OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name); - -// Sets the progress monitor callback function of the filter. -OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr); - -// Commits all previous changes to the filter. -// Must be called before first executing the filter. -OIDN_API void oidnCommitFilter(OIDNFilter filter); - -// Executes the filter. -OIDN_API void oidnExecuteFilter(OIDNFilter filter); - -#if defined(__cplusplus) -} -#endif diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp deleted file mode 100644 index 9f95a56fe..000000000 --- a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp +++ /dev/null @@ -1,468 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#include -#include "oidn.h" - -namespace oidn { - - // -------------------------------------------------------------------------- - // Buffer - // -------------------------------------------------------------------------- - - // Formats for images and other data stored in buffers - enum class Format - { - Undefined = OIDN_FORMAT_UNDEFINED, - - // 32-bit single-precision floating point scalar and vector formats - Float = OIDN_FORMAT_FLOAT, - Float2 = OIDN_FORMAT_FLOAT2, - Float3 = OIDN_FORMAT_FLOAT3, - Float4 = OIDN_FORMAT_FLOAT4, - }; - - // Access modes for mapping buffers - enum class Access - { - Read = OIDN_ACCESS_READ, // read-only access - Write = OIDN_ACCESS_WRITE, // write-only access - ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access - WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded - }; - - // Buffer object with automatic reference counting - class BufferRef - { - private: - OIDNBuffer handle; - - public: - BufferRef() : handle(nullptr) {} - BufferRef(OIDNBuffer handle) : handle(handle) {} - - BufferRef(const BufferRef& other) : handle(other.handle) - { - if (handle) - oidnRetainBuffer(handle); - } - - BufferRef(BufferRef&& other) : handle(other.handle) - { - other.handle = nullptr; - } - - BufferRef& operator =(const BufferRef& other) - { - if (&other != this) - { - if (other.handle) - oidnRetainBuffer(other.handle); - if (handle) - oidnReleaseBuffer(handle); - handle = other.handle; - } - return *this; - } - - BufferRef& operator =(BufferRef&& other) - { - std::swap(handle, other.handle); - return *this; - } - - BufferRef& operator =(OIDNBuffer other) - { - if (other) - oidnRetainBuffer(other); - if (handle) - oidnReleaseBuffer(handle); - handle = other; - return *this; - } - - ~BufferRef() - { - if (handle) - oidnReleaseBuffer(handle); - } - - OIDNBuffer getHandle() const - { - return handle; - } - - operator bool() const - { - return handle != nullptr; - } - - // Maps a region of the buffer to host memory. - // If byteSize is 0, the maximum available amount of memory will be mapped. - void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0) - { - return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize); - } - - // Unmaps a region of the buffer. - // mappedPtr must be a pointer returned by a previous call to map. - void unmap(void* mappedPtr) - { - oidnUnmapBuffer(handle, mappedPtr); - } - }; - - // -------------------------------------------------------------------------- - // Filter - // -------------------------------------------------------------------------- - - // Progress monitor callback function - typedef bool (*ProgressMonitorFunction)(void* userPtr, double n); - - // Filter object with automatic reference counting - class FilterRef - { - private: - OIDNFilter handle; - - public: - FilterRef() : handle(nullptr) {} - FilterRef(OIDNFilter handle) : handle(handle) {} - - FilterRef(const FilterRef& other) : handle(other.handle) - { - if (handle) - oidnRetainFilter(handle); - } - - FilterRef(FilterRef&& other) : handle(other.handle) - { - other.handle = nullptr; - } - - FilterRef& operator =(const FilterRef& other) - { - if (&other != this) - { - if (other.handle) - oidnRetainFilter(other.handle); - if (handle) - oidnReleaseFilter(handle); - handle = other.handle; - } - return *this; - } - - FilterRef& operator =(FilterRef&& other) - { - std::swap(handle, other.handle); - return *this; - } - - FilterRef& operator =(OIDNFilter other) - { - if (other) - oidnRetainFilter(other); - if (handle) - oidnReleaseFilter(handle); - handle = other; - return *this; - } - - ~FilterRef() - { - if (handle) - oidnReleaseFilter(handle); - } - - OIDNFilter getHandle() const - { - return handle; - } - - operator bool() const - { - return handle != nullptr; - } - - // Sets an image parameter of the filter (stored in a buffer). - void setImage(const char* name, - const BufferRef& buffer, Format format, - size_t width, size_t height, - size_t byteOffset = 0, - size_t bytePixelStride = 0, size_t byteRowStride = 0) - { - oidnSetFilterImage(handle, name, - buffer.getHandle(), (OIDNFormat)format, - width, height, - byteOffset, - bytePixelStride, byteRowStride); - } - - // Sets an image parameter of the filter (owned by the user). - void setImage(const char* name, - void* ptr, Format format, - size_t width, size_t height, - size_t byteOffset = 0, - size_t bytePixelStride = 0, size_t byteRowStride = 0) - { - oidnSetSharedFilterImage(handle, name, - ptr, (OIDNFormat)format, - width, height, - byteOffset, - bytePixelStride, byteRowStride); - } - - // Sets a boolean parameter of the filter. - void set(const char* name, bool value) - { - oidnSetFilter1b(handle, name, value); - } - - // Sets an integer parameter of the filter. - void set(const char* name, int value) - { - oidnSetFilter1i(handle, name, value); - } - - // Sets a float parameter of the filter. - void set(const char* name, float value) - { - oidnSetFilter1f(handle, name, value); - } - - // Gets a parameter of the filter. - template - T get(const char* name); - - // Sets the progress monitor callback function of the filter. - void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr) - { - oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr); - } - - // Commits all previous changes to the filter. - void commit() - { - oidnCommitFilter(handle); - } - - // Executes the filter. - void execute() - { - oidnExecuteFilter(handle); - } - }; - - // Gets a boolean parameter of the filter. - template<> - inline bool FilterRef::get(const char* name) - { - return oidnGetFilter1b(handle, name); - } - - // Gets an integer parameter of the filter. - template<> - inline int FilterRef::get(const char* name) - { - return oidnGetFilter1i(handle, name); - } - - // Gets a float parameter of the filter. - template<> - inline float FilterRef::get(const char* name) - { - return oidnGetFilter1f(handle, name); - } - - // -------------------------------------------------------------------------- - // Device - // -------------------------------------------------------------------------- - - // Device types - enum class DeviceType - { - Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically - - CPU = OIDN_DEVICE_TYPE_CPU, // CPU device - }; - - // Error codes - enum class Error - { - None = OIDN_ERROR_NONE, // no error occurred - Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred - InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified - InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed - OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation - UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported - Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user - }; - - // Error callback function - typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message); - - // Device object with automatic reference counting - class DeviceRef - { - private: - OIDNDevice handle; - - public: - DeviceRef() : handle(nullptr) {} - DeviceRef(OIDNDevice handle) : handle(handle) {} - - DeviceRef(const DeviceRef& other) : handle(other.handle) - { - if (handle) - oidnRetainDevice(handle); - } - - DeviceRef(DeviceRef&& other) : handle(other.handle) - { - other.handle = nullptr; - } - - DeviceRef& operator =(const DeviceRef& other) - { - if (&other != this) - { - if (other.handle) - oidnRetainDevice(other.handle); - if (handle) - oidnReleaseDevice(handle); - handle = other.handle; - } - return *this; - } - - DeviceRef& operator =(DeviceRef&& other) - { - std::swap(handle, other.handle); - return *this; - } - - DeviceRef& operator =(OIDNDevice other) - { - if (other) - oidnRetainDevice(other); - if (handle) - oidnReleaseDevice(handle); - handle = other; - return *this; - } - - ~DeviceRef() - { - if (handle) - oidnReleaseDevice(handle); - } - - OIDNDevice getHandle() const - { - return handle; - } - - operator bool() const - { - return handle != nullptr; - } - - // Sets a boolean parameter of the device. - void set(const char* name, bool value) - { - oidnSetDevice1b(handle, name, value); - } - - // Sets an integer parameter of the device. - void set(const char* name, int value) - { - oidnSetDevice1i(handle, name, value); - } - - // Gets a parameter of the device. - template - T get(const char* name); - - // Sets the error callback function of the device. - void setErrorFunction(ErrorFunction func, void* userPtr = nullptr) - { - oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr); - } - - // Returns the first unqueried error code and clears the stored error. - // Can be called for a null device as well to check why a device creation failed. - Error getError() - { - return (Error)oidnGetDeviceError(handle, nullptr); - } - - // Returns the first unqueried error code and string message, and clears the stored error. - // Can be called for a null device as well to check why a device creation failed. - Error getError(const char*& outMessage) - { - return (Error)oidnGetDeviceError(handle, &outMessage); - } - - // Commits all previous changes to the device. - // Must be called before first using the device (e.g. creating filters). - void commit() - { - oidnCommitDevice(handle); - } - - // Creates a new buffer (data allocated and owned by the device). - BufferRef newBuffer(size_t byteSize) - { - return oidnNewBuffer(handle, byteSize); - } - - // Creates a new shared buffer (data allocated and owned by the user). - BufferRef newBuffer(void* ptr, size_t byteSize) - { - return oidnNewSharedBuffer(handle, ptr, byteSize); - } - - // Creates a new filter of the specified type (e.g. "RT"). - FilterRef newFilter(const char* type) - { - return oidnNewFilter(handle, type); - } - }; - - // Gets a boolean parameter of the device. - template<> - inline bool DeviceRef::get(const char* name) - { - return oidnGetDevice1b(handle, name); - } - - // Gets an integer parameter of the device (e.g. "version"). - template<> - inline int DeviceRef::get(const char* name) - { - return oidnGetDevice1i(handle, name); - } - - // Creates a new device. - inline DeviceRef newDevice(DeviceType type = DeviceType::Default) - { - return DeviceRef(oidnNewDevice((OIDNDeviceType)type)); - } - -} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/version.h b/thirdparty/oidn/include/OpenImageDenoise/version.h deleted file mode 100644 index 66b347c99..000000000 --- a/thirdparty/oidn/include/OpenImageDenoise/version.h +++ /dev/null @@ -1,23 +0,0 @@ -// ======================================================================== // -// Copyright 2009-2019 Intel Corporation // -// // -// Licensed under the Apache License, Version 2.0 (the "License"); // -// you may not use this file except in compliance with the License. // -// You may obtain a copy of the License at // -// // -// http://www.apache.org/licenses/LICENSE-2.0 // -// // -// Unless required by applicable law or agreed to in writing, software // -// distributed under the License is distributed on an "AS IS" BASIS, // -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and // -// limitations under the License. // -// ======================================================================== // - -#pragma once - -#define OIDN_VERSION_MAJOR 1 -#define OIDN_VERSION_MINOR 1 -#define OIDN_VERSION_PATCH 0 -#define OIDN_VERSION 10100 -#define OIDN_VERSION_STRING "1.1.0" diff --git a/thirdparty/oidn/mkl-dnn/LICENSE b/thirdparty/oidn/mkl-dnn/LICENSE deleted file mode 100644 index d13f7b7ca..000000000 --- a/thirdparty/oidn/mkl-dnn/LICENSE +++ /dev/null @@ -1,214 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ============================================================================ - - Intel MKL-DNN includes components with separate copyright - notices and license terms. - - XByak, 3-clause BSD license - Copyright (c) 2007 MITSUNARI Shigeo - See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT - - gtest, 3-clause BSD license - Copyright 2008, Google Inc. - See full copyright notice and license text in tests/gtests/gtest/LICENSE diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.h b/thirdparty/oidn/mkl-dnn/include/mkldnn.h deleted file mode 100644 index 9b6499492..000000000 --- a/thirdparty/oidn/mkl-dnn/include/mkldnn.h +++ /dev/null @@ -1,1771 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_H -#define MKLDNN_H - -#ifndef DOXYGEN_SHOULD_SKIP_THIS - -/* All symbols shall be internal unless marked as MKLDNN_API */ -#if defined _WIN32 || defined __CYGWIN__ -# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) -# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) -#else -# if __GNUC__ >= 4 -# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) -# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) -# else -# define MKLDNN_HELPER_DLL_IMPORT -# define MKLDNN_HELPER_DLL_EXPORT -# endif -#endif - -#ifdef MKLDNN_DLL -# ifdef MKLDNN_DLL_EXPORTS -# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT -# else -# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT -# endif -#else -# define MKLDNN_API -#endif - -#if defined (__GNUC__) -# define MKLDNN_DEPRECATED __attribute__((deprecated)) -#elif defined(_MSC_VER) -# define MKLDNN_DEPRECATED __declspec(deprecated) -#else -# define MKLDNN_DEPRECATED -#endif - -#include "mkldnn_types.h" -#include "mkldnn_version.h" -#endif /* DOXYGEN_SHOULD_SKIP_THIS */ - -#ifdef __cplusplus -extern "C" { -#endif - -/** @addtogroup c_api C API - * @{ */ - -/** @addtogroup c_api_primitive Primitive operations - * @{ */ - -/** @addtogroup c_api_primitive_common Common primitive operations - * @{ */ - -/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr, - * @p engine, and optionally a hint primitive descriptor from forward - * propagation (required for backward propagation). Pass @c NULL for forward - * propagation. - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create( - mkldnn_primitive_desc_iterator_t *iterator, - const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, - mkldnn_engine_t engine, - const_mkldnn_primitive_desc_t hint_forward_primitive_desc); - -/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no - * more primitive descriptors are available. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next( - mkldnn_primitive_desc_iterator_t iterator); - -/** Fetches the current primitive descriptor. - * - * @note - * The user should delete the fetched primitive descriptor using - * mkldnn_primitive_desc_destroy() once it is no longer needed. */ -mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch( - const_mkldnn_primitive_desc_iterator_t iterator); - -/** Deletes a primitive descriptor @p iterator */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy( - mkldnn_primitive_desc_iterator_t iterator); - -/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and - * optionally a hint primitive descriptor from forward propagation. The call is - * equivalent to creating a primitive descriptor iterator, immediately fetching - * a primitive descriptor, and then destroying the iterator. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create( - mkldnn_primitive_desc_t *primitive_desc, - const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, - mkldnn_engine_t engine, - const_mkldnn_primitive_desc_t hint_forward_primitive_desc); - -/** Makes a copy of a @p primitive_desc. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone( - mkldnn_primitive_desc_t *primitive_desc, - const_mkldnn_primitive_desc_t existing_primitive_desc); - -/** Returns a constant reference to the attribute of a @p primitive_desc. - * - * @warning - * The user should not destroy the obtained @p attr. - * - * @warning - * The lifetime of an @p attr is the same as that of a @p primitive_desc, - * so it is illegal to use the @p attr once @p primitive_desc has been - * destroyed. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr( - const_mkldnn_primitive_desc_t primitive_desc, - const_mkldnn_primitive_attr_t *attr); - -/** Deletes a @p primitive_desc. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy( - mkldnn_primitive_desc_t primitive_desc); - -/** Queries primitive descriptor - * - * One of the most typical use cases is to query a convolution primitive - * descriptor created with source, weights, and destination formats equal - * to #mkldnn_format_tag_any about the corresponding memory descriptors - * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and - * #mkldnn_query_dst_md respectively) to be able to prepare memory and - * create reorders if required. - * - * Another quite typical use case is to query an operation primitive - * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md). - * The returned status #mkldnn_not_required indicates that a workspace is - * not required. - * - * A few other possibilities: - * - query an operation primitive descriptor for the underlying operation - * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d, - * #mkldnn_query_rnn_d, etc.) - * - query an operation primitive descriptor for the implementation - * information string (#mkldnn_query_impl_info_str) - * - query an operation primitive descriptor for the number of inputs and - * outputs (#mkldnn_query_num_of_inputs_s32 and - * #mkldnn_query_num_of_outputs_s32 respectively) - * - * @sa mkldnn_query_t for more options - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query( - const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, - int index, void *result); - -/** Queries primitive descriptor for memory descriptor - * - * @returns NULL in case of any error. - * - * This is just a specialized version of mkldnn_primitive_desc_query - * used for convenience. - */ -const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md( - const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, - int index); - -/** Queries primitive descriptor for signed 32bit int - * - * @returns 0 in case of any error (in particular if the queried entity is - * not of type int32_t). Note that 0 might also be the actual returned - * value. - * - * This is just a specialized version of mkldnn_primitive_desc_query - * used for convenience. - */ -int MKLDNN_API mkldnn_primitive_desc_query_s32( - const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, - int index); - -/** Creates a @p primitive using a @p primitive_desc descriptor. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_create( - mkldnn_primitive_t *primitive, - const_mkldnn_primitive_desc_t primitive_desc); - -/** Executes a @p primitive using a @p stream, and @p nargs arguments - * @p args. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_execute( - const_mkldnn_primitive_t primitive, mkldnn_stream_t stream, - int nargs, const mkldnn_exec_arg_t *args); - -/** Retrieves a reference to the @p primitive_desc descriptor of given @p - * primitive. - * - * @warning - * The returned object must not be destroyed by the user. The @c const - * qualifier of the returned object prevents such attempts. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc( - const_mkldnn_primitive_t primitive, - const_mkldnn_primitive_desc_t *primitive_desc); - -/** Deletes a @p primitive. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy( - mkldnn_primitive_t primitive); - -/** @} */ - -/** @addtogroup c_api_attributes Attributes - * An extension for controlling primitive behavior. - * @{ */ - -/** Creates an empty (default) @p attr attribute. All the parameters are set to - * default values. - * - * An empty attribute is used in primitive descriptor creation whenever it - * is not passed explicitly, e.g. in mkldnn_primitive_desc_create. - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create( - mkldnn_primitive_attr_t *attr); - -/** Makes a copy of an @p existing_attr. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone( - mkldnn_primitive_attr_t *attr, - const_mkldnn_primitive_attr_t existing_attr); - -/** Deletes an @p attr. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy( - mkldnn_primitive_attr_t attr); - -/** Returns the scratchpad @p mode set in the attribute @p attr */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode( - const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode); - -/** Sets scratchpad @p mode. - * - * The possible values are: #mkldnn_scratchpad_mode_library (default) and - * #mkldnn_scratchpad_mode_user. */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode( - mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode); - -/** Returns @p count, correspondence scale @p mask, and a pointer to a constant - * floating point array of output @p scales for given @p attr, previously set - * by mkldnn_primitive_attr_set_output_scales. - * - * @warning - * The @p scales array points to the internal @p attr field, so the user - * should not modify or destroy @p scales. - * - * @warning - * The lifetime of @p scales is the same as that of the @p attr to which it - * belongs, so it is illegal to use @p scales after @p attr is destroyed. - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales( - const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask, - const float **scales); - -/** Sets output @p scales for primitive operations. The number of elements @p - * count and correspondence scale @p mask are stored for future use. - * - * The @p mask argument defines the correspondence between the output tensor - * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a - * dedicated scaling factor for each slice of the output tensor over the i-th - * dimension. Set @p mask to 0 to use a common scaling factor for the whole - * output tensor. - * - * @note - * The dimension order is always native and does not depend on the actual - * layout used. Examples: - * - 2D dimensional data the order of dimensions is always: (n, c) - * - 4D dimensional data the order is always: (n, c, h, w) - * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw) - * - * Example usage: - * @code - * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params - * float scales[oc] = { ... }; // unique output scales per output channel - * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... - * - * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc - * - * mkldnn_primitive_attr_t attr; - * mkldnn_primitive_attr_create(&attr); // create default attributes - * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales); - * - * mkldnn_primitive_desc_t cpd; - * mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL); - * @endcode - * - * @note - * There is no way to check that @p count corresponds to @p mask until an - * actual primitive descriptor is created, so it is the user's - * responsibility to set proper values. The following formula must hold: - * - * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales( - mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, - const float *scales); - -/** Returns @p post_ops for given @p attr. - * - * @warning - * @p post_ops points to the internal @p attr field, so the user should not - * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the - * same as that of the @p attr it belongs to, so it is illegal to use @p - * post_ops after @p attr has been destroyed. - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops( - const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops); - -/** Sets configured @p post_ops to an attribute @p attr for future use (when - * primitive descriptor is being created). - * - * @note - * At this point in time, there is no way to check whether the primitive - * descriptor does or does not support a given sequence of post operations. - * Therefore the user should handle an error that might occur at the - * mkldnn_primitive_desc_create call. - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops( - mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops); - -/** @addtogroup c_api_attributes_post_ops Sequence of post operations - * An extension for performing extra operations after a base operation. - * @{ */ - -/** Creates an empty sequence of post operations @p post_ops. */ -mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops); - -/** Deletes a @p post_ops sequence. */ -mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops); - -/** Returns the @p length of post operations for given @p post_ops. */ -int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops); - -/** Returns the type of post operation with index @p index in given - * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */ -mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind( - const_mkldnn_post_ops_t post_ops, int index); - -/** Appends accumulation (sum) post operation to the @p post_ops. Prior to - * accumulating the result, the previous value would be multiplied by @p scale. - * - * The kind of this post operation is #mkldnn_sum. - * - * This feature might improve performance for cases like residual learning - * blocks, where the result of convolution is accumulated to the previously - * computed activations. The parameter @p scale might be extreme for the - * integer-based computations when the result and previous activations have - * different logical scaling factors. - * - * In the simplest case when the accumulation is the only post operation, the - * computations would be: - * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...) - * - * @note - * This post operation (as well as all the others) disregards the original - * layout of the destination; that is, the layout of the original - * destination is expected to be the same as the layout of the stored - * destination. - */ -mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum( - mkldnn_post_ops_t post_ops, float scale); - -/** Gets the parameters of the accumulation (sum) post operation with index - * @p index in the sequence of @p post_ops. - * - * @note - * If index @p index would not correspond to the accumulation post - * operation, the function returns #mkldnn_invalid_arguments. - */ -mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum( - const_mkldnn_post_ops_t post_ops, int index, float *scale); - -/** Appends eltwise post operation to the @p post_ops with given parameters - * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and - * mkldnn_eltwise_desc_t). - * - * The kind of this post operation is #mkldnn_eltwise. - * - * In the simplest case when the eltwise is the only post operation, the - * computations would be: - * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...) - * where eltwise_op is configured with the given parameters. - */ -mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise( - mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, - float alpha, float beta); - -/** Gets the eltwise parameters of the post operation with index @p index in - * the sequence of @p post_ops. - */ -mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise( - const_mkldnn_post_ops_t post_ops, int index, float *scale, - mkldnn_alg_kind_t *alg, float *alpha, float *beta); - -/** @} */ - -/** @} */ - -/** @addtogroup c_api_memory Memory - * A primitive to describe and store data. - * - * The library supports various data types and formats. Memory hierarchy - * consists of three levels of abstraction: - * 1. **Memory descriptor** -- engine agnostic logical description of data - * (number of dimensions, dimensions themselves, and data type), and - * optionally the format/layout that describes the physical representation - * of data in memory. If the format is not known yet, one can pass - * #mkldnn_format_tag_any. This approach is used to allow compute-intensive - * primitives to specify the most appropriate format on their own with - * users required to reorder the data if the incoming format doesn't match - * the primitive's selection. Memory descriptor can be initialized with - * mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides() - * functions, or by directly filling the mkldnn_memory_desc_t structure. - * The latter requires deep knowledge of how the physical data - * representation is mapped to the structure. - * The @ref understanding_memory_formats topic should shed some light on - * that. - * For the fully defined memory descriptors (i.e. where the format kind is - * not equal to #mkldnn_format_kind_any) a user can the size, using the - * mkldnn_memory_desc_get_size() function. As described in - * @ref understanding_memory_formats, the size of data sometimes cannot - * be computed as the product of dimensions times the size of the data - * type. So users are encouraged to use this function for better code - * portability. - * Two memory descriptors can be compared with mkldnn_memory_desc_equal(). - * The comparison is especially useful when checking whether a primitive - * requires reorder from the user's data format to the primitive's format. - * 2. **Memory** -- an engine-specific object that handles the data and its - * description (a memory descriptor). For CPU enigne, the data handle is - * simply a pointer to @c void. The data handle can be queried using - * mkldnn_memory_get_data_handle() and set using - * mkldnn_memory_set_data_handle(). The latter function always sets the - * memory in the padding region to zero, which is the invariant maintained - * by all the primitives in Intel MKL-DNN. - * See @ref understanding_memory_formats for more details. - * A memory can be created using mkldnn_memory_create() function. - * A memory can also be queried for the underlying memory descriptor and - * engine using mkldnn_memory_get_memory_desc() and - * mkldnn_memory_get_engine() functions. - * - * Along with ordinary memory with all dimensions being positive, Intel - * MKL-DNN supports *zero-volume* memory with one or more dimensions set to - * zero. This is to support the NumPy\* convention. - * If a *zero-volume* memory is passed to a primitive, the primitive does - * not perform any computations on this memory. For example: - * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)` - * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)` - * weights would produce `(0 batch, 16 output channels, 11 height, 11 width)` - * destination (assuming strides are `1` and paddings are zero) and perform - * zero multiply-add operations. - * - Concatenation of three memories of shapes `(3, 4, 13, 13)`, - * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce - * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second - * input (however, if the user created a concatenation primitive descriptor - * with three inputs they should also provide all three memories to the - * concatenation primitive, including the one with zero second dimension). - * - However, Intel MKL-DNN would return an error when attempting to create a - * convolution with *zero-volume* memory passed for weights because such a - * convolution is not well-defined: - * ~~~ - * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3) - * ~~~ - * Should the values in the destination be zeroes or just not accessed at - * all? Moreover, backward pass w.r.t. weights in such cases is also not - * well-defined. - * - * Data handle of *zero-volume* memory is never accessed and hence can be - * unset (NULL in case of CPU engine). - * - * @sa @ref understanding_memory_formats - * @{ */ - -/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p - * data_type, and @p strides. - * - * The @p strides might be NULL, which means the order of physical dimensions - * is the same as the order of logical ones. - * - * @note The logical order of dimensions is defined by a primitive that - * consumes the memory. - */ -mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides( - mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, - mkldnn_data_type_t data_type, const mkldnn_dims_t strides); - -/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p - * data_type, and format @p tag. - * - * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define - * the appropriate memory format. In this case, the @p format_kind would be set - * to #mkldnn_format_kind_any */ -mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag( - mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, - mkldnn_data_type_t data_type, mkldnn_format_tag_t tag); - -/** Initializes a @p memory_desc for a given @p parent_memory_desc, with - * @p dims sizes and @p offsets. May fail if layout used does not allow - * obtain desired submemory. In this case consider using `extract` or `insert` - * primitive */ -mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory( - mkldnn_memory_desc_t *memory_desc, - const mkldnn_memory_desc_t *parent_memory_desc, - const mkldnn_dims_t dims, const mkldnn_dims_t offsets); - -/** Compares two memory descriptors. - * @return 1 if the descriptors are the same. - * @return 0 if the descriptors are different. - * - * Use this function to identify whether a reorder is required between the - * two memories */ -int MKLDNN_API mkldnn_memory_desc_equal( - const mkldnn_memory_desc_t *lhs, - const mkldnn_memory_desc_t *rhs); - -/** Returns the size (in bytes) that is required for given @p memory_desc */ -size_t MKLDNN_API mkldnn_memory_desc_get_size( - const mkldnn_memory_desc_t *memory_desc); - -/** Creates a memory for given @p memory_desc and @p engine. Also sets handle - * to @p native_handle. - * The @p native_handle can: - * - point to the user allocated memory, i.e. valid handle. In this case the - * library doesn't own allocated memory. - * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and - * attach memory. In this case the library owns allocated memory. - * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory. - */ -mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory, - const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine, - void *native_handle); - -/** Returns a @p memory_desc associated with @p memory. */ -mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc( - const_mkldnn_memory_t memory, - const mkldnn_memory_desc_t **memory_desc); - -/** Returns an @p engine associated with @p memory. */ -mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine( - const_mkldnn_memory_t memory, mkldnn_engine_t *engine); - -/** For a @p memory, returns the data @p handle. - * - * For the CPU engine, the data handle is a pointer to the actual data. */ -mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle( - const_mkldnn_memory_t memory, void **handle); - -/** For a @p memory, sets the data @p handle. */ -mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle( - mkldnn_memory_t memory, void *handle); - -/** Deletes a @p memory. */ -mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory); - -/** @} */ - -/** @addtogroup c_api_reorder Reorder - * A primitive to copy data between memory formats. - * @{ */ - -/** Initializes a @p reorder_primitive_desc using the description of the source - * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md) - * memory, and an @p attr attribute. - * - * Inputs: - * - input (#mkldnn_query_src_md, 0) - * - * Outputs: - * - output (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create( - mkldnn_primitive_desc_t *reorder_primitive_desc, - mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md, - mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md, - const_mkldnn_primitive_attr_t attr); - -/** @} */ - -/** @addtogroup c_api_concat Concat - * A primitive to concatenate data by arbitrary dimension. - * @{ */ - -/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n - * inputs by @p concat_dimension with resulting @p output_desc memory - * descriptor. @p output_desc can be NULL or specified with the - * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory - * format would be chosen automatically. - * - * Inputs: - * - input 0 (#mkldnn_query_src_md, 0) - * - input 1 (#mkldnn_query_src_md, 1) - * - ... - * - input @p n - 1 (#mkldnn_query_src_md, @p n - 1) - * - * Outputs: - * - output (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create( - mkldnn_primitive_desc_t *concat_primitive_desc, - const mkldnn_memory_desc_t *dst_md, - int n, int concat_dimension, - const mkldnn_memory_desc_t *src_mds, - const_mkldnn_primitive_attr_t attr, - mkldnn_engine_t engine); - -/** @} */ - -/** @addtogroup c_api_sum Sum - * A primitive to sum data. - * @{ */ - -/** Creates out-of-place @p sum_primitive_desc for sum of @p n - * inputs multiplied by scale with resulting @p output_desc memory - * descriptor. @p output_desc can be NULL or specified with the - * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory - * format would be chosen automatically. - * - * Inputs: - * - src 0 (#mkldnn_query_src_md, 0) - * - src 1 (#mkldnn_query_src_md, 1) - * - ... - * - src @p n - 1 (#mkldnn_query_src_md, @p n - 1) - * - * Outputs: - * - output (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create( - mkldnn_primitive_desc_t *sum_primitive_desc, - const mkldnn_memory_desc_t *dst_mds, - int n, const float *scales, - const mkldnn_memory_desc_t *src_mds, - const_mkldnn_primitive_attr_t attr, - mkldnn_engine_t engine); - -/** @} */ - -/** @addtogroup c_api_convolution Convolution - * A primitive to compute convolution using different algorithms. - * - * \f[dst[n][oc][oh][ow] = - * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC} - * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw] - * \cdot weights[g][oc][ic][kh][kw] - * + bias[g][oc],\f] - * - * where size of output spatial domain is given by - * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}} - * \right\rfloor + 1\f$, - * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}} - * \right\rfloor + 1\f$, - * - * and summation is carried over input channels \f$ic\f$ in - * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and - * \f$p_l, p_r\f$ are @p padding_l and @p padding_r. - * @{ */ - -/** Initializes a convolution descriptor @p conv_desc for forward propagation - * using @p prop_kind (possible values are #mkldnn_forward_training and - * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p - * padding_l, @p padding_r, and @p padding_kind. In order to create a - * convolution without bias, @p bias_desc should either be @c NULL or point to - * a descriptor with memory format kind equal to #mkldnn_format_kind_undef. - * - * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - bias (#mkldnn_query_weights_md, 1), if created with bias - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init( - mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, - mkldnn_padding_kind_t padding_kind); - -/** Initializes a dilated convolution descriptor @p conv_desc for forward - * propagation using @p prop_kind (possible values are #mkldnn_forward_training - * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, - * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. - * In order to create a dilated convolution without bias, @p bias_desc - * should either be @c NULL or point to a descriptor with memory format kind - * equals #mkldnn_format_kind_undef. - * - * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - bias (#mkldnn_query_weights_md, 1), if created with bias - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init( - mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** Initializes a convolution descriptor @p conv_desc for backward propagation - * with respect to data using @p alg_kind, memory descriptors, @p strides, @p - * padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init( - mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, - mkldnn_padding_kind_t padding_kind); - -/** Initializes a dilated convolution descriptor @p conv_desc for backward - * propagation with respect to data using @p alg_kind, memory descriptors, @p - * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init( - mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** Initializes a convolution descriptor @p conv_desc for backward propagation - * with respect to weights using @p alg_kind, memory descriptors, @p strides, - * @p padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_weights (#mkldnn_query_diff_weights_md, 0) - * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias - */ -mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init( - mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *diff_weights_desc, - const mkldnn_memory_desc_t *diff_bias_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, - mkldnn_padding_kind_t padding_kind); - -/** Initializes a convolution descriptor @p conv_desc for backward propagation - * with respect to weights using @p alg_kind, memory descriptors, @p strides, - * @p dilates @p padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_weights (#mkldnn_query_diff_weights_md, 0) - * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias - */ -mkldnn_status_t MKLDNN_API -mkldnn_dilated_convolution_backward_weights_desc_init( - mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *diff_weights_desc, - const mkldnn_memory_desc_t *diff_bias_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** @} */ - -/** @addtogroup c_api_deconvolution Deconvolution - * A primitive to compute deconvolution using different algorithms. - * - * @{ */ - - -/** Initializes a deconvolution descriptor @p deconv_desc for forward - * propagation using @p prop_kind (possible values are #mkldnn_forward_training - * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, - * @p padding_l, @p padding_r, and @p padding_kind. In order to create a - * deconvolution without bias, @p bias_desc should either be @c NULL or point to - * a descriptor with memory format kind equals #mkldnn_format_kind_undef. - * - * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - bias (#mkldnn_query_weights_md, 1), if created with bias - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init( - mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, - mkldnn_padding_kind_t padding_kind); - -/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward - * propagation using @p prop_kind (possible values are #mkldnn_forward_training - * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, - * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to - * create a dilated deconvolution without bias, @p bias_desc should either be - * @c NULL or point to a descriptor with memory format kind equal - * #mkldnn_format_kind_undef. - * - * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - bias (#mkldnn_query_weights_md, 1), if created with bias - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init( - mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** Initializes a deconvolution descriptor @p conv_desc for backward propagation - * with respect to data using @p alg_kind, memory descriptors, @p strides, @p - * padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init( - mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, - mkldnn_padding_kind_t padding_kind); - -/** Initializes a dilated deconvolution descriptor @p conv_desc for backward - * propagation with respect to data using @p alg_kind, memory descriptors, @p - * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init( - mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** Initializes a deconvolution descriptor @p conv_desc for backward propagation - * with respect to weights using @p alg_kind, memory descriptors, @p strides, - * @p padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_weights (#mkldnn_query_diff_weights_md, 0) - * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias - */ -mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init( - mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *diff_weights_desc, - const mkldnn_memory_desc_t *diff_bias_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, - mkldnn_padding_kind_t padding_kind); - -/** Initializes a dilated deconvolution descriptor @p conv_desc for backward - * propagation with respect to weights using @p alg_kind, memory descriptors, - * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_weights (#mkldnn_query_diff_weights_md, 0) - * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias - */ -mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init( - mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *diff_weights_desc, - const mkldnn_memory_desc_t *diff_bias_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** @} */ - -/** @addtogroup c_api_shuffle Shuffle - * A primitive to shuffle data along the axis. - * @{ */ - -/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind, - * memory descriptor @p data_desc, @p axis, and @p group_size. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - * - */ -mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init( - mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, - const mkldnn_memory_desc_t *data_desc, int axis, - mkldnn_dim_t group_size); - -/** Initializes a @p shuffle_desc for backward propagation using memory - * descriptor @p diff_data_desc, @p axis, and @p group_size. - * - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - * - */ -mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init( - mkldnn_shuffle_desc_t *shuffle_desc, - const mkldnn_memory_desc_t *diff_data_desc, int axis, - mkldnn_dim_t group_size); - -/** @} */ - -/** @addtogroup c_api_eltwise Eltwise - * A primitive to compute element-wise operations like parametric rectifier - * linear unit (ReLU). - * - * Both forward and backward passes support in-place operation; that is, src - * and dst point to the same memory for forward pass, and diff_dst and diff_src - * point to the same memory for backward pass. - * - * @warning Because the original src is required for backward pass, in-place - * forward pass in general cannot be applied during training. However, for some - * kinds of element-wise operations (namely ReLU with alpha parameter equals 0), - * dst and src can be interchangeable for the backward pass, which enables - * performing in-place forward even for training. - * - * @{ */ - -/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind - * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), - * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and - * @p beta parameters. - * - * @sa mkldnn_eltwise_desc_t for details. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init( - mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, - float alpha, float beta); - -/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind - * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the - * @p alpha and @p beta parameters. - * - * @sa mkldnn_eltwise_desc_t for details. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init( - mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_data_desc, - const mkldnn_memory_desc_t *data_desc, float alpha, float beta); - -/** @} */ - -/** @addtogroup c_api_softmax Softmax - * A primitive to perform softmax. - * - * \f[dst[u][c][in] = - * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])} - * {\sum\limits_{c}\{\exp(src[ou][c][in]) - * - \max\limits_{c}(src[ou][c][in])\}},\f] - * - * where \f$ou, iu\f$ are outer and inner sizes repectively, defined - * by @p data_desc.dims and @p softmax_axis. - * @{ */ - -/** Initializes a @p softmax_desc for forward propagation using @p prop_kind - * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference) - * and memory descriptor @p data_desc. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init( - mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, - const mkldnn_memory_desc_t *data_desc, int softmax_axis); - -/** Initializes a @p softmax_desc for backward propagation using memory - * descriptors @p diff_desc and @p data_desc. - * - * Inputs: - * - dst (#mkldnn_query_dst_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init( - mkldnn_softmax_desc_t *softmax_desc, - const mkldnn_memory_desc_t *diff_desc, - const mkldnn_memory_desc_t *data_desc, int softmax_axis); - -/** @} */ - -/** @addtogroup c_api_pooling Pooling - * A primitive to perform max or average pooling. - * - * Max pooling: - * \f[dst[n][oc][oh][ow] = - * \max\limits_{kw,kh} - * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f] - * - * Average pooling: - * \f[dst[n][oc][oh][ow] = - * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh} - * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f] - * - * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and - * output spatial dimensions are calculated similarly to how they are done in - * convolution. - * - * During training, max pooling requires a workspace on forward - * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to - * save indices where maximum was found. The workspace layout is opaque, and - * the indices cannot be restored from it. However, one can use backward - * pooling to perform up-sampling (used in some detection topologies). - * - * @{ */ - -/** Initializes a pooling descriptor @p pool_desc for forward propagation using - * @p prop_kind (possible values are #mkldnn_forward_training and - * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling - * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l, - * @p padding_r, and @p padding_kind. - * - * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - * - workspace (#mkldnn_query_workspace_md, 0), - * if @p alg_kind = #mkldnn_pooling_max and - * @p prop_kind = #mkldnn_forward_training - */ -mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init( - mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** Initializes a pooling descriptor @p pool_desc for backward propagation - * using @p alg_kind, memory descriptors, and pooling parameters in the spatial - * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p - * padding_kind. - * - * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - workspace (#mkldnn_query_workspace_md, 0), - * if @p alg_kind = #mkldnn_pooling_max - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init( - mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_src_desc, - const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, - const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, - const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); - -/** @} */ - -/** @addtogroup c_api_lrn LRN - * A primitive to perform local response normalization (LRN) across or within - * channels. - * - * LRN accross channels: - * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} - * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} - * (src[n][c+i][h][w])^2\right\}^{-\beta} - * src[n][c][h][w],\f] - * - * LRN within channels: - * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} - * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} - * (src[n][c][h+i][w+i])^2\right\}^{-\beta} - * src[n][c][h][w],\f] - * - * where \f$n_{l}\f$ is the @p local_size. - * - * During training, LRN might or might not require a workspace on forward - * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The - * behavior is implementation specific. Optimized implementations typically - * require a workspace and use it to save some intermediate results from the - * forward pass that accelerate computations on the backward pass. - * - * To check whether a workspace is required, query the LRN primitive descriptor - * for the workspace (#mkldnn_query_workspace_md). Success indicates that the - * workspace is required and its description will be returned. - * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd - * - * @{ */ - -/** Initializes an @p lrn_desc for forward propagation using @p prop_kind - * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), - * @p alg_kind, memory descriptor @p data_desc, and regularization - * parameters @p local_size, @p alpha, @p beta, and @p k. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - * - workspace (#mkldnn_query_workspace_md, 0), - * if the underlying implementation requires - */ -mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init( - mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, - mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, - mkldnn_dim_t local_size, float alpha, float beta, float k); - -/** Initializes an @p lrn_desc for backward propagation using @p alg_kind, - * memory descriptors @p data_desc and @p diff_data_desc, and regularization - * parameters @p local_size, @p alpha, @p beta, and @p k. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - workspace (#mkldnn_query_workspace_md, 0), - * if the underlying implementation requires - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init( - mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, - const mkldnn_memory_desc_t *diff_data_desc, - const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size, - float alpha, float beta, float k); - -/** @} */ - -/** @addtogroup c_api_batch_normalization Batch Normalization - * A primitive to perform batch normalization. - * - * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]} - * {\sqrt{\sigma[c] + eps}} + \beta[c],\f] - * - * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and, - * - * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$, - * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn} - * (src[n][c][h][w] - \mu[c])^2\f$, - * - * and @c eps is a constant to improve numerical stability. - * - * Both forward and backward passes support in-place operation; that is, src - * and dst point to the same memory for forward pass, and diff_dst and diff_src - * point to the same memory for backward pass. - * - * Batch normalization supports different flavors controlled by - * mkldnn_batch_normalization_desc_t. For example, batch normalization can - * compute the mean and variance on its own or take them as inputs. It can - * either perform scaling and shifting using gamma and beta parameters or not. - * Optionally it can also perform a fused ReLU, which in case of training would - * also require a workspace. - * - * @sa mkldnn_batch_normalization_desc_t - * @{ */ - -/** Initializes a batch normalization descriptor @p bnrm_desc for forward - * propagation using @p prop_kind (possible values are - * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor - * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit - * flags of type mkldnn_batch_normalization_desc_t. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - mean (#mkldnn_query_src_md, 1), - * if #mkldnn_use_global_stats bit-flags is set in @p flags - * - variance (#mkldnn_query_src_md, 2), - * if #mkldnn_use_global_stats bit-flags is set in @p flags - * - scale_and_shift (#mkldnn_query_weights_md, 0), - * if #mkldnn_use_scaleshift bit-flags is set in @p flags - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - * - mean (#mkldnn_query_dst_md, 1), - * if #mkldnn_use_global_stats bit-flags is not set in @p flags - * @p prop_kind = #mkldnn_forward_training - * - variance (#mkldnn_query_dst_md, 2), - * if #mkldnn_use_global_stats bit-flags is not set in @p flags - * and @p prop_kind = #mkldnn_forward_training - * - workspace (#mkldnn_query_workspace_md, 0), - * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags - * and @p prop_kind = #mkldnn_forward_training - * - * @note In-place operation is supported; that is, dst points to the same memory - * as src. - * - * @sa mkldnn_batch_normalization_desc_t - */ -mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init( - mkldnn_batch_normalization_desc_t *bnrm_desc, - mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, - float epsilon, unsigned flags); - -/** Initializes a batch normalization descriptor @p bnrm_desc for backward - * propagation with respect to data and scale-shift parameters using memory - * descriptors @p data_desc and @p diff_data_desc, normalization parameter - * @p epsilon, and @p flags set using bit flags of type - * mkldnn_batch_normalization_desc_t. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - mean (#mkldnn_query_src_md, 1) - * - variance (#mkldnn_query_src_md, 2) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - scale_and_shift (#mkldnn_query_weights_md, 0), - * if #mkldnn_use_scaleshift bit-flags is set in @p flags - * - workspace (#mkldnn_query_workspace_md, 0), - * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - * - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0), - * if #mkldnn_use_scaleshift bit-flags is set in @p flags - * and @p prop_kind = #mkldnn_backward - * - * @note in-place operation is supported, - * i.e. diff_src points to the same memory as diff_dst. - * - * @sa mkldnn_batch_normalization_desc_t - */ -mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init( - mkldnn_batch_normalization_desc_t *bnrm_desc, - mkldnn_prop_kind_t prop_kind, - const mkldnn_memory_desc_t *diff_data_desc, - const mkldnn_memory_desc_t *data_desc, - float epsilon, unsigned flags); - -/** @} */ - -/** @addtogroup c_api_inner_product Inner product - * A primitive to compute an inner product. - * - * Inner product layer is also known as fully connected layer. - * With spatial dimension: - * - * \f[dst[n][oc] = \sum\limits_{ic, kh, kw} - * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw] - * + bias[oc]\f] - * @{ */ - -/** Initializes an inner product descriptor @p ip_desc for forward propagation - * using @p prop_kind (possible values are #mkldnn_forward_training and - * #mkldnn_forward_inference) and memory descriptors. In order to create an - * inner product without bias, @p bias_desc should be either @c NULL or a - * pointer to a descriptor with memory format kind equals - * #mkldnn_format_kind_undef. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - bias (#mkldnn_query_weights_md, 1), if created with bias - * - * Outputs: - * - dst (#mkldnn_query_dst_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init( - mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, - const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_desc); - -/** Initializes an inner product descriptor @p ip_desc for backward propagation - * with respect to data using memory descriptors. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - weights (#mkldnn_query_weights_md, 0) - * - * Outputs: - * - diff_src (#mkldnn_query_diff_src_md, 0) - */ -mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init( - mkldnn_inner_product_desc_t *ip_desc, - const mkldnn_memory_desc_t *diff_src_desc, - const mkldnn_memory_desc_t *weights_desc, - const mkldnn_memory_desc_t *diff_dst_desc); - -/** Initializes an inner product descriptor @p ip_desc for backward propagation - * with respect to weights using memory descriptors. - * - * @note Memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src (#mkldnn_query_src_md, 0) - * - diff_dst (#mkldnn_query_diff_dst_md, 0) - * - * Outputs: - * - diff_weights (#mkldnn_query_diff_weights_md, 0) - * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias - */ -mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init( - mkldnn_inner_product_desc_t *ip_desc, - const mkldnn_memory_desc_t *src_desc, - const mkldnn_memory_desc_t *diff_weights_desc, - const mkldnn_memory_desc_t *diff_bias_desc, - const mkldnn_memory_desc_t *diff_dst_desc); - -/** @} */ - -/** @addtogroup c_api_rnn RNN - * A primitive to compute the common recurrent layer. - * @todo add additional description for the group - * @{ */ - -/** - * Initializes a recurrent cell descriptor @p rnn_cell_desc - * using @p rnn_cell_desc, @p kind (possible values are - * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and - * #mkldnn_gru_linear_before_reset), - * @p f (possible values are #mkldnn_eltwise_relu and - * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping. - */ -mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init( - mkldnn_rnn_cell_desc_t *rnn_cell_desc, - mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, - unsigned int flags, float alpha, float clipping); - -/** Returns the number of gates of a particular @p rnn_cell_desc. */ -int MKLDNN_API mkldnn_rnn_cell_get_gates_count( - const mkldnn_rnn_cell_desc_t *rnn_cell_desc); - -/** Returns the number of states of a particular @p rnn_cell_desc. */ -int MKLDNN_API mkldnn_rnn_cell_get_states_count( - const mkldnn_rnn_cell_desc_t *rnn_cell_desc); - -/** Sets quantization @p scale and @p shift for RNN data tensors. - * For performance reasons, low precision configuration of RNN primitive - * expects input activations to have unsigned int8 data type. Scale and shift - * used to quantize floating point data to unsigned integer must be passed to - * RNN primitive using attributes. - * Example usage: - * @code - * // rnn parameters - * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; - * // activations quantization parameters - * float scale = ..., shift = ..; - * - * mkldnn_primitive_attr_t rnn_attr; - * // create default attributes - * mkldnn_primitive_attr_create(&rnn_attr); - * - * // set scale and shift for int8 quantization of activation - * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); - * - * // create & configure rnn op_desc - * mkldnn_rnn_desc_t rnn_d; - * mkldnn_primitive_desc_t rnn_pd; - * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); - * @endcode - * @note - * Quantization scale and shift are common for src_layer, src_iter, - * dst_iter and dst_layer. - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams( - mkldnn_primitive_attr_t attr, const float scale, const float shift); - -/** Sets quantization scales @p weights_scales for RNN weights tensors. - * Low precision configuration of RNN primitive expects input weights to have - * signed int8 data type. Scales used to quantize floating point data - * to signed integer must be passed to RNN primitive using attributes. - * The @p mask argument defines correspondence between output tensor dimensions - * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use - * dedicated scaling factor for each slice of the output tensor over i-th - * dimension. Set @p mask to 0 to use common scaling factor for the whole output - * tensor. Example usage: - * @code - * // rnn parameters - * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; - * // unique output scales per output channel - * float weights_scales[dic * n_gates] = { ... }; - * // mask that specifies last two dimensions of ldigo format - * int mask = 0x3; - * - * mkldnn_primitive_attr_t attr; - * // create default attributes - * mkldnn_primitive_attr_create(&attr); - * - * // set output channel-wise weights scales - * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask, - * weights_scales); - * - * // create & configure rnn op_desc - * mkldnn_rnn_desc_t rnn_d; - * mkldnn_primitive_desc_t rnn_pd; - * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); - * @endcode - * @note - * The dimension order is always native and does not depend on the actual - * layout used. For example, 5 dimensional weights always have - * (l, d, i, g, o) logical dimension ordering. - * @note - * Quantization sales are common for weights_layer and weights_iteration - * @note - * There is no way to check that @p count corresponds to @p mask until an - * actual primitive descriptor is created, so it is user's responsibility - * to set proper values. The following formula must be held: - * - * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] - */ -mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams ( - mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, - const float *weights_scales); - -/** Initializes a rnn descriptor @p rnn_desc for forward propagation - * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. - * @note If @p prop_kind equals #mkldnn_forward_training, you must query a - * workspace memory descriptor before creating the primitive. - * - * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be - * @c NULL or point to a zero memory descriptor, which would indicate that the - * RNN primitive should not use them. - * - * @note All memory descriptors except @p src_iter_desc are allowed to be - * initialized with #mkldnn_format_kind_any value of @p format_kind. - * - * Inputs: - * - src_layer (#mkldnn_query_src_md, 0) - * - src_iter (#mkldnn_query_src_md, 1), if used - * - weights_layer (#mkldnn_query_weights_md, 0) - * - weights_iter (#mkldnn_query_weights_md, 1) - * - bias (#mkldnn_query_weights_md, 2), if used - * - * Outputs: - * - dst_layer (#mkldnn_query_dst_md, 0) - * - dst_iter (#mkldnn_query_dst_md, 1), if used - * - workspace (#mkldnn_query_workspace_md, 0), - * if @p prop_kind equals #mkldnn_forward_training - */ -mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init( - mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, - const mkldnn_rnn_cell_desc_t *rnn_cell_desc, - const mkldnn_rnn_direction_t direction, - const mkldnn_memory_desc_t *src_layer_desc, - const mkldnn_memory_desc_t *src_iter_desc, - const mkldnn_memory_desc_t *weights_layer_desc, - const mkldnn_memory_desc_t *weights_iter_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_layer_desc, - const mkldnn_memory_desc_t *dst_iter_desc); - -/** Initializes a rnn descriptor @p rnn_desc for backward propagation - * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. - * - * @note All memory descriptors are allowed to be initialized with - * #mkldnn_format_kind_any value of @p format_kind. - * - * @p src_iter_desc (simultaneously with @p diff_src_iter_desc), - * @p bias_desc (simultaneously with @p diff_bias_desc), and - * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to - * either be @c NULL or point to a zero memory descriptor, which would indicate - * that the RNN primitive should not use them. - * - * Inputs: - * - src_layer (#mkldnn_query_src_md, 0) - * - src_iter (#mkldnn_query_src_md, 1), if used - * - weights_layer (#mkldnn_query_weights_md, 0) - * - weights_iter (#mkldnn_query_weights_md, 1) - * - bias (#mkldnn_query_weights_md, 2), if used - * - dst_layer (#mkldnn_query_dst_md, 0) - * - dst_iter (#mkldnn_query_dst_md, 1), if used - * - diff_dst_layer (#mkldnn_query_diff_dst_md, 0) - * - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used - * - workspace (#mkldnn_query_workspace_md, 0) - * - * Outputs: - * - diff_src_layer (#mkldnn_query_diff_src_md, 0) - * - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used - * - diff_weights_layer (#mkldnn_query_diff_weights_md, 0) - * - diff_weights_iter (#mkldnn_query_diff_weights_md, 1) - * - diff_bias (#mkldnn_query_diff_weights_md, 2), if used - */ -mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init( - mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, - const mkldnn_rnn_cell_desc_t *rnn_cell_desc, - const mkldnn_rnn_direction_t direction, - const mkldnn_memory_desc_t *src_layer_desc, - const mkldnn_memory_desc_t *src_iter_desc, - const mkldnn_memory_desc_t *weights_layer_desc, - const mkldnn_memory_desc_t *weights_iter_desc, - const mkldnn_memory_desc_t *bias_desc, - const mkldnn_memory_desc_t *dst_layer_desc, - const mkldnn_memory_desc_t *dst_iter_desc, - const mkldnn_memory_desc_t *diff_src_layer_desc, - const mkldnn_memory_desc_t *diff_src_iter_desc, - const mkldnn_memory_desc_t *diff_weights_layer_desc, - const mkldnn_memory_desc_t *diff_weights_iter_desc, - const mkldnn_memory_desc_t *diff_bias_desc, - const mkldnn_memory_desc_t *diff_dst_layer, - const mkldnn_memory_desc_t *diff_dst_iter_desc); - -/** @} */ - -/** @} */ - -/** @addtogroup c_api_engine Engine operations - * @{ */ - -/** Returns the number of engines of a particular @p kind. */ -size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind); - -/** Creates an @p engine of particular @p kind and @p index. */ -mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, - mkldnn_engine_kind_t kind, size_t index); - -/** Returns the kind of an @p engine. */ -mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine, - mkldnn_engine_kind_t *kind); - -/** Destroys an @p engine. */ -mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine); - -/** @} */ - -/** @addtogroup c_api_stream Execution stream operations - * @{ */ - -/** Creates an execution @p stream for @p engine and with @p flags. */ -mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, - mkldnn_engine_t engine, unsigned flags); - -/** Destroys an execution @p stream. */ -mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream); - -/** @} */ - -/** @addtogroup c_api_service Service functions - * @{ */ - -/** Sets verbosity level (print information to stdout). - * Possible levels are: - * - 0 -- no verbose output (default) - * - 1 -- primitive information at execution - * - 2 -- primitive information at creation and execution - * - * @note - * Dumping information might affect performance. - * This setting overrides the MKLDNN_VERBOSE environment variable. */ -mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level); - -/** Enables or disables dumping of JIT-generated code. - * The enable parameter can be: - * - 0 -- disable - * - any other value -- enable - * - * @note - * This setting overrides the MKLDNN_JIT_DUMP environment variable. */ -mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable); - -/** Gets library version information. - * Version information includes: - * - major -- major version number - * - minor -- minor version number - * - patch -- patch release number - * - hash -- git commit hash */ -const mkldnn_version_t MKLDNN_API *mkldnn_version(); - -/** @} */ - -/** @addtogroup c_api_blas BLAS functions - * A subset of Basic Linear ALgebra (BLAS) functions to perform - * matrix-matrix multiplication. - * @{ */ - -/** SGEMM performs a matrix-matrix multiplication operation defined as - * - * C := alpha*op( A )*op( B ) + beta*C - * - * where - * - op( X ) is one of op( X ) = X or op( X ) = X**T, - * - alpha and beta are scalars, - * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix - * and C an m by n matrix. - * - * The matrices are assumed to be stored in column-major order (the elements - * in a matrix columns are contiguous in memory). - * - * @note - * The API is different from the standard BLAS routine - * because it returns mkldnn_status_t for error handling. - * XERBLA is not supported: no error message will be printed - * in case of incorrect parameters. */ -mkldnn_status_t MKLDNN_API mkldnn_sgemm( - const char *transa, const char *transb, - const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, - const float *alpha, const float *A, const mkldnn_dim_t *lda, - const float *B, const mkldnn_dim_t *ldb, - const float *beta, float *C, const mkldnn_dim_t *ldc); - -/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication - * operation and add the result to a scalar-matrix product. For the final - * result, a vector is added to each row or column of the output matrix. - * The operation is defined as: - * - * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset - * - * where - * - op( X ) = X or op( X ) = X**T, - * - A_offset is an m-by-k matrix with every element equal to the value oa, - * - B_offset is an k-by-n matrix with every element equal to the value ob, - * - C_offset is an m-by-n matrix defined by the oc array, size len: - * - if offsetc = F: len must be at least 1 - * - if offsetc = C: len must be at least max(1, m) - * - if offsetc = R: len must be at least max(1, n) - * - alpha and beta are scalars, and A, B and C are matrices, with op( A ) - * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix. - * - * The matrices are assumed to be stored in column-major order (the elements - * in a matrix columns are contiguous in memory). - * - * @note - * The API is different compared with the standard BLAS routine - * because it returns mkldnn_status_t for error handling. - * XERBLA is not supported: no error message will be printed - * in case of incorrect parameters. */ -mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32( - const char *transa, const char *transb, const char *offsetc, - const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, - const float *alpha, - const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, - const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, - const float *beta, - int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); - -mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32( - const char *transa, const char *transb, const char *offsetc, - const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, - const float *alpha, - const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, - const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, - const float *beta, - int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); -/** @} */ - -/** @} */ - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp deleted file mode 100644 index 581400a01..000000000 --- a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp +++ /dev/null @@ -1,2615 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_HPP -#define MKLDNN_HPP - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -#include -#include -#include -#include -#include -#include - -#include "mkldnn.h" -#endif - -namespace mkldnn { - -/// @addtogroup cpp_api C++ API -/// @{ - -/// @addtogroup cpp_api_utils Utils -/// @{ - -/// A class that provides the destructor for an Intel(R) MKL-DNN C handle -template class handle_traits {}; - -/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base -/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and -/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class -/// can be passed by value. This class enables wrapping: -/// - Newly constructed handles. -/// @n In this case, the constructed handle uses reference counting provided -/// by @p std::shared_ptr with a proper deleter function specified through -/// the @p handle_traits class. -/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for -/// example, through mkldnn_primitive_get_primitive_desc()). -/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a -/// deleter because it is assumed that the handle wrapper for the original -/// object deletes the handle (this model is similar to @p std::weak_ptr). -template > class handle { -private: - std::shared_ptr::type> _data; - handle(const handle &&) = delete; - handle &operator=(const handle &&other) = delete; -protected: - bool operator==(const T other) const { return other == _data.get(); } - bool operator!=(const T other) const { return !(*this == other); } -public: - /// Constructs a C handle wrapper. - /// @param t The C handle to wrap. - /// @param weak A flag to specify whether to construct a weak wrapper. - handle(T t = 0, bool weak = false): _data(0) { - reset(t, weak); - } - - handle(const handle &other): _data(other._data) {} - handle &operator=(const handle &other) { - _data = other._data; - return *this; - } - /// Resets the value of a C handle. - /// @param t The new value of the C handle. - /// @param weak A flag to specify whether the wrapper should be weak. - void reset(T t, bool weak = false) { - auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; - _data.reset(t, weak ? dummy_destructor : traits::destructor); - } - - /// Returns the value of the underlying C handle. - T get() const { return _data.get(); } - - bool operator==(const handle &other) const { return other._data.get() == _data.get(); } - bool operator!=(const handle &other) const { return !(*this == other); } -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_memory_destroy; -}; - -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_primitive_desc_destroy; -}; - -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_primitive_destroy; -}; - -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; -}; -#endif - -struct memory; -struct primitive_desc; - -/// Base class for all computational primitives. -class primitive: public handle { - friend struct error; - friend struct stream; - using handle::handle; -public: - /// A proxy to C primitive kind enum - enum class kind { - undefined_primitive = mkldnn_undefined_primitive, - reorder = mkldnn_reorder, - concat = mkldnn_concat, - sum = mkldnn_sum, - convolution = mkldnn_convolution, - deconvolution = mkldnn_deconvolution, - shuffle = mkldnn_shuffle, - eltwise = mkldnn_eltwise, - softmax = mkldnn_softmax, - pooling = mkldnn_pooling, - lrn = mkldnn_lrn, - batch_normalization = mkldnn_batch_normalization, - inner_product = mkldnn_inner_product, - rnn = mkldnn_rnn, - }; - - primitive(const_mkldnn_primitive_desc_t c_pd); - primitive(const primitive_desc &pd); - - /// Returns the descriptor of the underlying C API primitive. - inline const_mkldnn_primitive_desc_t get_primitive_desc() const; - // TODO: use the C++ API wrapper structure. - - void execute(struct stream &astream, - const std::unordered_map &args) const; -}; - -inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { - return static_cast(akind); -} -/// Intel(R) MKL-DNN exception class. -/// -/// This class captures the status returned by the failed C API function, error -/// message, and, optionally, handle of the primitive that caused the error. -struct error: public std::exception { - mkldnn_status_t status; - const char *message; - - /// Constructs an error instance. - /// - /// @param astatus The error status returned by the C API. - /// @param amessage The error message. - error(mkldnn_status_t astatus, const char *amessage) - : status(astatus), message(amessage) {} - - /// A convenience function for wrapping calls to the C API. Checks the - /// return status and throws an #error in case of failure. - /// - /// @param status The error status returned by the C API. - /// @param message The error message. - static void wrap_c_api(mkldnn_status_t status, const char *message) { - if (status != mkldnn_success) - throw error(status, message); - } -}; - -const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { - const_mkldnn_primitive_desc_t pd; - error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), - "could not get primitive descriptor by primitive"); - return pd; -} -/// @} - -/// @addtogroup cpp_api_enums Common data types and enumerations -/// A proxy to @ref c_api_types in @ref c_api. -/// -/// @{ - -enum scratchpad_mode { - scratchpad_mode_library = mkldnn_scratchpad_mode_library, - scratchpad_mode_user = mkldnn_scratchpad_mode_user, -}; - -inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { - return static_cast(mode); -} - -enum padding_kind { - zero = mkldnn_padding_zero -}; - -inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { - return static_cast(kind); -} - -enum prop_kind { - forward_training = mkldnn_forward_training, - forward_scoring = mkldnn_forward_scoring, - forward_inference = mkldnn_forward_inference, - forward = mkldnn_forward, - backward = mkldnn_backward, - backward_data = mkldnn_backward_data, - backward_weights = mkldnn_backward_weights, - backward_bias = mkldnn_backward_bias -}; - -inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { - return static_cast(kind); -} - -enum algorithm { - algorithm_undef = mkldnn_alg_kind_undef, - convolution_auto = mkldnn_convolution_auto, - convolution_direct = mkldnn_convolution_direct, - convolution_winograd = mkldnn_convolution_winograd, - deconvolution_direct = mkldnn_deconvolution_direct, - deconvolution_winograd = mkldnn_deconvolution_winograd, - eltwise_relu = mkldnn_eltwise_relu, - eltwise_tanh = mkldnn_eltwise_tanh, - eltwise_elu = mkldnn_eltwise_elu, - eltwise_square = mkldnn_eltwise_square, - eltwise_abs = mkldnn_eltwise_abs, - eltwise_sqrt = mkldnn_eltwise_sqrt, - eltwise_linear = mkldnn_eltwise_linear, - eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, - eltwise_soft_relu = mkldnn_eltwise_soft_relu, - eltwise_logistic = mkldnn_eltwise_logistic, - lrn_across_channels = mkldnn_lrn_across_channels, - lrn_within_channel = mkldnn_lrn_within_channel, - pooling_max = mkldnn_pooling_max, - pooling_avg = mkldnn_pooling_avg, - pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, - pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, - vanilla_rnn = mkldnn_vanilla_rnn, - vanilla_lstm = mkldnn_vanilla_lstm, - vanilla_gru = mkldnn_vanilla_gru, - gru_linear_before_reset = mkldnn_gru_linear_before_reset -}; - -inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { - return static_cast(aalgorithm); -} - -enum batch_normalization_flag { - use_global_stats = mkldnn_use_global_stats, - use_scale_shift = mkldnn_use_scaleshift, - fuse_bn_relu = mkldnn_fuse_bn_relu -}; - -inline mkldnn_batch_normalization_flag_t convert_to_c( - batch_normalization_flag aflag) { - return static_cast(aflag); -} - -enum rnn_direction { - unidirectional_left2right = mkldnn_unidirectional_left2right, - unidirectional_right2left = mkldnn_unidirectional_right2left, - unidirectional = mkldnn_unidirectional, - bidirectional_concat = mkldnn_bidirectional_concat, - bidirectional_sum = mkldnn_bidirectional_sum, -}; - -inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { - return static_cast(adir); -} - -enum query { - undef = mkldnn_query_undef, - - query_engine = mkldnn_query_engine, - primitive_kind = mkldnn_query_primitive_kind, - - num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, - num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, - - time_estimate_f64 = mkldnn_query_time_estimate_f64, - memory_consumption_s64 = mkldnn_query_memory_consumption_s64, - - query_scratchpad_engine = mkldnn_query_scratchpad_engine, - - impl_info_str = mkldnn_query_impl_info_str, - - op_d = mkldnn_query_op_d, - convolution_d = mkldnn_query_convolution_d, - deconvolution_d = mkldnn_query_deconvolution_d, - shuffle_d = mkldnn_query_shuffle_d, - eltwise_d = mkldnn_query_eltwise_d, - softmax_d = mkldnn_query_softmax_d, - pooling_d = mkldnn_query_pooling_d, - lrn_d = mkldnn_query_lrn_d, - batch_normalization_d = mkldnn_query_batch_normalization_d, - inner_product_d = mkldnn_query_inner_product_d, - rnn_d = mkldnn_query_rnn_d, - - src_md = mkldnn_query_src_md, - diff_src_md = mkldnn_query_diff_src_md, - weights_md = mkldnn_query_weights_md, - diff_weights_md = mkldnn_query_diff_weights_md, - dst_md = mkldnn_query_dst_md, - diff_dst_md = mkldnn_query_diff_dst_md, - workspace_md = mkldnn_query_workspace_md, - scratchpad_md = mkldnn_query_scratchpad_md, -}; - -inline mkldnn_query_t convert_to_c(query aquery) { - return static_cast(aquery); -} - -/// @} - -/// @addtogroup cpp_api_attr Attributes -/// An extension for controlling primitive behavior. -/// -/// @sa @ref c_api_attributes in @ref c_api -/// @{ - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_post_ops_destroy; -}; -#endif - -struct post_ops: public handle { - post_ops() { - mkldnn_post_ops_t result; - error::wrap_c_api(mkldnn_post_ops_create(&result), - "could not create post operation sequence"); - reset(result); - } - - int len() const { return mkldnn_post_ops_len(get()); } - - primitive::kind kind(int index) const { - error::wrap_c_api( - index < len() ? mkldnn_success : mkldnn_invalid_arguments, - "post_ops index is out of range"); - return static_cast(mkldnn_post_ops_get_kind(get(), - index)); - } - - void append_sum(float scale = 1.) { - error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), - "could not append sum"); - } - - void get_params_sum(int index, float &scale) const { - error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), - "could not get sum params"); - } - - void append_eltwise(float scale, algorithm alg, float alpha, - float beta) { - error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, - convert_to_c(alg), alpha, beta), - "could not append eltwise"); - } - - void get_params_eltwise(int index, float &scale, algorithm &alg, - float &alpha, float &beta) const { - mkldnn_alg_kind_t c_alg; - error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, - &scale, &c_alg, &alpha, &beta), - "could not get eltwise params"); - alg = static_cast(c_alg); - } -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_primitive_attr_destroy; -}; -#endif - -struct primitive_attr: public handle { - primitive_attr() { - mkldnn_primitive_attr_t result; - error::wrap_c_api(mkldnn_primitive_attr_create(&result), - "could not create a primitive attr"); - reset(result); - } - - scratchpad_mode get_scratchpad_mode() const { - mkldnn_scratchpad_mode_t result; - error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( - get(), &result), "could not get scratchpad mode"); - return scratchpad_mode(result); - } - - void set_scratchpad_mode(scratchpad_mode mode) { - error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( - get(), mkldnn::convert_to_c(mode)), - "could not set scratchpad mode"); - } - - void get_output_scales(int &mask, std::vector &scales) const - { - mkldnn_dim_t count; - int c_mask; - const float *c_scales; - error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), - &count, &c_mask, &c_scales), - "could not get int output scales"); - scales.resize(count); - - mask = c_mask; - for (mkldnn_dim_t c = 0; c < count; ++c) - scales[c] = c_scales[c]; - } - - void set_output_scales(int mask, const std::vector &scales) - { - error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), - (mkldnn_dim_t)scales.size(), mask, &scales[0]), - "could not set int output scales"); - } - - const post_ops get_post_ops() const { - post_ops result; - const_mkldnn_post_ops_t c_result; - error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), - "could not get post operation sequence"); - result.reset(const_cast(c_result), true); - return result; - } - - void set_post_ops(post_ops ops) { - error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), - "could not set post operation sequence"); - } - - void set_rnn_data_qparams(const float scale, const float shift) - { - error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), - scale, shift), "could not set rnn data int scale/shift"); - } - - void set_rnn_weights_qparams(int mask, const std::vector &scales) - { - error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), - (int)scales.size(), mask, &scales[0]), - "could not set rnn weights int scales"); - } -}; - -/// @} - -/// @addtogroup cpp_api_engine Engine -/// Engine operations. -/// -/// @sa @ref c_api_engine in @ref c_api -/// @{ - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_engine_destroy; -}; -#endif - -/// An execution engine. -struct engine: public handle { - friend class primitive; - // gcc bug??? using handle::handle; - - /// Kinds of engines. - enum kind { - /// An unspecified engine - any = mkldnn_any_engine, - /// CPU engine - cpu = mkldnn_cpu, - }; - - /// Returns the number of engines of a certain kind. - /// - /// @param akind The kind of engines to count. - - static size_t get_count(kind akind) { - return mkldnn_engine_get_count(convert_to_c(akind)); - } - - /// Constructs an engine. - /// - /// @param akind The kind of engine to construct. - /// @param index The index of the engine. Must be less than the value - /// returned by #get_count() for this particular kind of engine. - - engine(kind akind, size_t index) { - mkldnn_engine_t aengine; - error::wrap_c_api( - mkldnn_engine_create(&aengine, - convert_to_c(akind), index), - "could not create an engine"); - reset(aengine); - } - - explicit engine(const mkldnn_engine_t& aengine) - : handle(aengine, true) {} - - engine(const handle &pd) { - mkldnn_engine_t engine_q; - error::wrap_c_api( - mkldnn_primitive_desc_query(pd.get(), - mkldnn::convert_to_c(query_engine), 0, &engine_q), - "could not get engine from primitive_desc"); - reset(engine_q, true); - } - - template - static engine query(const primitive_desc &pd) { - mkldnn_engine_t engine_q; - error::wrap_c_api( - mkldnn_primitive_desc_query(pd.get(), - mkldnn::convert_to_c(query_engine), 0, &engine_q), - "could not get engine from primitive_desc"); - - return engine(engine_q); - } - -private: - static mkldnn_engine_kind_t convert_to_c(kind akind) { - return static_cast(akind); - } -}; - -/// @} - -/// @addtogroup cpp_api_stream Stream -/// Execution stream operations -/// -/// @sa @ref c_api_stream in @ref c_api -/// @{ - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template <> struct handle_traits { - static constexpr auto destructor = &mkldnn_stream_destroy; -}; -#endif - -struct stream: public handle { - using handle::handle; - - enum: unsigned { - default_flags = mkldnn_stream_default_flags, - }; - - /// Constructs a stream. - stream(const engine &aengine, - unsigned flags = static_cast(default_flags)) { - mkldnn_stream_t astream; - error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), - "could not create a stream"); - reset(astream); - } -}; - -/// @} - -/// @addtogroup cpp_api_memory_related Memory and memory related operations -/// @{ - -/// @addtogroup cpp_api_memory Memory -/// A primitive to describe and store data. -/// -/// For more information, refer to @ref c_api_memory in @ref c_api. -/// @{ - -/// Memory that describes the data. -struct memory: public handle { - public: - typedef mkldnn_dim_t dim; - typedef std::vector dims; - - template static void validate_dims(const std::vector &v) { - if (v.size() > MKLDNN_MAX_NDIMS) - throw error(mkldnn_invalid_arguments, "invalid dimensions"); - } - - /// Data type specification. See #mkldnn_data_type_t for a detailed - /// description. - enum data_type { - data_undef = mkldnn_data_type_undef, - f32 = mkldnn_f32, - s32 = mkldnn_s32, - s8 = mkldnn_s8, - u8 = mkldnn_u8, - }; - - /// Memory format tag specification. See #mkldnn_format_tag_t - /// for a detailed description. - enum format_tag { - format_tag_undef = mkldnn_format_tag_undef, - any = mkldnn_format_tag_any, - a = mkldnn_a, - ab = mkldnn_ab, - abc = mkldnn_abc, - abcd = mkldnn_abcd, - abcde = mkldnn_abcde, - abcdef = mkldnn_abcdef, - abdec = mkldnn_abdec, - acb = mkldnn_acb, - acbde = mkldnn_acbde, - acdb = mkldnn_acdb, - acdeb = mkldnn_acdeb, - ba = mkldnn_ba, - bac = mkldnn_bac, - bacd = mkldnn_bacd, - bcda = mkldnn_bcda, - cba = mkldnn_cba, - cdba = mkldnn_cdba, - cdeba = mkldnn_cdeba, - decab = mkldnn_decab, - Abc16a = mkldnn_Abc16a, - ABc16a16b = mkldnn_ABc16a16b, - aBc16b = mkldnn_aBc16b, - ABc16b16a = mkldnn_ABc16b16a, - Abc4a = mkldnn_Abc4a, - aBc4b = mkldnn_aBc4b, - ABc4b16a4b = mkldnn_ABc4b16a4b, - ABc4b4a = mkldnn_ABc4b4a, - ABc8a16b2a = mkldnn_ABc8a16b2a, - ABc8a8b = mkldnn_ABc8a8b, - aBc8b = mkldnn_aBc8b, - ABc8b16a2b = mkldnn_ABc8b16a2b, - ABc8b8a = mkldnn_ABc8b8a, - Abcd16a = mkldnn_Abcd16a, - ABcd16a16b = mkldnn_ABcd16a16b, - aBcd16b = mkldnn_aBcd16b, - ABcd16b16a = mkldnn_ABcd16b16a, - aBCd16b16c = mkldnn_aBCd16b16c, - aBCd16c16b = mkldnn_aBCd16c16b, - Abcd4a = mkldnn_Abcd4a, - aBcd4b = mkldnn_aBcd4b, - ABcd4b16a4b = mkldnn_ABcd4b16a4b, - ABcd4b4a = mkldnn_ABcd4b4a, - aBCd4c16b4c = mkldnn_aBCd4c16b4c, - aBCd4c4b = mkldnn_aBCd4c4b, - ABcd8a16b2a = mkldnn_ABcd8a16b2a, - ABcd8a8b = mkldnn_ABcd8a8b, - aBcd8b = mkldnn_aBcd8b, - ABcd8b16a2b = mkldnn_ABcd8b16a2b, - aBCd8b16c2b = mkldnn_aBCd8b16c2b, - ABcd8b8a = mkldnn_ABcd8b8a, - aBCd8b8c = mkldnn_aBCd8b8c, - aBCd8c16b2c = mkldnn_aBCd8c16b2c, - aBCd8c8b = mkldnn_aBCd8c8b, - Abcde16a = mkldnn_Abcde16a, - ABcde16a16b = mkldnn_ABcde16a16b, - aBcde16b = mkldnn_aBcde16b, - ABcde16b16a = mkldnn_ABcde16b16a, - aBCde16b16c = mkldnn_aBCde16b16c, - aBCde16c16b = mkldnn_aBCde16c16b, - aBCde2c8b4c = mkldnn_aBCde2c8b4c, - Abcde4a = mkldnn_Abcde4a, - aBcde4b = mkldnn_aBcde4b, - ABcde4b4a = mkldnn_ABcde4b4a, - aBCde4b4c = mkldnn_aBCde4b4c, - aBCde4c16b4c = mkldnn_aBCde4c16b4c, - aBCde4c4b = mkldnn_aBCde4c4b, - Abcde8a = mkldnn_Abcde8a, - ABcde8a8b = mkldnn_ABcde8a8b, - aBcde8b = mkldnn_aBcde8b, - ABcde8b16a2b = mkldnn_ABcde8b16a2b, - aBCde8b16c2b = mkldnn_aBCde8b16c2b, - ABcde8b8a = mkldnn_ABcde8b8a, - aBCde8b8c = mkldnn_aBCde8b8c, - aBCde8c16b2c = mkldnn_aBCde8c16b2c, - aBCde8c8b = mkldnn_aBCde8c8b, - aBcdef16b = mkldnn_aBcdef16b, - aBCdef16b16c = mkldnn_aBCdef16b16c, - aBCdef16c16b = mkldnn_aBCdef16c16b, - aBcdef4b = mkldnn_aBcdef4b, - aBCdef4c4b = mkldnn_aBCdef4c4b, - aBCdef8b8c = mkldnn_aBCdef8b8c, - aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, - aBCdef8c8b = mkldnn_aBCdef8c8b, - aBdc16b = mkldnn_aBdc16b, - aBdc4b = mkldnn_aBdc4b, - aBdc8b = mkldnn_aBdc8b, - aBdec16b = mkldnn_aBdec16b, - aBdec4b = mkldnn_aBdec4b, - aBdec8b = mkldnn_aBdec8b, - aBdefc16b = mkldnn_aBdefc16b, - aBdefc4b = mkldnn_aBdefc4b, - aBdefc8b = mkldnn_aBdefc8b, - Acb16a = mkldnn_Acb16a, - Acb4a = mkldnn_Acb4a, - Acb8a = mkldnn_Acb8a, - aCBd16b16c = mkldnn_aCBd16b16c, - aCBde16b16c = mkldnn_aCBde16b16c, - Acdb16a = mkldnn_Acdb16a, - Acdb4a = mkldnn_Acdb4a, - Acdb8a = mkldnn_Acdb8a, - Acdeb16a = mkldnn_Acdeb16a, - Acdeb4a = mkldnn_Acdeb4a, - Acdeb8a = mkldnn_Acdeb8a, - BAc16a16b = mkldnn_BAc16a16b, - BAcd16a16b = mkldnn_BAcd16a16b, - format_tag_last = mkldnn_format_tag_last, - - x = mkldnn_x, - nc = mkldnn_nc, - cn = mkldnn_cn, - ncw = mkldnn_ncw, - nwc = mkldnn_nwc, - nchw = mkldnn_nchw, - nhwc = mkldnn_nhwc, - chwn = mkldnn_chwn, - ncdhw = mkldnn_ncdhw, - ndhwc = mkldnn_ndhwc, - oi = mkldnn_oi, - io = mkldnn_io, - oiw = mkldnn_oiw, - wio = mkldnn_wio, - oihw = mkldnn_oihw, - hwio = mkldnn_hwio, - ihwo = mkldnn_ihwo, - iohw = mkldnn_iohw, - oidhw = mkldnn_oidhw, - dhwio = mkldnn_dhwio, - goiw = mkldnn_goiw, - goihw = mkldnn_goihw, - hwigo = mkldnn_hwigo, - giohw = mkldnn_giohw, - goidhw = mkldnn_goidhw, - tnc = mkldnn_tnc, - ntc = mkldnn_ntc, - ldsnc = mkldnn_ldsnc, - ldigo = mkldnn_ldigo, - ldgoi = mkldnn_ldgoi, - ldgo = mkldnn_ldgo, - nCdhw16c = mkldnn_nCdhw16c, - nCdhw4c = mkldnn_nCdhw4c, - nCdhw8c = mkldnn_nCdhw8c, - nChw16c = mkldnn_nChw16c, - nChw4c = mkldnn_nChw4c, - nChw8c = mkldnn_nChw8c, - nCw16c = mkldnn_nCw16c, - nCw4c = mkldnn_nCw4c, - nCw8c = mkldnn_nCw8c, - IOw16o16i = mkldnn_IOw16o16i, - OIw16i16o = mkldnn_OIw16i16o, - OIw16o16i = mkldnn_OIw16o16i, - Oiw16o = mkldnn_Oiw16o, - OIw4i16o4i = mkldnn_OIw4i16o4i, - OIw4i4o = mkldnn_OIw4i4o, - Oiw4o = mkldnn_Oiw4o, - OIw8i16o2i = mkldnn_OIw8i16o2i, - OIw8i8o = mkldnn_OIw8i8o, - OIw8o16i2o = mkldnn_OIw8o16i2o, - OIw8o8i = mkldnn_OIw8o8i, - Owi16o = mkldnn_Owi16o, - Owi4o = mkldnn_Owi4o, - Owi8o = mkldnn_Owi8o, - IOhw16o16i = mkldnn_IOhw16o16i, - Ohwi16o = mkldnn_Ohwi16o, - Ohwi4o = mkldnn_Ohwi4o, - Ohwi8o = mkldnn_Ohwi8o, - OIhw16i16o = mkldnn_OIhw16i16o, - OIhw16o16i = mkldnn_OIhw16o16i, - Oihw16o = mkldnn_Oihw16o, - OIhw4i16o4i = mkldnn_OIhw4i16o4i, - OIhw4i4o = mkldnn_OIhw4i4o, - Oihw4o = mkldnn_Oihw4o, - OIhw8i16o2i = mkldnn_OIhw8i16o2i, - OIhw8i8o = mkldnn_OIhw8i8o, - OIhw8o16i2o = mkldnn_OIhw8o16i2o, - OIhw8o8i = mkldnn_OIhw8o8i, - Odhwi16o = mkldnn_Odhwi16o, - Odhwi4o = mkldnn_Odhwi4o, - Odhwi8o = mkldnn_Odhwi8o, - OIdhw16i16o = mkldnn_OIdhw16i16o, - OIdhw16o16i = mkldnn_OIdhw16o16i, - Oidhw16o = mkldnn_Oidhw16o, - OIdhw4i4o = mkldnn_OIdhw4i4o, - Oidhw4o = mkldnn_Oidhw4o, - OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, - OIdhw8i8o = mkldnn_OIdhw8i8o, - OIdhw8o8i = mkldnn_OIdhw8o8i, - gIOw16o16i = mkldnn_gIOw16o16i, - gOIw16i16o = mkldnn_gOIw16i16o, - gOIw16o16i = mkldnn_gOIw16o16i, - gOiw16o = mkldnn_gOiw16o, - gOIw4i16o4i = mkldnn_gOIw4i16o4i, - gOIw4i4o = mkldnn_gOIw4i4o, - gOiw4o = mkldnn_gOiw4o, - gOIw8i16o2i = mkldnn_gOIw8i16o2i, - gOIw8i8o = mkldnn_gOIw8i8o, - gOIw8o16i2o = mkldnn_gOIw8o16i2o, - gOIw8o8i = mkldnn_gOIw8o8i, - gOwi16o = mkldnn_gOwi16o, - gOwi4o = mkldnn_gOwi4o, - gOwi8o = mkldnn_gOwi8o, - gIOhw16o16i = mkldnn_gIOhw16o16i, - gOhwi16o = mkldnn_gOhwi16o, - gOhwi4o = mkldnn_gOhwi4o, - gOhwi8o = mkldnn_gOhwi8o, - Goihw16g = mkldnn_Goihw16g, - gOIhw16i16o = mkldnn_gOIhw16i16o, - gOIhw16o16i = mkldnn_gOIhw16o16i, - gOihw16o = mkldnn_gOihw16o, - gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, - gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, - gOIhw4i4o = mkldnn_gOIhw4i4o, - gOIhw4o4i = mkldnn_gOIhw4o4i, - gOihw4o = mkldnn_gOihw4o, - Goihw8g = mkldnn_Goihw8g, - gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, - gOIhw8i8o = mkldnn_gOIhw8i8o, - gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, - gOIhw8o8i = mkldnn_gOIhw8o8i, - gOdhwi16o = mkldnn_gOdhwi16o, - gOdhwi4o = mkldnn_gOdhwi4o, - gOdhwi8o = mkldnn_gOdhwi8o, - gOIdhw16i16o = mkldnn_gOIdhw16i16o, - gOIdhw16o16i = mkldnn_gOIdhw16o16i, - gOidhw16o = mkldnn_gOidhw16o, - gOIdhw4i4o = mkldnn_gOIdhw4i4o, - gOidhw4o = mkldnn_gOidhw4o, - gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, - gOIdhw8i8o = mkldnn_gOIdhw8i8o, - gOIdhw8o8i = mkldnn_gOIdhw8o8i, - }; - - /// A memory descriptor. - struct desc { - friend struct memory; - /// The underlying C API data structure. - mkldnn_memory_desc_t data; - - /// Constructs a zero memory descriptor - desc(): data() {} - - /// Constructs a memory descriptor. - /// - /// @param adims Data dimensions - /// @param adata_type Data precision/type. - /// @param aformat Data layout format tag. - desc(const dims &adims, data_type adata_type, - format_tag aformat) { - validate_dims(adims); - error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), - adims.size() == 0 ? nullptr : &adims[0], - convert_to_c(adata_type), convert_to_c(aformat)), - "could not initialize a memory descriptor"); - } - - /// Constructs a memory descriptor from a C API data structure. - /// - /// @param adata A C API #mkldnn_memory_desc_t structure. - desc(const mkldnn_memory_desc_t &adata): data(adata) {} - - /// Constructs a sub-memory descriptor - // - /// @param adims Sizes of a sub-memory - /// @param offsets Offsets of a sub-memory - desc submemory_desc(const dims &adims, const dims &offsets) { - mkldnn_memory_desc_t sub_md; - error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, - &data, &adims[0], &offsets[0]), - "could not initialize a sub-memory"); - return desc(sub_md); - } - - /// Returns the number of bytes required to allocate the memory described - /// including the padding area. - size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } - - bool operator==(const desc &other) const { - return mkldnn_memory_desc_equal(&data, &other.data) != 0; - } - - bool operator!=(const desc &other) const { return !operator==(other); } - }; - - /// Constructs a memory. - /// - /// @param md Memory descriptor. - /// @param aengine Engine. - /// @param ahandle Native handle. - memory(const desc &md, const engine &aengine, void *ahandle) { - mkldnn_memory_t result; - error::wrap_c_api(mkldnn_memory_create(&result, &md.data, - aengine.get(), ahandle), "could not create a memory"); - reset(result); - } - - /// Constructs a memory. - /// - /// @param md Memory descriptor. - /// @param aengine Engine. - memory(const desc &md, const engine &aengine) - : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} - - /// Returns the descriptor of the memory. - desc get_desc() const { - const mkldnn_memory_desc_t *cdesc; - error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), - "could not get memory descriptor from a memory"); - return desc(*cdesc); - } - - /// Returns the engine of the memory. - engine get_engine() const { - mkldnn_engine_t engine_q; - error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), - "could not get engine from a memory"); - return engine(engine_q); - } - - /// Returns a handle of the data contained in the memory. - /// - /// On the CPU engine, this is a pointer to the allocated memory. - void *get_data_handle() const { - void *handle; - error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), - "could not get native handle"); - return handle; - } - - void set_data_handle(void *handle) const { - error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), - "could not set native handle"); - } - - // Must go away or be private: - static mkldnn_data_type_t convert_to_c(data_type adata_type) { - return static_cast(adata_type); - } - static mkldnn_format_tag_t convert_to_c(format_tag aformat) { - return static_cast(aformat); - } -}; - -inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { - return a == memory::convert_to_c(b); -} -inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { - return !(a == b); -} -inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { - return b == a; -} -inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { - return !(a == b); -} - -inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { - return a == memory::convert_to_c(b); -} -inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { - return !(a == b); -} -inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { - return b == a; -} -inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { - return !(a == b); -} - -/// @} - -/// @addtogroup cpp_api_reorder Reorder -/// A primitive to copy data between memory formats. -/// -/// @sa @ref c_api_reorder in @ref c_api -/// @{ - -struct reorder : public primitive { - struct primitive_desc : public handle { - primitive_desc(const engine &src_engine, const memory::desc &src_md, - const engine &dst_engine, const memory::desc &dst_md, - const primitive_attr &aattr) { - mkldnn_primitive_desc_t result; - error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, - src_engine.get(), &src_md.data, - dst_engine.get(), &dst_md.data, aattr.get()), - "could not create a reorder primitive descriptor"); - reset(result); - } - - primitive_desc(const engine &src_engine, const memory::desc &src_md, - const engine &dst_engine, const memory::desc &dst_md) { - mkldnn_primitive_desc_t result; - error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, - src_engine.get(), &src_md.data, - dst_engine.get(), &dst_md.data, nullptr), - "could not create a reorder primitive descriptor"); - reset(result); - } - - primitive_desc(const memory &src, const memory &dst, - const primitive_attr &aattr) { - mkldnn_primitive_desc_t result; - auto src_md = src.get_desc(); - auto dst_md = dst.get_desc(); - error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, - src.get_engine().get(), &src_md.data, - dst.get_engine().get(), &dst_md.data, aattr.get()), - "could not create a reorder primitive descriptor"); - reset(result); - } - - primitive_desc(const memory &src, const memory &dst) { - mkldnn_primitive_desc_t result; - auto src_md = src.get_desc(); - auto dst_md = dst.get_desc(); - error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, - src.get_engine().get(), &src_md.data, - dst.get_engine().get(), &dst_md.data, nullptr), - "could not create a reorder primitive descriptor"); - reset(result); - } - - memory::desc scratchpad_desc() const { - const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( - get(), mkldnn::convert_to_c(scratchpad_md), 0); - if (cdesc == nullptr) - return memory::desc(); - return memory::desc(*cdesc); - } - - engine scratchpad_engine() { - mkldnn_engine_t engine_q; - error::wrap_c_api( - mkldnn_primitive_desc_query(get(), - mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), - "could not get scratchpad engine from reorder primitive_desc"); - - return engine(engine_q); - } - - engine get_engine() { return engine::query(*this); } - }; - - reorder(const primitive_desc &pd): primitive(pd.get()) {} - - reorder(const memory &src, const memory &dst): - primitive(primitive_desc(src, dst).get()) {} - - void execute(stream astream, memory &src, memory &dst) { - primitive::execute(astream, - {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); - } -}; - -/// @} - -/// @addtogroup cpp_api_concat Concat -/// A primitive to concatenate data by arbitrary dimension. -/// -/// @sa @ref c_api_concat in @ref c_api -/// @{ - -struct concat : public primitive { - struct primitive_desc : public handle { - std::vector cpp_to_c( - const std::vector &srcs) { - std::vector c_api_srcs; - c_api_srcs.reserve(srcs.size()); - for (const auto &s : srcs) c_api_srcs.push_back(s.data); - return c_api_srcs; - } - - primitive_desc(const memory::desc &dst, int concat_dimension, - const std::vector &srcs, const engine &aengine) { - auto c_api_srcs = cpp_to_c(srcs); - - mkldnn_primitive_desc_t result; - error::wrap_c_api(mkldnn_concat_primitive_desc_create( - &result, &dst.data, (int)c_api_srcs.size(), - concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), - "could not create a concat primitive descriptor"); - reset(result); - } - - primitive_desc(int concat_dimension, - const std::vector &srcs, const engine &aengine) { - auto c_api_srcs = cpp_to_c(srcs); - - mkldnn_primitive_desc_t result; - error::wrap_c_api(mkldnn_concat_primitive_desc_create( - &result, nullptr, (int)c_api_srcs.size(), - concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), - "could not create a concat primitive descriptor"); - reset(result); - } - - memory::desc dst_desc() const { - const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( - get(), mkldnn::convert_to_c(dst_md), 0); - error::wrap_c_api( - cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, - "could not get a dst memory descriptor"); - return memory::desc(*cdesc); - } - - memory::desc scratchpad_desc() const { - const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( - get(), mkldnn::convert_to_c(scratchpad_md), 0); - if (cdesc == nullptr) - return memory::desc(); - return memory::desc(*cdesc); - } - - engine get_engine() { return engine::query(*this); } - }; - - concat(const primitive_desc &pd): primitive(pd.get()) {} -}; - -/// @} - -/// @addtogroup cpp_api_sum Sum -/// A primitive to sum data. -/// -/// @sa @ref c_api_sum in @ref c_api -/// @{ - -struct sum : public primitive { - struct primitive_desc : public handle { - std::vector cpp_to_c( - const std::vector &srcs) { - std::vector c_api_srcs; - c_api_srcs.reserve(srcs.size()); - for (const auto &s : srcs) c_api_srcs.push_back(s.data); - return c_api_srcs; - } - - primitive_desc(const memory::desc &dst, - const std::vector &scales, - const std::vector &srcs, const engine &aengine) { - error::wrap_c_api(scales.size() == srcs.size() - ? mkldnn_success : mkldnn_invalid_arguments, - "number of scales not equal to number of srcs"); - - auto c_api_srcs = cpp_to_c(srcs); - - mkldnn_primitive_desc_t result; - error::wrap_c_api(mkldnn_sum_primitive_desc_create( - &result, &dst.data, (int)c_api_srcs.size(), - &scales[0], &c_api_srcs[0], nullptr, aengine.get()), - "could not create a sum primitive descriptor"); - reset(result); - } - - primitive_desc(const std::vector &scales, - const std::vector &srcs, const engine &aengine) { - error::wrap_c_api(scales.size() == srcs.size() - ? mkldnn_success : mkldnn_invalid_arguments, - "number of scales not equal to number of srcs"); - - auto c_api_srcs = cpp_to_c(srcs); - mkldnn_primitive_desc_t result; - error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, - nullptr, (int)c_api_srcs.size(), &scales[0], - &c_api_srcs[0], nullptr, aengine.get()), - "could not create a sum primitive descriptor"); - reset(result); - } - - memory::desc dst_desc() const { - const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( - get(), mkldnn::convert_to_c(dst_md), 0); - error::wrap_c_api( - cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, - "could not get a dst memory descriptor"); - return memory::desc(*cdesc); - } - - memory::desc scratchpad_desc() const { - const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( - get(), mkldnn::convert_to_c(scratchpad_md), 0); - if (cdesc == nullptr) - return memory::desc(); - return memory::desc(*cdesc); - } - - engine get_engine() { return engine::query(*this); } - }; - - sum(const primitive_desc &pd): primitive(pd.get()) {} -}; - -/// @} - -/// @} - -/// @addtogroup cpp_api_primitives Primitives -/// @{ - -/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors -/// @{ - -/// A base class for all primitive descriptors. -struct primitive_desc : public handle { - primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, - const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { - mkldnn_primitive_desc_iterator_t iterator = nullptr; - mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( - &iterator, desc, attr ? attr->get() : nullptr, e.get(), - hint_fwd_pd); - error::wrap_c_api(status, - "could not create a primitive descriptor iterator"); - pd_iterator.reset(iterator); - fetch_impl(); - } - - engine get_engine() { return engine::query(*this); } - - primitive_attr get_primitive_attr() const { - const_mkldnn_primitive_attr_t const_cattr; - error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), - "could not get attributes"); - mkldnn_primitive_attr_t cattr; - error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), - "could not clone attributes"); - - primitive_attr attr; - attr.reset(cattr); - return attr; - } - - /// Returns implementation name - const char *impl_info_str() const { - const char *res; - error::wrap_c_api(mkldnn_primitive_desc_query(get(), - mkldnn_query_impl_info_str, 0, &res), - "could not query implementation info string"); - return res; - } - - /// Queries the memory::dim value (same as int64_t) - memory::dim query_s64(query q) const { - memory::dim res; - mkldnn_status_t status = mkldnn_primitive_desc_query(get(), - mkldnn::convert_to_c(q), 0, &res); - return status == mkldnn_success ? res : 0; - } - - /// Advances the next implementation for the given op descriptor. - /// - /// Returns: - /// - @c true on success - /// - @c false if the last implementation reached, and - /// the primitive descriptor itself is kept unchanged - bool next_impl() { - mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( - pd_iterator.get()); - if (status == mkldnn_iterator_ends) return false; - error::wrap_c_api(status, "primitive descriptor iterator next failed"); - - fetch_impl(); - return true; - } - - /// Queries and returns requested memory descriptor. - memory::desc query_md(query what, int idx = 0) const { - std::vector valid_q{src_md, diff_src_md, weights_md, - diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; - if (!std::any_of(valid_q.cbegin(), valid_q.cend(), - [=](query q) { return what == q; })) - throw error(mkldnn_invalid_arguments, "invalid memory query"); - - const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( - get(), mkldnn::convert_to_c(what), idx); - if (cdesc == nullptr) return memory::desc(); - - return memory::desc(*cdesc); - } - - // register specialized queries, e.g. src_desc() -# define REG_QUERY_MD(name, what, idx) \ - memory::desc name ## _desc() const { return query_md(what ## _md, idx); } - - private: - handle pd_iterator; - void fetch_impl() { - mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( - pd_iterator.get()); - error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, - "could not fetch a primitive descriptor from the iterator"); - reset(pd); - } -}; - -/// @} - -/// @addtogroup cpp_api_convolution Convolution -/// A primitive to compute convolution using different algorithms. -/// -/// @sa @ref c_api_convolution in @ref c_api -/// @{ - -struct convolution_forward: public primitive { - struct desc { - mkldnn_convolution_desc_t data; - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &bias_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, &bias_desc.data, - &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution forward descriptor"); - } - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, nullptr, - &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution forward descriptor"); - } - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &bias_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api( - mkldnn_dilated_convolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, &bias_desc.data, - &dst_desc.data, &strides[0], &dilates[0], - &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated convolution forward descriptor"); - } - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api( - mkldnn_dilated_convolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, nullptr, - &dst_desc.data, &strides[0], &dilates[0], - &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated convolution forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(bias, weights, 1); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - convolution_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct convolution_backward_data : public primitive { - struct desc { - mkldnn_convolution_desc_t data; - desc(algorithm aalgorithm, - const memory::desc &diff_src_desc, - const memory::desc &weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( - &data, convert_to_c(aalgorithm), &diff_src_desc.data, - &weights_desc.data, &diff_dst_desc.data, - &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution backward data descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &diff_src_desc, - const memory::desc &weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api( - mkldnn_dilated_convolution_backward_data_desc_init( - &data, convert_to_c(aalgorithm), &diff_src_desc.data, - &weights_desc.data, &diff_dst_desc.data, - &strides[0], &dilates[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution backward data descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const convolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const convolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - convolution_backward_data(const primitive_desc &pd): primitive(pd) {} -}; - -struct convolution_backward_weights : public primitive { - struct desc { - mkldnn_convolution_desc_t data; - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_bias_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, &diff_bias_desc.data, - &diff_dst_desc.data, - &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution backward weights descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, nullptr, &diff_dst_desc.data, - &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution backward weights descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_bias_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, &diff_bias_desc.data, - &diff_dst_desc.data, - &strides[0], &dilates[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution backward weights descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, nullptr, &diff_dst_desc.data, - &strides[0], &dilates[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a convolution backward weights descriptor"); - } - - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const convolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const convolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(diff_weights, diff_weights, 0); - REG_QUERY_MD(diff_bias, diff_weights, 1); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} -// -/// @addtogroup cpp_api_deconvolution Deconvolution -/// A primitive to compute deconvolution using different algorithms. -/// -/// @sa @ref c_api_deconvolution in @ref c_api -/// @{ - -struct deconvolution_forward: public primitive { - struct desc { - mkldnn_deconvolution_desc_t data; - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &bias_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, &bias_desc.data, - &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a deconvolution forward descriptor"); - } - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, nullptr, - &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a deconvolution forward descriptor"); - } - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &bias_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, &bias_desc.data, - &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], - &padding_r[0], mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated deconvolution forward descriptor"); - } - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, &weights_desc.data, nullptr, - &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], - &padding_r[0], mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated deconvolution forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(bias, weights, 1); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - deconvolution_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct deconvolution_backward_data : public primitive { - struct desc { - mkldnn_deconvolution_desc_t data; - desc(algorithm aalgorithm, - const memory::desc &diff_src_desc, - const memory::desc &weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( - &data, convert_to_c(aalgorithm), &diff_src_desc.data, - &weights_desc.data, &diff_dst_desc.data, - &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a deconvolution backward data descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &diff_src_desc, - const memory::desc &weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( - &data, convert_to_c(aalgorithm), &diff_src_desc.data, - &weights_desc.data, &diff_dst_desc.data, - &strides[0], &dilates[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated deconvolution backward data descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const deconvolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const deconvolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} -}; - -struct deconvolution_backward_weights : public primitive { - struct desc { - mkldnn_deconvolution_desc_t data; - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_bias_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, &diff_bias_desc.data, - &diff_dst_desc.data, - &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a deconvolution backward weights descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, nullptr, &diff_dst_desc.data, - &strides[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a deconvolution backward weights descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_bias_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, &diff_bias_desc.data, - &diff_dst_desc.data, - &strides[0], &dilates[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated deconvolution backward weights descriptor"); - } - desc(algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_dst_desc, - const memory::dims strides, - const memory::dims dilates, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(dilates); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( - &data, convert_to_c(aalgorithm), &src_desc.data, - &diff_weights_desc.data, nullptr, &diff_dst_desc.data, - &strides[0], &dilates[0], &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not create a dilated deconvolution backward weights descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const deconvolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const deconvolution_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(diff_weights, diff_weights, 0); - REG_QUERY_MD(diff_bias, diff_weights, 1); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_lrn LRN -/// A primitive to perform local response normalization (LRN) across or within -/// channels. -/// -/// @sa @ref c_api_lrn in @ref c_api -/// @{ - -struct lrn_forward : public primitive { - struct desc { - mkldnn_lrn_desc_t data; - - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, memory::dim local_size, - float alpha, float beta, float k = 1.f) { - error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), - &src_desc.data, local_size, alpha, beta, k), - "could not create a lrn forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - lrn_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct lrn_backward : public primitive { - struct desc { - mkldnn_lrn_desc_t data; - - desc(algorithm aalgorithm, const memory::desc &data_desc, - const memory::desc &diff_data_desc, memory::dim local_size, - float alpha, float beta, float k = 1.f) { - error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, - convert_to_c(aalgorithm), &diff_data_desc.data, - &data_desc.data, local_size, alpha, beta, k), - "could not create a lrn backward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const lrn_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const lrn_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - lrn_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_pooling Pooling -/// A primitive to perform max or average pooling. -/// -/// @sa @ref c_api_pooling in @ref c_api -/// @{ - -struct pooling_forward : public primitive { - struct desc { - mkldnn_pooling_desc_t data; - desc(prop_kind aprop_kind, algorithm aalgorithm, - const memory::desc &src_desc, - const memory::desc &dst_desc, - const memory::dims strides, - const memory::dims kernel, - const memory::dims padding_l, - const memory::dims padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(kernel); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), - convert_to_c(aalgorithm), - &src_desc.data, &dst_desc.data, - &strides[0], &kernel[0], - &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not init a forward pooling descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - pooling_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct pooling_backward : public primitive { - struct desc { - mkldnn_pooling_desc_t data; - desc(algorithm aalgorithm, - const memory::desc &diff_src_desc, - const memory::desc &diff_dst_desc, - const memory::dims &strides, - const memory::dims &kernel, - const memory::dims &padding_l, - const memory::dims &padding_r, - const padding_kind apadding_kind) { - memory::validate_dims(strides); - memory::validate_dims(kernel); - memory::validate_dims(padding_l); - memory::validate_dims(padding_r); - error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, - convert_to_c(aalgorithm), - &diff_src_desc.data, &diff_dst_desc.data, - &strides[0], &kernel[0], - &padding_l[0], &padding_r[0], - mkldnn::convert_to_c(apadding_kind)), - "could not init a backward pooling descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const pooling_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const pooling_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - pooling_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_eltwise Eltwise -/// A primitive to compute element-wise operations like parametric rectifier -/// linear unit (ReLU). -/// -/// @sa @ref c_api_eltwise in @ref c_api -/// @{ - -struct eltwise_forward : public primitive { - struct desc { - mkldnn_eltwise_desc_t data; - template - desc(prop_kind aprop_kind, algorithm alg_kind, - const memory::desc &src_desc, T alpha = 0, T beta = 0) { - error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), - mkldnn::convert_to_c(alg_kind), &src_desc.data, - static_cast(alpha), static_cast(beta)), - "could not create a eltwise forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - eltwise_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct eltwise_backward : public primitive { - struct desc { - mkldnn_eltwise_desc_t data; - - template - desc(algorithm alg_kind, const memory::desc &diff_data_desc, - const memory::desc &data_desc, T alpha = 0, T beta = 0) { - error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, - mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, - &data_desc.data, static_cast(alpha), - static_cast(beta)), - "could not create a eltwise backward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const eltwise_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const eltwise_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - eltwise_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_softmax Softmax -/// A primitive to perform softmax. -/// -/// @sa @ref c_api_softmax in @ref c_api -/// @{ - -struct softmax_forward : public primitive { - struct desc { - mkldnn_softmax_desc_t data; - desc(prop_kind aprop_kind, const memory::desc &data_desc, - int softmax_axis) { - error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), &data_desc.data, - softmax_axis), - "could not create a softmax forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - softmax_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct softmax_backward : public primitive { - struct desc { - mkldnn_softmax_desc_t data; - desc(const memory::desc &diff_desc, const memory::desc &data_desc, - int softmax_axis) { - error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, - &diff_desc.data, &data_desc.data, softmax_axis), - "could not init a backward softmax descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const softmax_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const softmax_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - softmax_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_batch_norm Batch normalization -/// A primitive to perform batch normalization. -/// -/// @sa @ref c_api_batch_normalization in @ref c_api -/// @{ - -struct batch_normalization_forward : public primitive { - struct desc { - mkldnn_batch_normalization_desc_t data; - template - desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, - unsigned flags) { - error::wrap_c_api( - mkldnn_batch_normalization_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), &src_desc.data, - static_cast(epsilon), flags), - "could not create a batch normalization forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - - memory::desc mean_desc() const { return stat_desc(mean); } - memory::desc variance_desc() const { return stat_desc(var); } - - private: - enum { mean = 1, var = 2, }; - memory::desc stat_desc(int kind) const { - mkldnn_batch_normalization_desc_t *p; - error::wrap_c_api(mkldnn_primitive_desc_query( - get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), - "could not get a batch-normalization descriptor"); - return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); - } - }; - - batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct batch_normalization_backward : public primitive { - struct desc { - mkldnn_batch_normalization_desc_t data; - template - desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, - const memory::desc &data_desc, T epsilon, unsigned flags) { - error::wrap_c_api( - mkldnn_batch_normalization_backward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), - &diff_data_desc.data, &data_desc.data, - static_cast(epsilon), flags), - "could not create a batch normalization backward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const batch_normalization_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const batch_normalization_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(mean, src, 1); - REG_QUERY_MD(variance, src, 2); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(workspace, workspace, 0); - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(diff_weights, diff_weights, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_inner_product Inner Product -/// A primitive to compute an inner product. -/// -/// @sa @ref c_api_inner_product in @ref c_api -/// @{ - -struct inner_product_forward: public primitive { - struct desc { - mkldnn_inner_product_desc_t data; - desc(prop_kind aprop_kind, const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &bias_desc, - const memory::desc &dst_desc) { - error::wrap_c_api( - mkldnn_inner_product_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), &src_desc.data, - &weights_desc.data, &bias_desc.data, &dst_desc.data), - "could not create a inner product forward descriptor"); - } - - desc(prop_kind aprop_kind, const memory::desc &src_desc, - const memory::desc &weights_desc, - const memory::desc &dst_desc) { - error::wrap_c_api( - mkldnn_inner_product_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), &src_desc.data, - &weights_desc.data, nullptr, &dst_desc.data), - "could not create a inner product forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(bias, weights, 1); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - inner_product_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct inner_product_backward_data: public primitive { - struct desc { - mkldnn_inner_product_desc_t data; - desc(const memory::desc &diff_src_desc, - const memory::desc &weights_desc, - const memory::desc &diff_dst_desc) { - error::wrap_c_api( - mkldnn_inner_product_backward_data_desc_init(&data, - &diff_src_desc.data, &weights_desc.data, - &diff_dst_desc.data), - "could not create a inner product backward data descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const inner_product_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const inner_product_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(weights, weights, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} -}; - -struct inner_product_backward_weights: public primitive { - struct desc { - mkldnn_inner_product_desc_t data; - desc(const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_bias_desc, - const memory::desc &diff_dst_desc) { - error::wrap_c_api( - mkldnn_inner_product_backward_weights_desc_init( - &data, &src_desc.data, &diff_weights_desc.data, - &diff_bias_desc.data, &diff_dst_desc.data), - "could not create a inner product backward weights descriptor"); - } - desc(const memory::desc &src_desc, - const memory::desc &diff_weights_desc, - const memory::desc &diff_dst_desc) { - error::wrap_c_api( - mkldnn_inner_product_backward_weights_desc_init( - &data, &src_desc.data, &diff_weights_desc.data, - nullptr, &diff_dst_desc.data), - "could not create a inner product backward weights descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const inner_product_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const inner_product_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(diff_weights, diff_weights, 0); - REG_QUERY_MD(diff_bias, diff_weights, 1); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_rnn RNN -/// A primitive to compute common recurrent layer. -/// -/// @sa @ref c_api_rnn in @ref c_api -/// @{ - -struct rnn_cell { - struct desc { - mkldnn_rnn_cell_desc_t c_rnn_cell_; - - desc(algorithm kind, algorithm activation_f) { - error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, - mkldnn::convert_to_c(kind), - mkldnn::convert_to_c(activation_f), 0U, 0, 0), - "could not init an rnn cell descriptor"); - } - desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} - - operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } - - algorithm get_cell_kind() const - { return algorithm(c_rnn_cell_.cell_kind); } - algorithm get_activation() const - { return algorithm(c_rnn_cell_.activation_kind); } - - float get_alpha() const { return c_rnn_cell_.alpha; } - void set_alpha(float alpha) { - c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; - c_rnn_cell_.alpha = alpha; - } - - float get_clipping() const { return c_rnn_cell_.clipping; } - void set_clipping(float clipping) { - c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; - c_rnn_cell_.clipping = clipping; - } - - int get_gates_count() const { - return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); - } - int get_state_count() const { - return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); - } - }; -}; - -struct rnn_forward : public primitive { - struct desc { - mkldnn_rnn_desc_t data; - desc(prop_kind aprop_kind, rnn_cell::desc cell, - const rnn_direction direction, - const memory::desc &src_layer_desc, - const memory::desc &src_iter_desc, - const memory::desc &weights_layer_desc, - const memory::desc &weights_iter_desc, - const memory::desc &bias_desc, - const memory::desc &dst_layer_desc, - const memory::desc &dst_iter_desc - ) { - error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), cell, - mkldnn::convert_to_c(direction), - &src_layer_desc.data, &src_iter_desc.data, - &weights_layer_desc.data, &weights_iter_desc.data, - &bias_desc.data, - &dst_layer_desc.data, &dst_iter_desc.data), - "could not create an RNN forward descriptor"); - } - - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) - : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} - - REG_QUERY_MD(src_layer, src, 0); - REG_QUERY_MD(src_iter, src, 1); - REG_QUERY_MD(weights_layer, weights, 0); - REG_QUERY_MD(weights_iter, weights, 1); - REG_QUERY_MD(bias, weights, 2); - REG_QUERY_MD(dst_layer, dst, 0); - REG_QUERY_MD(dst_iter, dst, 1); - REG_QUERY_MD(workspace, workspace, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - rnn_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct rnn_backward : public primitive { - struct desc { - mkldnn_rnn_desc_t data; - desc(prop_kind aprop_kind, rnn_cell::desc cell, - const rnn_direction direction, - const memory::desc &src_layer_desc, - const memory::desc &src_iter_desc, - const memory::desc &weights_layer_desc, - const memory::desc &weights_iter_desc, - const memory::desc &bias_desc, - const memory::desc &dst_layer_desc, - const memory::desc &dst_iter_desc, - const memory::desc &diff_src_layer_desc, - const memory::desc &diff_src_iter_desc, - const memory::desc &diff_weights_layer_desc, - const memory::desc &diff_weights_iter_desc, - const memory::desc &diff_bias_desc, - const memory::desc &diff_dst_layer_desc, - const memory::desc &diff_dst_iter_desc) { - error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), cell, - mkldnn::convert_to_c(direction), - &src_layer_desc.data, &src_iter_desc.data, - &weights_layer_desc.data, &weights_iter_desc.data, - &bias_desc.data, - &dst_layer_desc.data, &dst_iter_desc.data, - &diff_src_layer_desc.data, &diff_src_iter_desc.data, - &diff_weights_layer_desc.data, - &diff_weights_iter_desc.data, &diff_bias_desc.data, - &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), - "could not create an RNN backward descriptor"); - } - - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const rnn_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, - const rnn_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(src_layer, src, 0); - REG_QUERY_MD(src_iter, src, 1); - REG_QUERY_MD(weights_layer, weights, 0); - REG_QUERY_MD(weights_iter, weights, 1); - REG_QUERY_MD(bias, weights, 2); - REG_QUERY_MD(dst_layer, dst, 0); - REG_QUERY_MD(dst_iter, dst, 1); - REG_QUERY_MD(workspace, workspace, 0); - - REG_QUERY_MD(diff_src_layer, diff_src, 0); - REG_QUERY_MD(diff_src_iter, diff_src, 1); - REG_QUERY_MD(diff_weights_layer, diff_weights, 0); - REG_QUERY_MD(diff_weights_iter, diff_weights, 1); - REG_QUERY_MD(diff_bias, diff_weights, 2); - REG_QUERY_MD(diff_dst_layer, diff_dst, 0); - REG_QUERY_MD(diff_dst_iter, diff_dst, 1); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - // With last iteration (with and without input src_iter) - rnn_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @addtogroup cpp_api_shuffle Shuffle -/// A primitive to shuffle data along the axis. -/// -/// @sa @ref c_api_shuffle in @ref c_api -/// @{ - -struct shuffle_forward : public primitive { - struct desc { - mkldnn_shuffle_desc_t data; - desc(prop_kind aprop_kind, const memory::desc &data_desc, - int axis, int group_size) { - error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, - mkldnn::convert_to_c(aprop_kind), &data_desc.data, - axis, group_size), - "could not create a shuffle forward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e) - : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} - - REG_QUERY_MD(src, src, 0); - REG_QUERY_MD(dst, dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - shuffle_forward(const primitive_desc &pd): primitive(pd) {} -}; - -struct shuffle_backward : public primitive { - struct desc { - mkldnn_shuffle_desc_t data; - desc(const memory::desc &diff_data_desc, int axis, int group_size) { - error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, - &diff_data_desc.data, axis, group_size), - "could not create a shuffle backward descriptor"); - } - }; - - struct primitive_desc : public mkldnn::primitive_desc { - primitive_desc(const desc &desc, const engine &e, - const shuffle_forward::primitive_desc &hint_fwd_pd) - : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} - - REG_QUERY_MD(diff_src, diff_src, 0); - REG_QUERY_MD(diff_dst, diff_dst, 0); - REG_QUERY_MD(scratchpad, scratchpad, 0); - }; - - shuffle_backward(const primitive_desc &pd): primitive(pd) {} -}; - -/// @} - -/// @} Primitives - -/// @} C++ API - -#undef REG_QUERY_MD - -// implementation section -#ifndef DOXYGEN_SHOULD_SKIP_THIS - -inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { - mkldnn_primitive_t result; - error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), - "could not create a primitive"); - reset(result); -} - -inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} - -inline void primitive::execute(stream &astream, - const std::unordered_map &args) const { - std::vector c_args; - c_args.reserve(args.size()); - for (const auto &a: args) - c_args.push_back({a.first, a.second.get()}); - - error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), - (int)c_args.size(), c_args.data()), - "primitive execution fail"); -} -#endif // DOXYGEN_SHOULD_SKIP_THIS - -} // namespace mkldnn - -#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h deleted file mode 100644 index f4dc2fdfa..000000000 --- a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h +++ /dev/null @@ -1,98 +0,0 @@ -/******************************************************************************* -* Copyright 2018-2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* DO NOT EDIT, AUTO-GENERATED */ - -#ifndef MKLDNN_DEBUG_H -#define MKLDNN_DEBUG_H - -#ifndef DOXYGEN_SHOULD_SKIP_THIS - -/* All symbols shall be internal unless marked as MKLDNN_API */ -#if defined _WIN32 || defined __CYGWIN__ -# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) -# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) -#else -# if __GNUC__ >= 4 -# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) -# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) -# else -# define MKLDNN_HELPER_DLL_IMPORT -# define MKLDNN_HELPER_DLL_EXPORT -# endif -#endif - -#ifdef MKLDNN_DLL -# ifdef MKLDNN_DLL_EXPORTS -# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT -# else -# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT -# endif -#else -# define MKLDNN_API -#endif - -#if defined (__GNUC__) -# define MKLDNN_DEPRECATED __attribute__((deprecated)) -#elif defined(_MSC_VER) -# define MKLDNN_DEPRECATED __declspec(deprecated) -#else -# define MKLDNN_DEPRECATED -#endif - -#include "mkldnn_types.h" -#endif /* DOXYGEN_SHOULD_SKIP_THIS */ - -#ifdef __cplusplus -extern "C" { -#endif - -const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v); -const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v); -const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v); -const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v); -const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v); -const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v); -const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v); -const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v); - -/** Forms a format string for a given memory descriptor. - * - * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'. - * Here: - * - dt -- data type - * - p -- indicates there is non-trivial padding - * - o -- indicates there is non-trivial padding offset - * - 0 -- indicates there is non-trivial offset0 - * - fmt_kind -- format kind (blocked, wino, etc...) - * - fmt -- extended format string (format_kind specific) - * - extra -- shows extra fields (underspecified) - */ -int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len, - const mkldnn_memory_desc_t *md); - -/** Forms a dimension string for a given memory descriptor. - * - * The format is defined as: 'dim0xdim1x...xdimN - */ -int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len, - const mkldnn_memory_desc_t *md); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h deleted file mode 100644 index 1b6c35698..000000000 --- a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h +++ /dev/null @@ -1,1415 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_TYPES_H -#define MKLDNN_TYPES_H - -#ifdef __cplusplus -extern "C" { -#endif - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -#include -#include -#endif - -/** @addtogroup c_api C API - * @{ - * - * @addtogroup c_api_types Types - * @{ - * - * @addtogroup c_api_types_generic Generic - * @{ */ - -/** Intel(R) MKL-DNN Version type */ -typedef struct { - int major; - int minor; - int patch; - const char *hash; -} mkldnn_version_t; - -/** Status values returned by Intel(R) MKL-DNN functions. */ -typedef enum { - /** The operation was successful */ - mkldnn_success = 0, - /** The operation failed due to an out-of-memory condition */ - mkldnn_out_of_memory = 1, - /** The operation failed and should be retried */ - mkldnn_try_again = 2, - /** The operation failed because of incorrect function arguments */ - mkldnn_invalid_arguments = 3, - /** The operation failed because a primitive was not ready for execution */ - mkldnn_not_ready = 4, - /** The operation failed because requested functionality is not implemented - */ - mkldnn_unimplemented = 5, - /** Primitive iterator passed over last primitive descriptor */ - mkldnn_iterator_ends = 6, - /** Primitive or engine failed on execution */ - mkldnn_runtime_error = 7, - /** Queried element is not required for given primitive */ - mkldnn_not_required = 8, -} mkldnn_status_t; - -/** Data type specification */ -typedef enum { - /** Undefined data type, used for empty memory descriptors. */ - mkldnn_data_type_undef = 0, - /** 32-bit/single-precision floating point. */ - mkldnn_f32 = 1, - /** 32-bit signed integer. */ - mkldnn_s32 = 2, - /** 8-bit signed integer. */ - mkldnn_s8 = 3, - /** 8-bit unsigned integer. */ - mkldnn_u8 = 4, -} mkldnn_data_type_t; - -/** Memory format kind */ -typedef enum { - /** Undefined memory format, used for empty memory descriptors. */ - mkldnn_format_kind_undef = 0, - /** Unspecified format. The primitive selects a format automatically. */ - mkldnn_format_kind_any, - /** A tensor in a generic format described by the stride and blocking - * values in each dimension. See #mkldnn_blocking_desc_t for more - * information. */ - mkldnn_blocked, - /** Weights format used in 8bit Winograd convolution */ - mkldnn_format_kind_wino, - /** Packed weights format used in RNN */ - mkldnn_format_kind_rnn_packed, -} mkldnn_format_kind_t; - -/** Memory format tag specification. - * - * Intel MKL-DNN formats describe physical data layout. The physical layout - * is described as a sequence of the dimensions as they are laid out in the - * memory (from the outer-most to the inner-most). Note that this order - * doesn't affect the logical order of the dimensions that is kept in the - * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the - * dimensions is specified by the type of tensor. - * - * For example, CNN 5D tensor always has its logical dimensions in the order - * `(batch, channels, depth, height, width)`, while the physical layout might be - * #mkldnn_ncdhw or #mkldnn_ndhwc: - * - * ~~~cpp - * int batch = 2, channels = 16, depth = 13, height = 13, width = 13; - * - * int ndims = 5; // 5D tensor - * mkldnn_dims_t dims = {batch, channels, depth, height, width}; - * mkldnn_memory_desc_t data_in_ncdhw; - * mkldnn_memory_desc_init_by_tag( - * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw); - * - * // note that in both cases dims passed are the same - * mkldnn_memory_desc_t data_in_ndhwc; - * mkldnn_memory_desc_init_by_tag( - * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc); - * ~~~ - * - * The following notation applies to memory format names: - * - @c 'n' denotes the mini-batch dimension - * - @c 'c' denotes a channels dimension - * - When there are multiple channel dimensions (for example, in convolution - * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output - * channels - * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width - * respectively - * - Upper-case letters indicate that the data is laid out in blocks - * for a particular dimension. In such cases, the format name contains both - * upper- and lower-case letters for that dimension with a lower-case letter - * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a - * format where the outermost dimension is mini-batch, followed by the - * channel block number, followed by the spatial height and width, and - * finally followed by 8-element channel blocks. - * - * @note - * Channel designations can be different. For example, both the @c - * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D - * tensor. - * - * @sa @ref understanding_memory_formats - */ -typedef enum { - /** Undefined memory format tag */ - mkldnn_format_tag_undef = 0, - /** Undefined memory format tag. - * The primitive selects a format automatically. */ - mkldnn_format_tag_any, - - /* Semantic agnostic section */ - /* The physical order of dimensions is defined by the permutation of the - * characters, assuming that ab..z defines the natural order. - */ - - /* Plain formats */ - - mkldnn_a, - mkldnn_ab, - mkldnn_abc, - mkldnn_abcd, - mkldnn_abcde, - mkldnn_abcdef, - mkldnn_abdec, - mkldnn_acb, - mkldnn_acbde, - mkldnn_acdb, - mkldnn_acdeb, - mkldnn_ba, - mkldnn_bac, - mkldnn_bacd, - mkldnn_bcda, - mkldnn_cba, - mkldnn_cdba, - mkldnn_cdeba, - mkldnn_decab, - - /* Opaque blocked formats */ - - mkldnn_Abc16a, - mkldnn_ABc16a16b, - mkldnn_aBc16b, - mkldnn_ABc16b16a, - mkldnn_Abc4a, - mkldnn_aBc4b, - mkldnn_ABc4b16a4b, - mkldnn_ABc4b4a, - mkldnn_ABc8a16b2a, - mkldnn_ABc8a8b, - mkldnn_aBc8b, - mkldnn_ABc8b16a2b, - mkldnn_ABc8b8a, - mkldnn_Abcd16a, - mkldnn_ABcd16a16b, - mkldnn_aBcd16b, - mkldnn_ABcd16b16a, - mkldnn_aBCd16b16c, - mkldnn_aBCd16c16b, - mkldnn_Abcd4a, - mkldnn_aBcd4b, - mkldnn_ABcd4b16a4b, - mkldnn_ABcd4b4a, - mkldnn_aBCd4c16b4c, - mkldnn_aBCd4c4b, - mkldnn_ABcd8a16b2a, - mkldnn_ABcd8a8b, - mkldnn_aBcd8b, - mkldnn_ABcd8b16a2b, - mkldnn_aBCd8b16c2b, - mkldnn_ABcd8b8a, - mkldnn_aBCd8b8c, - mkldnn_aBCd8c16b2c, - mkldnn_aBCd8c8b, - mkldnn_Abcde16a, - mkldnn_ABcde16a16b, - mkldnn_aBcde16b, - mkldnn_ABcde16b16a, - mkldnn_aBCde16b16c, - mkldnn_aBCde16c16b, - mkldnn_aBCde2c8b4c, - mkldnn_Abcde4a, - mkldnn_aBcde4b, - mkldnn_ABcde4b4a, - mkldnn_aBCde4b4c, - mkldnn_aBCde4c16b4c, - mkldnn_aBCde4c4b, - mkldnn_Abcde8a, - mkldnn_ABcde8a8b, - mkldnn_aBcde8b, - mkldnn_ABcde8b16a2b, - mkldnn_aBCde8b16c2b, - mkldnn_ABcde8b8a, - mkldnn_aBCde8b8c, - mkldnn_aBCde8c16b2c, - mkldnn_aBCde8c8b, - mkldnn_aBcdef16b, - mkldnn_aBCdef16b16c, - mkldnn_aBCdef16c16b, - mkldnn_aBcdef4b, - mkldnn_aBCdef4c4b, - mkldnn_aBCdef8b8c, - mkldnn_aBCdef8c16b2c, - mkldnn_aBCdef8c8b, - mkldnn_aBdc16b, - mkldnn_aBdc4b, - mkldnn_aBdc8b, - mkldnn_aBdec16b, - mkldnn_aBdec4b, - mkldnn_aBdec8b, - mkldnn_aBdefc16b, - mkldnn_aBdefc4b, - mkldnn_aBdefc8b, - mkldnn_Acb16a, - mkldnn_Acb4a, - mkldnn_Acb8a, - mkldnn_aCBd16b16c, - mkldnn_aCBde16b16c, - mkldnn_Acdb16a, - mkldnn_Acdb4a, - mkldnn_Acdb8a, - mkldnn_Acdeb16a, - mkldnn_Acdeb4a, - mkldnn_Acdeb8a, - mkldnn_BAc16a16b, - mkldnn_BAcd16a16b, - - /** Just a sentinel, not real memory format tag. Must be changed after new - * format tag is added. */ - mkldnn_format_tag_last, - - /* Aliases */ - - mkldnn_x = mkldnn_a, - mkldnn_nc = mkldnn_ab, - mkldnn_cn = mkldnn_ba, - mkldnn_ncw = mkldnn_abc, - mkldnn_nwc = mkldnn_acb, - mkldnn_nchw = mkldnn_abcd, - mkldnn_nhwc = mkldnn_acdb, - mkldnn_chwn = mkldnn_bcda, - mkldnn_ncdhw = mkldnn_abcde, - mkldnn_ndhwc = mkldnn_acdeb, - - mkldnn_oi = mkldnn_ab, - mkldnn_io = mkldnn_ba, - mkldnn_oiw = mkldnn_abc, - mkldnn_wio = mkldnn_cba, - mkldnn_oihw = mkldnn_abcd, - mkldnn_hwio = mkldnn_cdba, - mkldnn_ihwo = mkldnn_bcda, - mkldnn_iohw = mkldnn_bacd, - mkldnn_oidhw = mkldnn_abcde, - mkldnn_dhwio = mkldnn_cdeba, - mkldnn_goiw = mkldnn_abcd, - mkldnn_goihw = mkldnn_abcde, - mkldnn_hwigo = mkldnn_decab, - mkldnn_giohw = mkldnn_acbde, - mkldnn_goidhw = mkldnn_abcdef, - - /** 3D RNN data tensor in the format (seq_length, batch, input channels). */ - mkldnn_tnc = mkldnn_abc, - /** 3D RNN data tensor in the format (batch, seq_length, input channels). */ - mkldnn_ntc = mkldnn_bac, - /** 5D RNN states tensor in the format (num_layers, num_directions, - * num_states, batch, state channels). */ - mkldnn_ldsnc = mkldnn_abcde, - /** 5D RNN weights tensor in the format (num_layers, num_directions, - * input_channels, num_gates, output_channels). - * - * - For LSTM cells, the gates order is input, forget, candidate - * and output gate. - * - For GRU cells, the gates order is update, reset and output gate. */ - mkldnn_ldigo = mkldnn_abcde, - /** 5D RNN weights tensor in the format (num_layers, num_directions, - * num_gates, output_channels, input_channels). - * - * - For LSTM cells, the gates order is input, forget, candidate - * and output gate. - * - For GRU cells, the gates order is update, reset and output gate. */ - mkldnn_ldgoi = mkldnn_abdec, - /** 4D RNN bias tensor in the format (num_layers, num_directions, - * num_gates, output_channels). - * - * - For LSTM cells, the gates order is input, forget, candidate - * and output gate. - * - For GRU cells, the gates order is update, reset and output gate. */ - mkldnn_ldgo = mkldnn_abcd, - - /* Opaque data types, are not to be used explicitly */ - - /* data */ - mkldnn_nCdhw16c = mkldnn_aBcde16b, - mkldnn_nCdhw4c = mkldnn_aBcde4b, - mkldnn_nCdhw8c = mkldnn_aBcde8b, - mkldnn_nChw16c = mkldnn_aBcd16b, - mkldnn_nChw4c = mkldnn_aBcd4b, - mkldnn_nChw8c = mkldnn_aBcd8b, - mkldnn_nCw16c = mkldnn_aBc16b, - mkldnn_nCw4c = mkldnn_aBc4b, - mkldnn_nCw8c = mkldnn_aBc8b, - - /* weights, 3D */ - mkldnn_IOw16o16i = mkldnn_BAc16a16b, - mkldnn_OIw16i16o = mkldnn_ABc16b16a, - mkldnn_OIw16o16i = mkldnn_ABc16a16b, - mkldnn_Oiw16o = mkldnn_Abc16a, - mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b, - mkldnn_OIw4i4o = mkldnn_ABc4b4a, - mkldnn_Oiw4o = mkldnn_Abc4a, - mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b, - mkldnn_OIw8i8o = mkldnn_ABc8b8a, - mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a, - mkldnn_OIw8o8i = mkldnn_ABc8a8b, - mkldnn_Owi16o = mkldnn_Acb16a, - mkldnn_Owi4o = mkldnn_Acb4a, - mkldnn_Owi8o = mkldnn_Acb8a, - - /* weights, 4D */ - mkldnn_IOhw16o16i = mkldnn_BAcd16a16b, - mkldnn_Ohwi16o = mkldnn_Acdb16a, - mkldnn_Ohwi4o = mkldnn_Acdb4a, - mkldnn_Ohwi8o = mkldnn_Acdb8a, - mkldnn_OIhw16i16o = mkldnn_ABcd16b16a, - mkldnn_OIhw16o16i = mkldnn_ABcd16a16b, - mkldnn_Oihw16o = mkldnn_Abcd16a, - mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b, - mkldnn_OIhw4i4o = mkldnn_ABcd4b4a, - mkldnn_Oihw4o = mkldnn_Abcd4a, - mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b, - mkldnn_OIhw8i8o = mkldnn_ABcd8b8a, - mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a, - mkldnn_OIhw8o8i = mkldnn_ABcd8a8b, - - /* weights, 5D */ - mkldnn_Odhwi16o = mkldnn_Acdeb16a, - mkldnn_Odhwi4o = mkldnn_Acdeb4a, - mkldnn_Odhwi8o = mkldnn_Acdeb8a, - mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a, - mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b, - mkldnn_Oidhw16o = mkldnn_Abcde16a, - mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a, - mkldnn_Oidhw4o = mkldnn_Abcde4a, - mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b, - mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a, - mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b, - - /* weights w/ groups, 3D */ - mkldnn_Goiw16g = mkldnn_Abcd16a, - mkldnn_gIOw16o16i = mkldnn_aCBd16b16c, - mkldnn_gOIw16i16o = mkldnn_aBCd16c16b, - mkldnn_gOIw16o16i = mkldnn_aBCd16b16c, - mkldnn_gOiw16o = mkldnn_aBcd16b, - mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c, - mkldnn_gOIw4i4o = mkldnn_aBCd4c4b, - mkldnn_gOiw4o = mkldnn_aBcd4b, - mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c, - mkldnn_gOIw8i8o = mkldnn_aBCd8c8b, - mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b, - mkldnn_gOIw8o8i = mkldnn_aBCd8b8c, - mkldnn_gOwi16o = mkldnn_aBdc16b, - mkldnn_gOwi4o = mkldnn_aBdc4b, - mkldnn_gOwi8o = mkldnn_aBdc8b, - - /* weights w/ groups, 4D */ - mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c, - mkldnn_gOhwi16o = mkldnn_aBdec16b, - mkldnn_gOhwi4o = mkldnn_aBdec4b, - mkldnn_gOhwi8o = mkldnn_aBdec8b, - mkldnn_Goihw16g = mkldnn_Abcde16a, - mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b, - mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c, - mkldnn_gOihw16o = mkldnn_aBcde16b, - mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c, - mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c, - mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b, - mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c, - mkldnn_gOihw4o = mkldnn_aBcde4b, - mkldnn_Goihw8g = mkldnn_Abcde8a, - mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c, - mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b, - mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b, - mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c, - - /* weights w/ groups, 6D */ - mkldnn_gOdhwi16o = mkldnn_aBdefc16b, - mkldnn_gOdhwi4o = mkldnn_aBdefc4b, - mkldnn_gOdhwi8o = mkldnn_aBdefc8b, - mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b, - mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c, - mkldnn_gOidhw16o = mkldnn_aBcdef16b, - mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b, - mkldnn_gOidhw4o = mkldnn_aBcdef4b, - mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c, - mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b, - mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c, -} mkldnn_format_tag_t; - -/** Kinds of padding. Define how to interpret the data in padding regions. */ -typedef enum { - /** The data in padding regions is zero. */ - mkldnn_padding_zero, -} mkldnn_padding_kind_t; - -/** Kinds of propagation. */ -typedef enum { - /* TODO: suggest renames */ - /** Undefined propagation type. */ - mkldnn_prop_kind_undef = 0, - /** Forward data propagation (training mode). In this mode primitives - * perform computations necessary for subsequent backward propagation. */ - mkldnn_forward_training = 64, - /** Forward data propagation (inference mode). In this mode primitives - * perform only computations that are necessary for inference and omit - * computations that are necessary only for backward propagation. */ - mkldnn_forward_inference = 96, - /** Forward data propagation (alias for @c mkldnn_forward_inference) */ - mkldnn_forward_scoring = mkldnn_forward_inference, - /** Forward data propagation (alias for @c mkldnn_forward_training) */ - mkldnn_forward = mkldnn_forward_training, - /** Backward propagation (with respect to all parameters */ - mkldnn_backward = 128, - /** Backward data propagation */ - mkldnn_backward_data = 160, - /** Backward weights propagation */ - mkldnn_backward_weights = 192, - /** Backward bias propagation */ - mkldnn_backward_bias = 193, -} mkldnn_prop_kind_t; - -/** Kinds of primitives. Used to implement a way to extend the library with new - * primitives without changing the ABI. */ -typedef enum { - /** Undefined primitive (XXX: why do we have it?). */ - mkldnn_undefined_primitive, - /** A reorder primitive.*/ - mkldnn_reorder, - /** A shuffle primitive.*/ - mkldnn_shuffle, - /** A (out-of-place) concat primitive. */ - mkldnn_concat, - /** A sum primitive. */ - mkldnn_sum, - /** A convolution primitive. */ - mkldnn_convolution, - /** A deconvolution primitive. */ - mkldnn_deconvolution, - /** An element-wise primitive. */ - mkldnn_eltwise, - /** A Softmax primitive. */ - mkldnn_softmax, - /** A pooling primitive. */ - mkldnn_pooling, - /** An LRN primitive. */ - mkldnn_lrn, - /** An batch normalization primitive. */ - mkldnn_batch_normalization, - /** An inner product primitive. */ - mkldnn_inner_product, - /** A rnn primitive. */ - mkldnn_rnn, -} mkldnn_primitive_kind_t; - -/** Kinds of algorithms. */ -typedef enum { - mkldnn_alg_kind_undef, - /** Direct convolution */ - mkldnn_convolution_direct = 0x1, - /** Winograd convolution */ - mkldnn_convolution_winograd = 0x2, - /** Convolution algorithm(either direct or Winograd) is chosen just in time **/ - mkldnn_convolution_auto = 0x3, - /** Direct deconvolution */ - mkldnn_deconvolution_direct = 0xa, - /** Winograd deconvolution */ - mkldnn_deconvolution_winograd = 0xb, - /** Eltwise: ReLU */ - mkldnn_eltwise_relu = 0x1f, - /** Eltwise: hyperbolic tangent non-linearity (tanh) */ - mkldnn_eltwise_tanh = 0x2f, - /** Eltwise: parametric exponential linear unit (elu) */ - mkldnn_eltwise_elu = 0x3f, - /** Eltwise: square */ - mkldnn_eltwise_square = 0x4f, - /** Eltwise: abs */ - mkldnn_eltwise_abs = 0x5f, - /** Eltwise: square root */ - mkldnn_eltwise_sqrt = 0x6f, - /** Eltwise: linear */ - mkldnn_eltwise_linear = 0x7f, - /** Eltwise: bounded_relu */ - mkldnn_eltwise_bounded_relu = 0x8f, - /** Eltwise: soft_relu */ - mkldnn_eltwise_soft_relu = 0x9f, - /** Eltwise: logistic */ - mkldnn_eltwise_logistic = 0xaf, - /** Max pooling */ - mkldnn_pooling_max = 0x1ff, - /** Average pooling include padding */ - mkldnn_pooling_avg_include_padding = 0x2ff, - /** Average pooling exclude padding */ - mkldnn_pooling_avg_exclude_padding = 0x3ff, - mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding, - /** Local response normalization (LRN) across multiple channels */ - mkldnn_lrn_across_channels = 0xaff, - /** LRN within a single channel */ - mkldnn_lrn_within_channel = 0xbff, - /** RNN cell */ - mkldnn_vanilla_rnn = 0x1fff, - /** LSTM cell */ - mkldnn_vanilla_lstm = 0x2fff, - /** GRU cell */ - mkldnn_vanilla_gru = 0x3fff, - /** GRU cell with linear before reset - * - * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru - * in how the new memory gate is calculated: - * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f] - * Primitive expects 4 biases on input: - * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ - * */ - mkldnn_gru_linear_before_reset = 0x4fff, -} mkldnn_alg_kind_t; - -/** Flags for batch-normalization primititve. */ -typedef enum { - /** Use global statistics - * - * If specified - * - on forward propagation use mean and variance provided by user (input) - * - on backward propagation reduces the amount of computations, since - * mean and variance are considered as constants - * - * If not specified: - * - on forward propagation mean and variance are computed and stored in - * output - * - on backward propagation compute full derivative wrt to data - */ - mkldnn_use_global_stats = 0x1U, - /** Use scale and shift parameters - * - * If specified: - * - on forward propagation use scale and shift (aka scale and bias) for - * the batch normalization results - * - on backward propagation (for prop_kind == #mkldnn_backward) compute - * diff wrt to scale and shift (hence one extra output used) - * - * If no specified: - * - on backward propagation prop_kind == #mkldnn_backward_data has the - * same behavior as prop_kind == #mkldnn_backward - */ - mkldnn_use_scaleshift = 0x2U, - /** Fuse with ReLU - * - * If specified: - * - on inference this option behaves the same as if the primitive were - * fused with ReLU via post ops API - * - on training primitive requires workspace (required to be able to - * perform backward pass) - */ - mkldnn_fuse_bn_relu = 0x4U, -} mkldnn_batch_normalization_flag_t; - -/** @} */ - -/** @addtogroup c_api_types_memory Memory - * @{ */ - -/** Maximum number of dimensions a tensor can have. Only restricts the amount - * of space used for the tensor description. Individual computational - * primitives may support only tensors of certain dimensions. */ -#define MKLDNN_MAX_NDIMS 12 - -/** A type to describe tensor dimension. */ -typedef int64_t mkldnn_dim_t; - -/** A type to describe tensor dimensions. */ -typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS]; - -/** A type to describe strides within a tensor. */ -typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS]; - -/** Generic description of blocked data layout for most memory formats. - * - * @sa @ref understanding_memory_formats */ -typedef struct { - /** The strides between the outermost blocks. - * In case of plain (non-blocked) formats the strides between dimensions. */ - mkldnn_dims_t strides; - /* Innermost section - * ASSUMPTION: the innermost blocks are always dense */ - /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */ - int inner_nblks; - /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */ - mkldnn_dims_t inner_blks; - /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of - * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */ - mkldnn_dims_t inner_idxs; -} mkldnn_blocking_desc_t; - -typedef enum { - /** Undefined memory format, used for empty memory descriptors. */ - mkldnn_wino_undef = 0, - /** Tensors of weights for 2x3 winograd convolutions. */ - mkldnn_wino_wei_aaOIoi, - mkldnn_wino_wei_aaOio, - mkldnn_wino_wei_aaOBiOo, - /** Tensor of weights for 4x3 convolution. */ - mkldnn_wino_wei_OBaaIBOIio -} mkldnn_wino_memory_format_t; - -/** Description of tensor of weights for winograd 2x3 convolution. */ -typedef struct { - mkldnn_wino_memory_format_t wino_format; - int r; - int alpha; - int ic; - int oc; - int ic_block; - int oc_block; - int ic2_block; - int oc2_block; - float adj_scale; - size_t size; -} mkldnn_wino_desc_t; - -typedef enum { - mkldnn_packed_format_undef = 0, - mkldnn_ldigo_p, - mkldnn_ldgoi_p -} mkldnn_rnn_packed_memory_format_t; - -/* Maximum number of parts of RNN weights tensor that require separate - * computation. */ -#define MKLDNN_RNN_MAX_N_PARTS 4 - -/** Description of tensor of packed weights for rnn. */ -typedef struct { - mkldnn_rnn_packed_memory_format_t format; - int n_parts; - int n; - int parts[MKLDNN_RNN_MAX_N_PARTS]; - size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS]; - size_t offset_compensation; - size_t size; -} mkldnn_rnn_packed_desc_t; - -typedef enum { - mkldnn_memory_extra_flag_none = 0x0U, - /** Indicates the weights have an additional buffer, that depends on the - * @p compensation_mask. - * - * For instance, in 4D case with the compensation mask equals (1 << 0) - * the additional buffer would consist of OC values: - * O[oc : 0,OC] = - * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } - */ - mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U, - mkldnn_memory_extra_flag_scale_adjust = 0x2U, -} mkldnn_memory_extra_flags_t; - -/** Description of extra information stored in memory */ -typedef struct { - /** The flags contain arbitrary extra information, such as compensation. - * @sa mkldnn_memory_extra_flags_t */ - uint64_t flags; - /** Compensation mask */ - int compensation_mask; - /** Scale applied to the data */ - float scale_adjust; - /** For future backwards compatibility */ - char reserved[64]; -} mkldnn_memory_extra_desc_t; - -/** Memory descriptor. The description is based on a number of dimensions, - * dimensions themselves, plus information about elements type and memory - * format. Additionally, contains format-specific descriptions of the data - * layout. */ -typedef struct { - /** Number of dimensions */ - int ndims; - /** Dimensions in the following order: - * - CNN data tensors: mini-batch, channel, spatial - * ({N, C, [[D,] H,] W}) - * - CNN weight tensors: group (optional), output channel, input channel, - * spatial ({[G,] O, I, [[D,] H,] W}) - * - RNN data tensors: time, mini-batch, channels ({T, N, C}) - * or layers, directions, states, mini-batch, channels ({L, D, S, N, C}) - * - RNN weight tensor: layers, directions, input channel, gates, output channels - * ({L, D, I, G, O}). - * - * @note - * The order of dimensions does not depend on the memory format, so - * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc - * the dims for 4D CN data tensor would be {N, C, H, W}. - */ - mkldnn_dims_t dims; - /** Data type of the tensor elements. */ - mkldnn_data_type_t data_type; - - /** Size of the data including padding in each dimension. */ - mkldnn_dims_t padded_dims; - /** Per-dimension offset from the padding to actual data, the top-level - * tensor with offsets applied must lie within the padding area. */ - mkldnn_dims_t padded_offsets; - - /** Offset from memory origin to the current block, non-zero only in - * a description of a memory sub-block. */ - mkldnn_dim_t offset0; - - /** Memory format kind. */ - mkldnn_format_kind_t format_kind; - union { - /** Description of the data layout for memory formats that use - * blocking. */ - mkldnn_blocking_desc_t blocking; - /** Tensor of weights for integer 8bit winograd convolution. */ - mkldnn_wino_desc_t wino_desc; - /** Tensor of packed weights for RNN. */ - mkldnn_rnn_packed_desc_t rnn_packed_desc; - /* ... other descriptions possible */ - } format_desc; - - mkldnn_memory_extra_desc_t extra; -} mkldnn_memory_desc_t; - -/** @struct mkldnn_memory - * An opaque structure to describe a memory. */ -struct mkldnn_memory; - -/** A memory handle. */ -typedef struct mkldnn_memory *mkldnn_memory_t; - -/** A constant memory handle. */ -typedef const struct mkldnn_memory *const_mkldnn_memory_t; - -#define MKLDNN_NATIVE_HANDLE_NONE (NULL) -#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1) - -/** @} */ - -/** @addtogroup c_api_types_op_descs Operation descriptors - * @{*/ - -/** A pointer to any of the operation descriptors. */ -typedef void *mkldnn_op_desc_t; -/** A pointer to any of the operation descriptors (constant variant). */ -typedef const void *const_mkldnn_op_desc_t; - -/** A descriptor of a convolution operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_convolution. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, #mkldnn_backward_data, - * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ - mkldnn_prop_kind_t prop_kind; - /** The kind of the convolution algorithm. Possible values: - * #mkldnn_convolution_direct. */ - mkldnn_alg_kind_t alg_kind; - /** Source memory descriptor. */ - mkldnn_memory_desc_t src_desc; - /** Source gradient memory descriptor. */ - mkldnn_memory_desc_t diff_src_desc; - /** Weights memory descriptor. */ - mkldnn_memory_desc_t weights_desc; - /** Weights gradient memory descriptor. */ - mkldnn_memory_desc_t diff_weights_desc; - /** Bias memory descriptor. */ - mkldnn_memory_desc_t bias_desc; - /** Bias gradient memory descriptor. */ - mkldnn_memory_desc_t diff_bias_desc; - /** Destination memory descriptor. */ - mkldnn_memory_desc_t dst_desc; - /** Destination gradient memory descriptor. */ - mkldnn_memory_desc_t diff_dst_desc; - /** Convolution strides in each spatial dimension. */ - mkldnn_dims_t strides; - /** Convolution dilates in each spatial dimension. */ - mkldnn_dims_t dilates; - /** Padding in each spatial dimension. padding[0] is a padding in the - * beginning (@p padding_l), padding[1] is a padding in the end (@p - * padding_r). */ - mkldnn_dims_t padding[2]; - /** The kind of padding to use. */ - mkldnn_padding_kind_t padding_kind; - /** The accumulator data type. Initialized automatically. */ - mkldnn_data_type_t accum_data_type; -} mkldnn_convolution_desc_t; - -/** A descriptor of a deconvolution operation. */ -typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t; - -/** A descriptor of a shuffle operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_convolution. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, and #mkldnn_backward_data. */ - mkldnn_prop_kind_t prop_kind; - /** Source and destination memory descriptor, - * and source and destination gradient memory descriptor. */ - mkldnn_memory_desc_t data_desc; - /** axis for shuffling. */ - int axis; - /** number of groups in group convolution */ - mkldnn_dim_t group_size; -} mkldnn_shuffle_desc_t; - -/** A descriptor of a element-wise operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_eltwise. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. - */ - mkldnn_prop_kind_t prop_kind; - /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu, - * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square, - * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear, - * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and - * #mkldnn_eltwise_logistic. */ - mkldnn_alg_kind_t alg_kind; - /** Source and destination memory descriptor. */ - mkldnn_memory_desc_t data_desc; - /** Source and destination gradient memory descriptor. */ - mkldnn_memory_desc_t diff_data_desc; - /** Algorithm specific parameter. - * Accordance table: - * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored - * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored - * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored - * - #mkldnn_eltwise_square: @p alpha and @p beta ignored - * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored - * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored - * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift - * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored - * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored - * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored - */ - float alpha, beta; -} mkldnn_eltwise_desc_t; - -/** A descriptor of a Softmax operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_softmax. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training and - * #mkldnn_forward_inference. */ - mkldnn_prop_kind_t prop_kind; - /** Source and destination memory descriptor. */ - mkldnn_memory_desc_t data_desc; - /** Source and Destination of gradient memory descriptor. */ - mkldnn_memory_desc_t diff_desc; - /** The axis along which to perform the softmax. */ - int softmax_axis; -} mkldnn_softmax_desc_t; - -/** A descriptor of a pooling operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_pooling. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. - */ - mkldnn_prop_kind_t prop_kind; - /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and - * #mkldnn_pooling_avg. */ - mkldnn_alg_kind_t alg_kind; - /** Source memory descriptor. */ - mkldnn_memory_desc_t src_desc; - /** Source gradient memory descriptor. */ - mkldnn_memory_desc_t diff_src_desc; - /** Destination memory descriptor. */ - mkldnn_memory_desc_t dst_desc; - /** Destination gradient memory descriptor. */ - mkldnn_memory_desc_t diff_dst_desc; - /** Pooling kernel strides for spatial dimensions. */ - mkldnn_dims_t strides; - /** Pooling kernel spatial dimensions. */ - mkldnn_dims_t kernel; - /** Padding in each spatial dimension. padding[0] is a padding in the - * beginning (@p padding_l), padding[1] is a padding in the end (@p - * padding_r). */ - mkldnn_dims_t padding[2]; - /** The kind of padding to use. */ - mkldnn_padding_kind_t padding_kind; - /** The accumulator data type. Initialized automatically. */ - mkldnn_data_type_t accum_data_type; -} mkldnn_pooling_desc_t; - -/** A descriptor of a Local Response Normalization (LRN) operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_lrn. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. - */ - mkldnn_prop_kind_t prop_kind; - /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and - * #mkldnn_lrn_across_channels. */ - mkldnn_alg_kind_t alg_kind; - /** Source and destination memory descriptor. */ - mkldnn_memory_desc_t data_desc; - /** Source and destination gradient memory descriptor. */ - mkldnn_memory_desc_t diff_data_desc; - /** The number of channels to sum over (for cross-channel LRN) or the side - * length of the square region to sum over (for within-channel LRN). */ - mkldnn_dim_t local_size; - /** LRN alpha parameter. */ - float lrn_alpha; - /** LRN beta parameter. */ - float lrn_beta; - /** LRN k parameter. */ - float lrn_k; -} mkldnn_lrn_desc_t; - -/** A descriptor of a Batch Normalization operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_batch_normalization. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. - */ - mkldnn_prop_kind_t prop_kind; - /** Source and destination memory descriptor. */ - mkldnn_memory_desc_t data_desc; - /** Source and destination gradient memory descriptor. */ - mkldnn_memory_desc_t diff_data_desc; - /** Scale and shift data and gradient memory descriptors. - * - * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st - * dimension contains gamma parameter, 2-nd dimension contains beta - * parameter. */ - mkldnn_memory_desc_t data_scaleshift_desc; - mkldnn_memory_desc_t diff_data_scaleshift_desc; - /** Mean and variance data memory descriptors. - * - * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels]. - */ - mkldnn_memory_desc_t mean_desc; - mkldnn_memory_desc_t variance_desc; - /** Batch normalization epsilon parameter. */ - float batch_norm_epsilon; - unsigned flags; -} mkldnn_batch_normalization_desc_t; - -/** A descriptor of an inner product operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_inner_product. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, #mkldnn_backward_data, - * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ - mkldnn_prop_kind_t prop_kind; - /** Source memory descriptor. */ - mkldnn_memory_desc_t src_desc; - /** Source gradient memory descriptor. */ - mkldnn_memory_desc_t diff_src_desc; - /** Weights memory descriptor. */ - mkldnn_memory_desc_t weights_desc; - /** Weights gradient memory descriptor. */ - mkldnn_memory_desc_t diff_weights_desc; - /** Bias memory descriptor. */ - mkldnn_memory_desc_t bias_desc; - /** Bias gradient memory descriptor. */ - mkldnn_memory_desc_t diff_bias_desc; - /** Destination memory descriptor. */ - mkldnn_memory_desc_t dst_desc; - /** Destination gradient memory descriptor. */ - mkldnn_memory_desc_t diff_dst_desc; - /** The accumulator data type. Initialized automatically. */ - mkldnn_data_type_t accum_data_type; -} mkldnn_inner_product_desc_t; - -/** Flags for RNN cell. */ -typedef enum { - mkldnn_rnn_cell_with_relu = 0x1U, - mkldnn_rnn_cell_with_clipping = 0x2U, -} mkldnn_rnn_cell_flags_t; - -typedef struct { - /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn, - * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, - * or #mkldnn_gru_linear_before_reset. */ - mkldnn_alg_kind_t cell_kind; - /** Activation function used. Must be either #mkldnn_eltwise_relu or - * #mkldnn_eltwise_tanh. */ - mkldnn_alg_kind_t activation_kind; - /** RNN cell flags */ - unsigned int flags; - /** @c alpha is a negative slope parameter (used only if - * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */ - float alpha; - /** clipping parameter (used only if - * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */ - float clipping; -} mkldnn_rnn_cell_desc_t; - -/** A direction of RNN primitive execution. */ -typedef enum { - /* Unidirectional execution of RNN primitive from left to right. */ - mkldnn_unidirectional_left2right, - /* Unidirectional execution of RNN primitive from right to left. */ - mkldnn_unidirectional_right2left, - /* Bidirectional execution of RNN primitive with concatenation of the - * results. */ - mkldnn_bidirectional_concat, - /* Bidirectional execution of RNN primitive with summation of the - * results. */ - mkldnn_bidirectional_sum, - mkldnn_unidirectional = mkldnn_unidirectional_left2right, -} mkldnn_rnn_direction_t; - -/** A descriptor for an RNN operation. */ -typedef struct { - /** The kind of primitive. Used for self-identifying the primitive - * descriptor. Must be #mkldnn_rnn. */ - mkldnn_primitive_kind_t primitive_kind; - /** The kind of propagation. Possible values: #mkldnn_forward_training, - * #mkldnn_forward_inference, and #mkldnn_backward. */ - mkldnn_prop_kind_t prop_kind; - /** The RNN cell desc. */ - mkldnn_rnn_cell_desc_t cell_desc; - /** The direction of RNN primitive execution. */ - mkldnn_rnn_direction_t direction; - /** Source layer memory descriptor. */ - mkldnn_memory_desc_t src_layer_desc; - /** Source iteration memory descriptor. */ - mkldnn_memory_desc_t src_iter_desc; - /** Weights layer memory descriptor. */ - mkldnn_memory_desc_t weights_layer_desc; - /** Weights iteration memory descriptor. */ - mkldnn_memory_desc_t weights_iter_desc; - /** Bias memory descriptor. */ - mkldnn_memory_desc_t bias_desc; - /** Destination layer memory descriptor. */ - mkldnn_memory_desc_t dst_layer_desc; - /** Destination iter memory descriptor. */ - mkldnn_memory_desc_t dst_iter_desc; - /** Source gradient layer memory descriptor. */ - mkldnn_memory_desc_t diff_src_layer_desc; - /** Source gradient iter memory descriptor. */ - mkldnn_memory_desc_t diff_src_iter_desc; - /** Weights gradient layer memory descriptor. */ - mkldnn_memory_desc_t diff_weights_layer_desc; - /** Weights gradient iter memory descriptor. */ - mkldnn_memory_desc_t diff_weights_iter_desc; - /** Bias gradient memory descriptor. */ - mkldnn_memory_desc_t diff_bias_desc; - /** Destination gradient layer memory descriptor. */ - mkldnn_memory_desc_t diff_dst_layer_desc; - /** Destination gradient iteration memory descriptor. */ - mkldnn_memory_desc_t diff_dst_iter_desc; -} mkldnn_rnn_desc_t; - -/** @} */ - -/** @addtogroup c_api_engine_types Engine - * @{ */ - -/** @brief Kinds of engines. */ -typedef enum { - /** An unspecified engine. */ - mkldnn_any_engine, - /** CPU engine. */ - mkldnn_cpu, -} mkldnn_engine_kind_t; - -/** @struct mkldnn_engine - * @brief An opaque structure to describe an engine. */ -struct mkldnn_engine; -/** @brief An engine handle. */ -typedef struct mkldnn_engine *mkldnn_engine_t; -#if 0 -/* FIXME: looks like this never happens */ -/** @brief A constant engine handle. */ -typedef const struct mkldnn_engine *const_mkldnn_engine_t; -#endif - -/** @} */ - -/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators - * @{ */ - -/** @struct mkldnn_primitive_desc_iterator - * @brief An opaque structure to describe a primitive descriptor iterator. */ -struct mkldnn_primitive_desc_iterator; - -/** @brief A primitive descriptor iterator handle. */ -typedef struct mkldnn_primitive_desc_iterator - *mkldnn_primitive_desc_iterator_t; - -/** @brief A constant primitive descriptor iterator handle. */ -typedef const struct mkldnn_primitive_desc_iterator - *const_mkldnn_primitive_desc_iterator_t; - -/** @} */ - -/** @addtogroup c_api_primitive_descs Primitive descriptors - * @{ */ - -/** @struct mkldnn_primitive_desc - * @brief An opaque structure to describe a primitive descriptor. */ -struct mkldnn_primitive_desc; - -/** @brief A primitive descriptor handle. */ -typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t; - -/** @brief A constant primitive descriptor handle. */ -typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t; - -/** @} */ - -/** @addtogroup c_api_primitive_attr Primitive descriptor attributes - * @{ */ - -/** Scratchpad mode */ -typedef enum { - /** The library manages scratchpad (default) */ - mkldnn_scratchpad_mode_library, - /** A user shall query and provide the scratchpad memory to primitives */ - mkldnn_scratchpad_mode_user, -} mkldnn_scratchpad_mode_t; - -/** @struct mkldnn_primitive_attr - * @brief An opaque structure for primitive descriptor attributes. - * - * Attributes may contain: - * - output scales (to scale the result prior to storing it to the memory) - */ -struct mkldnn_primitive_attr; - -/** @brief A primitive descriptor attributes handle that controls primitive - * behavior. */ -typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t; - -/** @brief A constant primitive descriptor attributes handle. */ -typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t; - -/** @struct mkldnn_post_ops - * @brief An opaque structure for a chain of post operations. - * - * mkldnn_post_ops can be used to perform some (trivial) operations like - * accumulation or eltwise after certain primitives like convolution. - * - * Post operations might be combined together, making a chain of post - * operations. For instance one can configure convolution followed by - * accumulation followed by eltwise. This might be especially beneficial - * for residual learning blocks. - * - * @warning - * Of course not all combinations are supported, so the user should handle - * errors accordingly. - * - * Supported post operations: - * - accumulation (base primitive: convolution) - * - eltwise (base primitive: convolution) - */ -struct mkldnn_post_ops; - -/** @brief A post operation chain handle. */ -typedef struct mkldnn_post_ops *mkldnn_post_ops_t; - -/** @brief A constant post operation chain handle. */ -typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t; - -/** @} */ - -/** @addtogroup c_api_types_primitive Primitive - * @{ */ - -/** @struct mkldnn_primitive - * An opaque structure to describe a primitive. */ -struct mkldnn_primitive; -/** A primitive handle. */ -typedef struct mkldnn_primitive *mkldnn_primitive_t; -/** A constant primitive handle. */ -typedef const struct mkldnn_primitive *const_mkldnn_primitive_t; - -/** @addtogroup c_api_types_arguments Argument indices - * @{ */ - -#define MKLDNN_ARG_SRC_0 1 -#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0 -#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0 -#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0 - -#define MKLDNN_ARG_SRC_1 2 -#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1 - -#define MKLDNN_ARG_DST_0 17 -#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0 -#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0 -#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0 - -#define MKLDNN_ARG_DST_1 18 -#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1 - -#define MKLDNN_ARG_WEIGHTS_0 33 -#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0 -#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0 -#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0 - -#define MKLDNN_ARG_WEIGHTS_1 34 -#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1 - -#define MKLDNN_ARG_BIAS 41 - -#define MKLDNN_ARG_MEAN 49 -#define MKLDNN_ARG_VARIANCE 50 - -#define MKLDNN_ARG_WORKSPACE 64 -#define MKLDNN_ARG_SCRATCHPAD 80 - -#define MKLDNN_ARG_DIFF_SRC_0 129 -#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0 -#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0 - -#define MKLDNN_ARG_DIFF_SRC_1 130 -#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1 - -#define MKLDNN_ARG_DIFF_DST_0 145 -#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0 -#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0 - -#define MKLDNN_ARG_DIFF_DST_1 146 -#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1 - -#define MKLDNN_ARG_DIFF_WEIGHTS_0 161 -#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0 -#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0 -#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0 - -#define MKLDNN_ARG_DIFF_WEIGHTS_1 162 -#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1 - -#define MKLDNN_ARG_DIFF_BIAS 169 - -#define MKLDNN_ARG_MULTIPLE_SRC 1024 -#define MKLDNN_ARG_MULTIPLE_DST 2048 - -/** @} */ - -/** An auxiliary structure to specify primitive's inputs/outputs at execution - * - * @warning - * With this API it's impossible to preserve constness of memory, so all - * memories are passed w/o const qualifier. However only memories with - * output semantics might be changed during the execution */ -typedef struct { - int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */ - mkldnn_memory_t memory; /**< Input/output memory */ -} mkldnn_exec_arg_t; - -/** @} */ - -/** @addtogroup c_api_types_query Queries - * @{ */ - -/** Primitive descriptor query specification - * - * For generic function mkldnn_primitive_desc_query(), the type of result must - * agree with the queried argument. The correspondence table: - * Query | type of result - * -------------------------------------------------------------- - * #mkldnn_query_engine | mkldnn_engine_t * - * #mkldnn_query_scratchpad_engine | mkldnn_engine_t * - * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t * - * *_s32 | int * - * *_s64 | mkldnn_dim_t * (same as int64_t *) - * *_f64 | double * - * *_str | const char ** - * #mkldnn_query_op_d | const_mkldnn_op_desc_t * - * *_md | const mkldnn_memory_desc_t ** - * *_${op}_d | const mkldnn_${op}_desc_t ** - * *_pd | const_mkldnn_primitive_desc_t * - * - * @note - * Rule of thumb: all opaque types and structures are returned by - * reference. All numbers are returned by value. - * - * @warning - * All returned references point to constant objects and are valid only - * during the lifetime of the queried primitive descriptor. Returned objects - * must not be destroyed by the user. If you need to keep the object longer - * than the lifetime of the queried primitive descriptor, use - * mkldnn_primitive_desc_clone() to make a copy. */ -typedef enum { - mkldnn_query_undef = 0, /**< no query */ - - mkldnn_query_engine, /**< execution engine */ - mkldnn_query_primitive_kind, /**< primitive kind */ - - mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */ - mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */ - - mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */ - mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra - (scratch) memory, additional to all - inputs and outputs memory (bytes) */ - - mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used - for creating scratchpad memory */ - - mkldnn_query_impl_info_str, /**< implementation name */ - - /* memory and op descriptor section */ - mkldnn_query_some_d = 64, /**< stub */ - mkldnn_query_op_d, /**< op descriptor */ - mkldnn_query_convolution_d, /**< convolution descriptor */ - mkldnn_query_deconvolution_d, /**< deconvolution descriptor */ - mkldnn_query_shuffle_d, /**< shuffle descriptor */ - mkldnn_query_eltwise_d, /**< eltwise descriptor */ - mkldnn_query_softmax_d, /**< softmax descriptor */ - mkldnn_query_pooling_d, /**< pooling descriptor */ - mkldnn_query_lrn_d, /**< lrn descriptor */ - mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */ - mkldnn_query_inner_product_d, /**< inner product descriptor */ - mkldnn_query_rnn_d, /**< rnn descriptor */ - - /* memory descriptor section */ - mkldnn_query_some_md = 128, /**< stub */ - mkldnn_query_src_md, /**< source memory desc */ - mkldnn_query_diff_src_md, /**< source gradient memory desc */ - mkldnn_query_weights_md, /**< weights memory descriptor desc */ - mkldnn_query_diff_weights_md, /**< weights grad. memory desc */ - mkldnn_query_dst_md, /**< destination memory desc */ - mkldnn_query_diff_dst_md, /**< destination grad. memory desc */ - mkldnn_query_workspace_md, /**< workspace memory desc */ - mkldnn_query_scratchpad_md, /**< scratchpad memory desc */ -} mkldnn_query_t; - -/** @} */ - -/** @addtogroup c_api_types_stream Execution stream - * @{ */ - -/** @brief Stream flags. */ -typedef enum { - /** A default stream configuration. */ - mkldnn_stream_default_flags = 0x0U, -} mkldnn_stream_flags_t; - -/** @struct mkldnn_stream - * An opaque structure to describe an execution stream. */ -struct mkldnn_stream; -/** An execution stream handle. */ -typedef struct mkldnn_stream *mkldnn_stream_t; -/** A constant execution stream handle. */ -typedef const struct mkldnn_stream *const_mkldnn_stream_t; - -/** @} */ -/** @} */ -/** @} */ - -#ifdef __cplusplus -} -#endif - - -#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h deleted file mode 100644 index a2713decc..000000000 --- a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h +++ /dev/null @@ -1,32 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_VERSION_H -#define MKLDNN_VERSION_H - -/* Major version of MKL-DNN */ -#define MKLDNN_VERSION_MAJOR 0 - -/* Minor version of MKL-DNN */ -#define MKLDNN_VERSION_MINOR 90 - -/* Patch version of MKL-DNN */ -#define MKLDNN_VERSION_PATCH 0 - -/* Git Commit Hash of MKL-DNN */ -#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c" - -#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in deleted file mode 100644 index 5ee012618..000000000 --- a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in +++ /dev/null @@ -1,32 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_VERSION_H -#define MKLDNN_VERSION_H - -/* Major version of MKL-DNN */ -#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@ - -/* Minor version of MKL-DNN */ -#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@ - -/* Patch version of MKL-DNN */ -#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@ - -/* Git Commit Hash of MKL-DNN */ -#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@" - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp deleted file mode 100644 index 1a51d8562..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp +++ /dev/null @@ -1,104 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc, - prop_kind_t prop_kind, const memory_desc_t *data_desc, - const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) { - bool args_ok = true - && !any_null(bnrm_desc, data_desc) - && one_of(prop_kind, forward_training, forward_inference, - backward_data, backward) - && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr); - if (!args_ok) return invalid_arguments; - - auto bd = batch_normalization_desc_t(); - bd.primitive_kind = primitive_kind::batch_normalization; - bd.prop_kind = prop_kind; - - bd.data_desc = *data_desc; - bd.diff_data_desc = zero_md(); - if ( one_of(bd.prop_kind,backward_data, backward) ) - bd.diff_data_desc = *diff_data_desc; - - dims_t scaleshift_dims = { 2, data_desc->dims[1] }; - mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2, - scaleshift_dims, data_type::f32, mkldnn_nc); - bd.diff_data_scaleshift_desc = zero_md(); - if (bd.prop_kind == backward) { - bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc; - } - - dims_t stats_dims = { data_desc->dims[1] }; - mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims, - data_type::f32, mkldnn_x); - bd.variance_desc = bd.mean_desc; - bd.batch_norm_epsilon = epsilon; - - unsigned bnorm_flags = - mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu; - if ((~bnorm_flags & flags) != 0) return invalid_arguments; - - bd.flags = flags; - - bool consistency = true - && utils::one_of(bd.data_desc.ndims, 2, 4, 5); - if (bd.prop_kind == backward_data) - consistency = consistency - && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5) - && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims, - bd.diff_data_desc.ndims); - if (!consistency) return invalid_arguments; - - *bnrm_desc = bd; - return success; -} -} - -status_t mkldnn_batch_normalization_forward_desc_init( - batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, - const memory_desc_t *data_desc, float epsilon, unsigned flags) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr, - epsilon, flags); -} - -status_t mkldnn_batch_normalization_backward_desc_init( - batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, - const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc, - float epsilon, unsigned flags) { - if (!one_of(prop_kind, backward, backward_data)) - return invalid_arguments; - return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc, - epsilon, flags); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp deleted file mode 100644 index f61410b33..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp +++ /dev/null @@ -1,240 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef BATCH_NORMALIZATION_PD_HPP -#define BATCH_NORMALIZATION_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct batch_normalization_fwd_pd_t; - -struct batch_normalization_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::batch_normalization; - - batch_normalization_pd_t(engine_t *engine, - const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , data_md_(desc_.data_desc) - , stat_md_(desc_.mean_desc) - , scaleshift_md_(desc_.data_scaleshift_desc) - , ws_md_() - {} - - const batch_normalization_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::batch_normalization_d: - *(const batch_normalization_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common batch_normalization aux functions */ - - dim_t MB() const { return data_desc().dims[0]; } - dim_t C() const { return data_desc().dims[1]; } - dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } - dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } - dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } - - int ndims() const { return desc_.data_desc.ndims; } - - bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; } - bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; } - bool use_global_stats() const - { return desc_.flags & mkldnn_use_global_stats; } - bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; } - bool with_relu_post_op() const { - const auto &p = this->attr()->post_ops_; - return p.len_ == 1 && p.entry_[0].is_relu(true, true); - } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - bool is_bwd() const { return !this->is_fwd(); } - bool is_training() const - { return desc_.prop_kind == prop_kind::forward_training; } - - bool has_zero_dim_memory() const - { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } - -protected: - batch_normalization_desc_t desc_; - const batch_normalization_fwd_pd_t *hint_fwd_pd_; - - memory_desc_t data_md_; - memory_desc_t stat_md_; - memory_desc_t scaleshift_md_; - - memory_desc_t ws_md_; - - void init_default_ws(size_t bits_per_element) { - const auto data_mdw = memory_desc_wrapper(data_md_); - - const dim_t data_nelems = data_mdw.nelems(true); - const dim_t bits_per_byte = 8; - const dims_t ws_sz = { (dim_t)utils::div_up( - data_nelems * bits_per_element, bits_per_byte) }; - mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8, - format_tag::x); - } - -private: - const memory_desc_t &data_desc() const { return desc_.data_desc; } -}; - -struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t { - typedef batch_normalization_fwd_pd_t base_class; - typedef batch_normalization_fwd_pd_t hint_class; - - batch_normalization_fwd_pd_t(engine_t *engine, - const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input; - if (arg == MKLDNN_ARG_DST) return arg_usage_t::output; - - if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) { - if (stats_is_src()) return arg_usage_t::input; - if (!stats_is_src() && is_training()) return arg_usage_t::output; - return arg_usage_t::unused; - } - - if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu()) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override { - if (index == 0) return &data_md_; - if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_; - return nullptr; - } - - virtual const memory_desc_t *dst_md(int index = 0) const override { - if (index == 0) return &data_md_; - if (!stats_is_src() && is_training() && (index == 1 || index == 2)) - return &stat_md_; - return nullptr; - } - - virtual const memory_desc_t *weights_md(int index = 0) const override - { return index == 0 ? &scaleshift_md_ : nullptr; } - - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; } - - const memory_desc_t *stat_md() const - { return stats_is_src() ? src_md(1) : dst_md(1); } - - virtual int n_inputs() const override - { return 1 + 2 * stats_is_src() + use_scaleshift(); } - virtual int n_outputs() const override - { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); } -}; - -struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t { - typedef batch_normalization_bwd_pd_t base_class; - typedef batch_normalization_fwd_pd_t hint_class; - - batch_normalization_bwd_pd_t(engine_t *engine, - const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_data_md_(desc_.diff_data_desc) - , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN, - MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift()) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - - virtual const memory_desc_t *weights_md(int index = 0) const override - { return index == 0 ? &scaleshift_md_ : nullptr; } - virtual const memory_desc_t *diff_weights_md(int index = 0) const override - { return index == 0 ? &diff_scaleshift_md_ : nullptr; } - - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; } - - const memory_desc_t *stat_md() const { return src_md(1); } - - virtual int n_inputs() const override - { return 4 + use_scaleshift() + fuse_bn_relu(); } - virtual int n_outputs() const override - { return 1 + (desc_.prop_kind == prop_kind::backward); } - -protected: - memory_desc_t diff_data_md_; - memory_desc_t diff_scaleshift_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp deleted file mode 100644 index 3d43a0fbe..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp +++ /dev/null @@ -1,550 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef TYPE_MAPPING_HPP -#define TYPE_MAPPING_HPP - -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { - -// TODO: autogenerate this - -using dim_t = mkldnn_dim_t; -using dims_t = mkldnn_dims_t; -using stride_t = mkldnn_dim_t; -using strides_t = mkldnn_strides_t; - -using status_t = mkldnn_status_t; -namespace status { - const status_t success = mkldnn_success; - const status_t out_of_memory = mkldnn_out_of_memory; - const status_t try_again = mkldnn_try_again; - const status_t invalid_arguments = mkldnn_invalid_arguments; - const status_t not_ready = mkldnn_not_ready; - const status_t unimplemented = mkldnn_unimplemented; - const status_t iterator_ends = mkldnn_iterator_ends; - const status_t runtime_error = mkldnn_runtime_error; - const status_t not_required = mkldnn_not_required; -} - -using prop_kind_t = mkldnn_prop_kind_t; -namespace prop_kind { - const prop_kind_t undef = mkldnn_prop_kind_undef; - const prop_kind_t forward_training = mkldnn_forward_training; - const prop_kind_t forward_inference = mkldnn_forward_inference; - const prop_kind_t forward_scoring = mkldnn_forward_scoring; - const prop_kind_t forward = mkldnn_forward; - const prop_kind_t backward = mkldnn_backward; - const prop_kind_t backward_data = mkldnn_backward_data; - const prop_kind_t backward_weights = mkldnn_backward_weights; - const prop_kind_t backward_bias = mkldnn_backward_bias; -} - -using alg_kind_t = mkldnn_alg_kind_t; -namespace alg_kind { - const alg_kind_t undef = mkldnn_alg_kind_undef; - const alg_kind_t convolution_auto = mkldnn_convolution_auto; - const alg_kind_t convolution_direct = mkldnn_convolution_direct; - const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; - const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct; - const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd; - const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; - const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh; - const alg_kind_t eltwise_elu = mkldnn_eltwise_elu; - const alg_kind_t eltwise_square = mkldnn_eltwise_square; - const alg_kind_t eltwise_abs = mkldnn_eltwise_abs; - const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt; - const alg_kind_t eltwise_linear = mkldnn_eltwise_linear; - const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu; - const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu; - const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic; - const alg_kind_t pooling_max = mkldnn_pooling_max; - const alg_kind_t pooling_avg = mkldnn_pooling_avg; - const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding; - const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; - const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; - const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; - const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; - const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; - const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; - const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset; -} - -using data_type_t = mkldnn_data_type_t; -namespace data_type { - const data_type_t undef = mkldnn_data_type_undef; - const data_type_t f32 = mkldnn_f32; - const data_type_t s32 = mkldnn_s32; - const data_type_t s8 = mkldnn_s8; - const data_type_t u8 = mkldnn_u8; -} - -using scratchpad_mode_t = mkldnn_scratchpad_mode_t; -namespace scratchpad_mode { - const scratchpad_mode_t library = mkldnn_scratchpad_mode_library; - const scratchpad_mode_t user = mkldnn_scratchpad_mode_user; -} - -using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t; -namespace rnn_packed_format { - const rnn_packed_format_t undef = mkldnn_packed_format_undef; - const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p; - const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p; -} - -using format_kind_t = mkldnn_format_kind_t; -namespace format_kind { - const format_kind_t undef = mkldnn_format_kind_undef; - const format_kind_t any = mkldnn_format_kind_any; - const format_kind_t blocked = mkldnn_blocked; - const format_kind_t wino = mkldnn_format_kind_wino; - const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed; -} - -using format_tag_t = mkldnn_format_tag_t; -namespace format_tag { - const format_tag_t undef = mkldnn_format_tag_undef; - const format_tag_t any = mkldnn_format_tag_any; - const format_tag_t a = mkldnn_a; - const format_tag_t ab = mkldnn_ab; - const format_tag_t abc = mkldnn_abc; - const format_tag_t abcd = mkldnn_abcd; - const format_tag_t abcde = mkldnn_abcde; - const format_tag_t abcdef = mkldnn_abcdef; - const format_tag_t abdec = mkldnn_abdec; - const format_tag_t acb = mkldnn_acb; - const format_tag_t acbde = mkldnn_acbde; - const format_tag_t acdb = mkldnn_acdb; - const format_tag_t acdeb = mkldnn_acdeb; - const format_tag_t ba = mkldnn_ba; - const format_tag_t bac = mkldnn_bac; - const format_tag_t bacd = mkldnn_bacd; - const format_tag_t bcda = mkldnn_bcda; - const format_tag_t cba = mkldnn_cba; - const format_tag_t cdba = mkldnn_cdba; - const format_tag_t cdeba = mkldnn_cdeba; - const format_tag_t decab = mkldnn_decab; - const format_tag_t Abc16a = mkldnn_Abc16a; - const format_tag_t ABc16a16b = mkldnn_ABc16a16b; - const format_tag_t aBc16b = mkldnn_aBc16b; - const format_tag_t ABc16b16a = mkldnn_ABc16b16a; - const format_tag_t Abc4a = mkldnn_Abc4a; - const format_tag_t aBc4b = mkldnn_aBc4b; - const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b; - const format_tag_t ABc4b4a = mkldnn_ABc4b4a; - const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a; - const format_tag_t ABc8a8b = mkldnn_ABc8a8b; - const format_tag_t aBc8b = mkldnn_aBc8b; - const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b; - const format_tag_t ABc8b8a = mkldnn_ABc8b8a; - const format_tag_t Abcd16a = mkldnn_Abcd16a; - const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b; - const format_tag_t aBcd16b = mkldnn_aBcd16b; - const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a; - const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c; - const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b; - const format_tag_t Abcd4a = mkldnn_Abcd4a; - const format_tag_t aBcd4b = mkldnn_aBcd4b; - const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b; - const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a; - const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c; - const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b; - const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a; - const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b; - const format_tag_t aBcd8b = mkldnn_aBcd8b; - const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b; - const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b; - const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a; - const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c; - const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c; - const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b; - const format_tag_t Abcde16a = mkldnn_Abcde16a; - const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b; - const format_tag_t aBcde16b = mkldnn_aBcde16b; - const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a; - const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c; - const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b; - const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c; - const format_tag_t Abcde4a = mkldnn_Abcde4a; - const format_tag_t aBcde4b = mkldnn_aBcde4b; - const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a; - const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c; - const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c; - const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b; - const format_tag_t Abcde8a = mkldnn_Abcde8a; - const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b; - const format_tag_t aBcde8b = mkldnn_aBcde8b; - const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b; - const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b; - const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a; - const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c; - const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c; - const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b; - const format_tag_t aBcdef16b = mkldnn_aBcdef16b; - const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c; - const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b; - const format_tag_t aBcdef4b = mkldnn_aBcdef4b; - const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b; - const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c; - const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c; - const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b; - const format_tag_t aBdc16b = mkldnn_aBdc16b; - const format_tag_t aBdc4b = mkldnn_aBdc4b; - const format_tag_t aBdc8b = mkldnn_aBdc8b; - const format_tag_t aBdec16b = mkldnn_aBdec16b; - const format_tag_t aBdec4b = mkldnn_aBdec4b; - const format_tag_t aBdec8b = mkldnn_aBdec8b; - const format_tag_t aBdefc16b = mkldnn_aBdefc16b; - const format_tag_t aBdefc4b = mkldnn_aBdefc4b; - const format_tag_t aBdefc8b = mkldnn_aBdefc8b; - const format_tag_t Acb16a = mkldnn_Acb16a; - const format_tag_t Acb4a = mkldnn_Acb4a; - const format_tag_t Acb8a = mkldnn_Acb8a; - const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c; - const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c; - const format_tag_t Acdb16a = mkldnn_Acdb16a; - const format_tag_t Acdb4a = mkldnn_Acdb4a; - const format_tag_t Acdb8a = mkldnn_Acdb8a; - const format_tag_t Acdeb16a = mkldnn_Acdeb16a; - const format_tag_t Acdeb4a = mkldnn_Acdeb4a; - const format_tag_t Acdeb8a = mkldnn_Acdeb8a; - const format_tag_t BAc16a16b = mkldnn_BAc16a16b; - const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b; - const format_tag_t last = mkldnn_format_tag_last; - - const format_tag_t x = mkldnn_x; - const format_tag_t nc = mkldnn_nc; - const format_tag_t cn = mkldnn_cn; - const format_tag_t ncw = mkldnn_ncw; - const format_tag_t nwc = mkldnn_nwc; - const format_tag_t nchw = mkldnn_nchw; - const format_tag_t nhwc = mkldnn_nhwc; - const format_tag_t chwn = mkldnn_chwn; - const format_tag_t ncdhw = mkldnn_ncdhw; - const format_tag_t ndhwc = mkldnn_ndhwc; - const format_tag_t oi = mkldnn_oi; - const format_tag_t io = mkldnn_io; - const format_tag_t oiw = mkldnn_oiw; - const format_tag_t wio = mkldnn_wio; - const format_tag_t oihw = mkldnn_oihw; - const format_tag_t hwio = mkldnn_hwio; - const format_tag_t ihwo = mkldnn_ihwo; - const format_tag_t iohw = mkldnn_iohw; - const format_tag_t oidhw = mkldnn_oidhw; - const format_tag_t dhwio = mkldnn_dhwio; - const format_tag_t goiw = mkldnn_goiw; - const format_tag_t goihw = mkldnn_goihw; - const format_tag_t hwigo = mkldnn_hwigo; - const format_tag_t giohw = mkldnn_giohw; - const format_tag_t goidhw = mkldnn_goidhw; - const format_tag_t tnc = mkldnn_tnc; - const format_tag_t ntc = mkldnn_ntc; - const format_tag_t ldsnc = mkldnn_ldsnc; - const format_tag_t ldigo = mkldnn_ldigo; - const format_tag_t ldgoi = mkldnn_ldgoi; - const format_tag_t ldgo = mkldnn_ldgo; - const format_tag_t nCdhw16c = mkldnn_nCdhw16c; - const format_tag_t nCdhw4c = mkldnn_nCdhw4c; - const format_tag_t nCdhw8c = mkldnn_nCdhw8c; - const format_tag_t nChw16c = mkldnn_nChw16c; - const format_tag_t nChw4c = mkldnn_nChw4c; - const format_tag_t nChw8c = mkldnn_nChw8c; - const format_tag_t nCw16c = mkldnn_nCw16c; - const format_tag_t nCw4c = mkldnn_nCw4c; - const format_tag_t nCw8c = mkldnn_nCw8c; - const format_tag_t IOw16o16i = mkldnn_IOw16o16i; - const format_tag_t OIw16i16o = mkldnn_OIw16i16o; - const format_tag_t OIw16o16i = mkldnn_OIw16o16i; - const format_tag_t Oiw16o = mkldnn_Oiw16o; - const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i; - const format_tag_t OIw4i4o = mkldnn_OIw4i4o; - const format_tag_t Oiw4o = mkldnn_Oiw4o; - const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i; - const format_tag_t OIw8i8o = mkldnn_OIw8i8o; - const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o; - const format_tag_t OIw8o8i = mkldnn_OIw8o8i; - const format_tag_t Owi16o = mkldnn_Owi16o; - const format_tag_t Owi4o = mkldnn_Owi4o; - const format_tag_t Owi8o = mkldnn_Owi8o; - const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i; - const format_tag_t Ohwi16o = mkldnn_Ohwi16o; - const format_tag_t Ohwi4o = mkldnn_Ohwi4o; - const format_tag_t Ohwi8o = mkldnn_Ohwi8o; - const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o; - const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i; - const format_tag_t Oihw16o = mkldnn_Oihw16o; - const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i; - const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o; - const format_tag_t Oihw4o = mkldnn_Oihw4o; - const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i; - const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o; - const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o; - const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i; - const format_tag_t Odhwi16o = mkldnn_Odhwi16o; - const format_tag_t Odhwi4o = mkldnn_Odhwi4o; - const format_tag_t Odhwi8o = mkldnn_Odhwi8o; - const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o; - const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i; - const format_tag_t Oidhw16o = mkldnn_Oidhw16o; - const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o; - const format_tag_t Oidhw4o = mkldnn_Oidhw4o; - const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i; - const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o; - const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i; - const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i; - const format_tag_t Goiw16g = mkldnn_Goiw16g; - const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o; - const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i; - const format_tag_t gOiw16o = mkldnn_gOiw16o; - const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i; - const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o; - const format_tag_t gOiw4o = mkldnn_gOiw4o; - const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i; - const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o; - const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; - const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i; - const format_tag_t gOwi16o = mkldnn_gOwi16o; - const format_tag_t gOwi4o = mkldnn_gOwi4o; - const format_tag_t gOwi8o = mkldnn_gOwi8o; - const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i; - const format_tag_t gOhwi16o = mkldnn_gOhwi16o; - const format_tag_t gOhwi4o = mkldnn_gOhwi4o; - const format_tag_t gOhwi8o = mkldnn_gOhwi8o; - const format_tag_t Goihw16g = mkldnn_Goihw16g; - const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o; - const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i; - const format_tag_t gOihw16o = mkldnn_gOihw16o; - const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i; - const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i; - const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o; - const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i; - const format_tag_t gOihw4o = mkldnn_gOihw4o; - const format_tag_t Goihw8g = mkldnn_Goihw8g; - const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i; - const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o; - const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o; - const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i; - const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o; - const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o; - const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o; - const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o; - const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i; - const format_tag_t gOidhw16o = mkldnn_gOidhw16o; - const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o; - const format_tag_t gOidhw4o = mkldnn_gOidhw4o; - const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i; - const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o; - const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i; -} - -using memory_extra_flags_t = mkldnn_memory_extra_flags_t; -namespace memory_extra_flags { - const memory_extra_flags_t none = mkldnn_memory_extra_flag_none; - const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8; - const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust; -} - -using padding_kind_t = mkldnn_padding_kind_t; -namespace padding_kind { - const padding_kind_t padding_zero = mkldnn_padding_zero; -} - -using engine_kind_t = mkldnn_engine_kind_t; -namespace engine_kind { - const engine_kind_t any_engine = mkldnn_any_engine; - const engine_kind_t cpu = mkldnn_cpu; -} - -using primitive_kind_t = mkldnn_primitive_kind_t; -namespace primitive_kind { - const primitive_kind_t undefined = mkldnn_undefined_primitive; - const primitive_kind_t reorder = mkldnn_reorder; - const primitive_kind_t concat = mkldnn_concat; - const primitive_kind_t sum = mkldnn_sum; - const primitive_kind_t convolution = mkldnn_convolution; - const primitive_kind_t deconvolution = mkldnn_deconvolution; - const primitive_kind_t shuffle = mkldnn_shuffle; - const primitive_kind_t eltwise = mkldnn_eltwise; - const primitive_kind_t softmax = mkldnn_softmax; - const primitive_kind_t pooling = mkldnn_pooling; - const primitive_kind_t lrn = mkldnn_lrn; - const primitive_kind_t batch_normalization = mkldnn_batch_normalization; - const primitive_kind_t inner_product = mkldnn_inner_product; - const primitive_kind_t rnn = mkldnn_rnn; -} - -using query_t = mkldnn_query_t; -namespace query { - const query_t undef = mkldnn_query_undef; - - const query_t engine = mkldnn_query_engine; - const query_t primitive_kind = mkldnn_query_primitive_kind; - - const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32; - const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32; - - const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64; - const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64; - - const query_t scratchpad_engine = mkldnn_query_scratchpad_engine; - - const query_t impl_info_str = mkldnn_query_impl_info_str; - - const query_t some_d = mkldnn_query_some_d; - const query_t op_d = mkldnn_query_op_d; - const query_t convolution_d = mkldnn_query_convolution_d; - const query_t deconvolution_d = mkldnn_query_deconvolution_d; - const query_t shuffle_d = mkldnn_query_shuffle_d; - const query_t eltwise_d = mkldnn_query_eltwise_d; - const query_t softmax_d = mkldnn_query_softmax_d; - const query_t pooling_d = mkldnn_query_pooling_d; - const query_t lrn_d = mkldnn_query_lrn_d; - const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; - const query_t inner_product_d = mkldnn_query_inner_product_d; - const query_t rnn_d = mkldnn_query_rnn_d; - - const query_t some_md = mkldnn_query_some_md; - const query_t src_md = mkldnn_query_src_md; - const query_t diff_src_md = mkldnn_query_diff_src_md; - const query_t weights_md = mkldnn_query_weights_md; - const query_t diff_weights_md = mkldnn_query_diff_weights_md; - const query_t dst_md = mkldnn_query_dst_md; - const query_t diff_dst_md = mkldnn_query_diff_dst_md; - - const query_t workspace_md = mkldnn_query_workspace_md; - const query_t scratchpad_md = mkldnn_query_scratchpad_md; -} - -using blocking_desc_t = mkldnn_blocking_desc_t; -using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t; -using wino_desc_t = mkldnn_wino_desc_t; -using memory_extra_desc_t = mkldnn_memory_extra_desc_t; -using memory_desc_t = mkldnn_memory_desc_t; -using convolution_desc_t = mkldnn_convolution_desc_t; -using deconvolution_desc_t = mkldnn_deconvolution_desc_t; -using shuffle_desc_t = mkldnn_shuffle_desc_t; -using pooling_desc_t = mkldnn_pooling_desc_t; -using eltwise_desc_t = mkldnn_eltwise_desc_t; -using softmax_desc_t = mkldnn_softmax_desc_t; -using lrn_desc_t = mkldnn_lrn_desc_t; -using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; -using inner_product_desc_t = mkldnn_inner_product_desc_t; - -using rnn_direction_t = mkldnn_rnn_direction_t; -using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; -using rnn_desc_t = mkldnn_rnn_desc_t; - -/* C op_desc_t, which eventually are just (void*) */ -using c_op_desc_t = mkldnn_op_desc_t; -using const_c_op_desc_t = const_mkldnn_op_desc_t; - -struct op_desc_t { - union { - primitive_kind_t kind; - convolution_desc_t convolution; - deconvolution_desc_t deconvolution; - shuffle_desc_t shuffle; - pooling_desc_t pooling; - eltwise_desc_t eltwise; - softmax_desc_t softmax; - lrn_desc_t lrn; - batch_normalization_desc_t batch_normalization; - inner_product_desc_t inner_product; - rnn_desc_t rnn; - }; - - op_desc_t(const primitive_kind_t &_): kind(_) {} - -# define DECL_CTOR_AND_CONVERTERS(c_type, name) \ - op_desc_t(const c_type &_): name(_) {} \ - static op_desc_t *convert_from_c(c_type *_) \ - { return reinterpret_cast(_); } \ - static const op_desc_t *convert_from_c(const c_type *_) \ - { return reinterpret_cast(_); } - - DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution); - DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle); - DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling); - DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise); - DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax); - DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn); - DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization); - DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product); - DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn); - -# undef DECL_CTOR_AND_CONVERTERS -}; - -using engine_t = mkldnn_engine; -using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator; -using primitive_desc_t = mkldnn_primitive_desc; -using primitive_attr_t = mkldnn_primitive_attr; -using post_ops_t = mkldnn_post_ops; -using memory_t = mkldnn_memory; -using primitive_t = mkldnn_primitive; - -using primitive_arg_index_t = int; - -using stream_flags_t = mkldnn_stream_flags_t; -namespace stream_flags { - const stream_flags_t default_flags = mkldnn_stream_default_flags; -} -using stream_t = mkldnn_stream; - -/* forward declaration of the internal primitive_desc types */ -struct batch_normalization_bwd_pd_t; -struct batch_normalization_fwd_pd_t; -struct batch_normalization_pd_t; -struct concat_pd_t; -struct convolution_bwd_data_pd_t; -struct convolution_bwd_weights_pd_t; -struct convolution_fwd_pd_t; -struct convolution_pd_t; -struct deconvolution_bwd_data_pd_t; -struct deconvolution_bwd_weights_pd_t; -struct deconvolution_fwd_pd_t; -struct deconvolution_pd_t; -struct eltwise_bwd_pd_t; -struct eltwise_fwd_pd_t; -struct eltwise_pd_t; -struct inner_product_bwd_data_pd_t; -struct inner_product_bwd_weights_pd_t; -struct inner_product_fwd_pd_t; -struct inner_product_pd_t; -struct lrn_bwd_pd_t; -struct lrn_fwd_pd_t; -struct lrn_pd_t; -struct pooling_bwd_pd_t; -struct pooling_fwd_pd_t; -struct pooling_pd_t; -struct reorder_pd_t; -struct rnn_bwd_pd_t; -struct rnn_fwd_pd_t; -struct rnn_pd_t; -struct shuffle_pd_t; -struct softmax_bwd_pd_t; -struct softmax_fwd_pd_t; -struct softmax_pd_t; -struct sum_pd_t; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp deleted file mode 100644 index ed4c35c6e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "concat_pd.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; - -status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd, - const memory_desc_t *dst_md, int n, int concat_dim, - const memory_desc_t *src_mds, - const primitive_attr_t *attr, - engine_t *engine) { - bool args_ok = !any_null(concat_pd, src_mds) && n > 0; - if (!args_ok) return invalid_arguments; - - const primitive_attr_t dummy_attr; - if (attr == NULL) - attr = &dummy_attr; - - const int ndims = src_mds[0].ndims; - const dims_t &dims = src_mds[0].dims; - const data_type_t dt = src_mds[0].data_type; - - int concat_dim_sz = dims[concat_dim]; - for (int i = 1; i < n; ++i) { - if (src_mds[i].ndims != ndims) return invalid_arguments; - for (int d = 0; d < ndims; ++d) { - if (d == concat_dim) continue; - if (src_mds[i].dims[d] != dims[d]) - return invalid_arguments; - } - if (src_mds[i].data_type != dt) return invalid_arguments; - concat_dim_sz += src_mds[i].dims[concat_dim]; - } - - memory_desc_t dummy_dst_md; - if (dst_md) { - if (dst_md->ndims != ndims) return invalid_arguments; - for (int d = 0; d < ndims; ++d) { - if (dst_md->dims[d] != - (d == concat_dim ? concat_dim_sz : dims[d])) - return invalid_arguments; - } - } else { - dummy_dst_md = src_mds[0]; - dummy_dst_md.dims[concat_dim] = concat_dim_sz; - dummy_dst_md.format_kind = format_kind::any; - dst_md = &dummy_dst_md; - } - - auto c_pd = reinterpret_cast(concat_pd); - - for (auto c = engine->get_concat_implementation_list(); *c; ++c) { - if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds) - == success) { - (*c_pd)->init_info(); - (*c_pd)->init_scratchpad_md(); - return success; - } - } - return unimplemented; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp deleted file mode 100644 index 29311927e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp +++ /dev/null @@ -1,211 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CONCAT_PD_HPP -#define CONCAT_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "primitive_desc.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct concat_pd_t: public primitive_desc_t { - concat_pd_t(engine_t *engine, const primitive_attr_t *attr, - const memory_desc_t *dst_md, int n, int concat_dim, - const memory_desc_t *src_mds) - : primitive_desc_t(engine, attr, primitive_kind::concat) - , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md) - { - src_mds_.reserve(n_); - for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); - } - - concat_pd_t(const concat_pd_t &rhs) = default; - - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg >= MKLDNN_ARG_MULTIPLE_SRC - && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index < n_inputs() ? &src_mds_[index] : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - - virtual int n_inputs() const override { return n_; } - virtual int n_outputs() const override { return 1; } - - int concat_dim() const { return concat_dim_; } - - const memory_desc_t *src_image_md(int index = 0) const - { return index < n_inputs() ? &src_image_mds_[index] : nullptr; } - -protected: - int n_, concat_dim_; - memory_desc_t dst_md_; - nstl::vector src_mds_; - - /* contains images of srcs in the dst memory (if possible) - * Lives here to simplify some implementations. An implementation might - * use this auxiliary array iff init() returned success */ - nstl::vector src_image_mds_; - -protected: - /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */ - status_t init() { - bool ok = true - && set_default_params() == status::success - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - for (int i = 0; i < n_; ++i) { - const memory_desc_wrapper i_d(&src_mds_[i]); - if (!i_d.is_blocking_desc() || i_d.is_additional_buffer()) - return status::unimplemented; - } - - const int ndims = dst_md_.ndims; - int current_concat_dim_offset = 0; - for (int i = 0; i < n_; ++i) { - const int dim = src_mds_[i].dims[concat_dim_]; - dims_t dims, offsets = {}; - utils::array_copy(dims, dst_md_.dims, ndims); - dims[concat_dim_] = dim; - offsets[concat_dim_] = current_concat_dim_offset; - - memory_desc_t src_img_d; - status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, - &dst_md_, dims, offsets); - if (status != status::success) return status; - src_image_mds_.push_back(src_img_d); - current_concat_dim_offset += dim; - } - - return status::success; - } - - status_t set_default_params() { - if (dst_md_.format_kind != format_kind::any) - return status::success; - - const int ndims = dst_md_.ndims; - - /* The stupidest ever heuristics (but not the same as we had before): - * - Pick the first non-plain format; - * - If all formats are plain or it is not possible to create a - * blocked format for the output, pick the format of the plain input - * - If this fails as well, use plain layout (abcd...) - */ - status_t status = status::unimplemented; - for (int i = 0; i < n_; ++i) { - const memory_desc_wrapper src_d(src_mds_[i]); - if (src_d.is_blocking_desc() && !src_d.is_plain()) { - status = memory_desc_init_by_blocking_desc(dst_md_, - src_d.blocking_desc()); - if (status == status::success) break; - } - } - - if (status == status::success) { - /* check if we can create a sub-memory for the dst */ - bool desired_format_ok = true; - int current_concat_dim_offset = 0; - for (int i = 0; i < n_; ++i) { - const int dim = src_mds_[i].dims[concat_dim_]; - dims_t dims, offsets = {}; - utils::array_copy(dims, dst_md_.dims, ndims); - dims[concat_dim_] = dim; - offsets[concat_dim_] = current_concat_dim_offset; - - memory_desc_t src_img_d; - status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, - &dst_md_, dims, offsets); - if (status != status::success) { - desired_format_ok = false; - break; - } - current_concat_dim_offset += dim; - } - - if (!desired_format_ok) - status = status::unimplemented; - } - - /* if no success so far, try using the format of the first plain input */ - if (status != status::success) { - for (int i = 0; i < n_; ++i) { - const memory_desc_wrapper src_d(src_mds_[i]); - if (src_d.is_blocking_desc() && src_d.is_plain()) { - status = memory_desc_init_by_blocking_desc(dst_md_, - memory_desc_wrapper(src_mds_[0]).blocking_desc()); - if (status == status::success) return status; - } - } - } - - /* the last line of defense: use plain abcd... format */ - if (status != status::success) - status = memory_desc_init_by_strides(dst_md_, nullptr); - - return status; - } -}; - -#define DECLARE_CONCAT_PD_t(impl_name, ...) \ - static status_t create(concat_pd_t **concat_pd, \ - engine_t *engine, const primitive_attr_t *attr, \ - const memory_desc_t *dst_md, int n, int concat_dim, \ - const memory_desc_t *src_mds) { \ - using namespace status; \ - auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \ - if (_pd == nullptr) return out_of_memory; \ - if (_pd->init() != success) { delete _pd; return unimplemented; } \ - return safe_ptr_assign(*concat_pd, _pd); \ - } \ - virtual status_t create_primitive(primitive_t **p) const override { \ - double ms = get_msec(); \ - auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ - ms = get_msec() - ms; \ - if (mkldnn_verbose()->level >= 2) { \ - printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ - fflush(0); \ - } \ - return ret; \ - } \ - virtual pd_t *clone() const override { return new pd_t(*this); } \ - virtual const char *name() const override { return impl_name; } \ - -#define DECLARE_CONCAT_PD_T(impl_name, ...) \ - DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__) - -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp deleted file mode 100644 index 0c5c02bcd..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp +++ /dev/null @@ -1,200 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace mkldnn { -namespace impl { -status_t conv_desc_init(convolution_desc_t *conv_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, - const dims_t strides, const dims_t dilates, - const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - bool args_ok = true - && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides, - padding_l) - && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd) - && one_of(padding_kind, padding_kind::padding_zero); - if (!args_ok) return invalid_arguments; - - if (padding_r == nullptr) padding_r = padding_l; - - auto cd = convolution_desc_t(); - cd.primitive_kind = primitive_kind::convolution; - cd.prop_kind = prop_kind; - cd.alg_kind = alg_kind; - - cd.diff_src_desc = cd.src_desc = zero_md(); - cd.diff_dst_desc = cd.dst_desc = zero_md(); - cd.diff_weights_desc = cd.weights_desc = zero_md(); - cd.diff_bias_desc = cd.bias_desc = zero_md(); - - const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); - const bool with_bias = - bias_desc && bias_desc->format_kind != format_kind::undef; - const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; - - (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc; - (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc; - (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) = - *weights_desc; - if (with_bias) - (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) = - *bias_desc; - - int sp_dims = src_desc->ndims - 2; - utils::array_copy(cd.strides, strides, sp_dims); - utils::array_copy(cd.padding[0], padding_l, sp_dims); - utils::array_copy(cd.padding[1], padding_r, sp_dims); - if (dilates) - utils::array_copy(cd.dilates, dilates, sp_dims); - else - utils::array_set(cd.dilates, 0, sp_dims); - - cd.padding_kind = padding_kind; - cd.accum_data_type = types::default_accum_data_type(src_desc->data_type, - weights_desc->data_type, dst_desc->data_type, prop_kind); - - const int g = with_groups ? weights_desc->dims[0] : 1; - const int bias_dim = prop_kind == backward_data - ? src_desc->dims[1] - : dst_desc->dims[1]; - - bool consistency = true - && memory_desc_wrapper(weights_desc).nelems() - && src_desc->ndims == dst_desc->ndims - && utils::one_of(src_desc->ndims, 3, 4, 5) - && utils::one_of(weights_desc->ndims, src_desc->ndims, - src_desc->ndims + 1) - && (with_bias ? bias_desc->ndims == 1 : true) - && (with_bias ? bias_desc->dims[0] == bias_dim : true) - && src_desc->dims[0] == dst_desc->dims[0] - && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] - && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; - for (int i = 2; i < src_desc->ndims; ++i) - { - int src = src_desc->dims[i]; - int ker = weights_desc->dims[with_groups + i]; - int dil = cd.dilates[i - 2]; - int pad_l = padding_l[i - 2]; - int pad_r = padding_r[i - 2]; - int str = strides[i - 2]; - int dst = dst_desc->dims[i]; - int ker_range = 1 + (ker - 1) * (dil + 1); - - if (str < 1) return invalid_arguments; - consistency = consistency - && dil >= 0 - && pad_l >= 0 - && pad_r + str > 0 - && (src - ker_range + pad_l + pad_r) / str + 1 == dst; - } - if (!consistency) return invalid_arguments; - - *conv_desc = cd; - return success; -} -} -} - -status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, - const dims_t strides, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, - weights_desc, bias_desc, dst_desc, strides, nullptr, - padding_l, padding_r, padding_kind); -} - -status_t mkldnn_dilated_convolution_forward_desc_init( - convolution_desc_t *conv_desc, prop_kind_t prop_kind, - alg_kind_t alg_kind, const memory_desc_t *src_desc, - const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_desc, const dims_t strides, - const dims_t dilates, const dims_t padding_l, - const dims_t padding_r, padding_kind_t padding_kind) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, - weights_desc, bias_desc, dst_desc, strides, dilates, - padding_l, padding_r, padding_kind); -} - -status_t mkldnn_convolution_backward_data_desc_init( - convolution_desc_t *conv_desc, alg_kind_t alg_kind, - const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, - weights_desc, nullptr, diff_dst_desc, strides, nullptr, - padding_l, padding_r, padding_kind); -} - -status_t mkldnn_dilated_convolution_backward_data_desc_init( - convolution_desc_t *conv_desc, alg_kind_t alg_kind, - const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t dilates, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, - weights_desc, nullptr, diff_dst_desc, strides, dilates, - padding_l, padding_r, padding_kind); -} - -status_t mkldnn_convolution_backward_weights_desc_init( - convolution_desc_t *conv_desc, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, - const memory_desc_t *diff_bias_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, - diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, - nullptr, padding_l, padding_r, padding_kind); -} - -status_t mkldnn_dilated_convolution_backward_weights_desc_init( - convolution_desc_t *conv_desc, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, - const memory_desc_t *diff_bias_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t dilates, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, - diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, - dilates, padding_l, padding_r, padding_kind); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp deleted file mode 100644 index 9604e0acf..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp +++ /dev/null @@ -1,56 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "utils.hpp" - -#include "convolution_pd.hpp" - -namespace mkldnn { -namespace impl { - -using namespace prop_kind; - -memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { - return desc->prop_kind == backward_data - ? &desc->diff_src_desc : &desc->src_desc; -} - -memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { - return desc->prop_kind == backward_weights - ? &desc->diff_weights_desc : &desc->weights_desc; -} - -memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { - return desc->prop_kind == backward_weights - ? &desc->diff_bias_desc : &desc->bias_desc; -} - -memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { - return utils::one_of(desc->prop_kind, forward_inference, forward_training) - ? &desc->dst_desc : &desc->diff_dst_desc; -} - -const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) -{ return conv_prop_invariant_src_d(const_cast(desc)); } -const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) -{ return conv_prop_invariant_wei_d(const_cast(desc)); } -const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) -{ return conv_prop_invariant_bia_d(const_cast(desc)); } -const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) -{ return conv_prop_invariant_dst_d(const_cast(desc)); } - -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp deleted file mode 100644 index b10c36db4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp +++ /dev/null @@ -1,348 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CONVOLUTION_PD_HPP -#define CONVOLUTION_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -status_t conv_desc_init(convolution_desc_t *conv_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, - const dims_t strides, const dims_t dilates, - const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind); - -memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); -memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); -memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); -memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); -const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); -const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); -const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); -const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); - -struct convolution_fwd_pd_t; - -struct convolution_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::convolution; - - convolution_pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - {} - - const convolution_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case pkind_traits::query_d: - *(const convolution_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common conv aux functions */ - - dim_t MB() const { return _src_md()->dims[0]; } - - dim_t IC() const { return _src_md()->dims[1]; } - dim_t OC() const { return _dst_md()->dims[1]; } - dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } - - dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } - dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } - dim_t IW() const { return _src_md()->dims[ndims() - 1]; } - - dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } - dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } - dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } - - dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } - dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } - dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } - - dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } - dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } - dim_t KSW() const { return desc_.strides[ndims() - 3]; } - - dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } - dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } - dim_t KDW() const { return desc_.dilates[ndims() - 3]; } - - dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } - dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } - dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } - dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } - dim_t padL() const { return desc_.padding[0][ndims() - 3]; } - dim_t padR() const { return desc_.padding[1][ndims() - 3]; } - - int ndims() const { return _src_md()->ndims; } - - bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } - bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - - bool has_zero_dim_memory() const { - const auto s_d = memory_desc_wrapper(*_src_md()); - const auto d_d = memory_desc_wrapper(*_dst_md()); - return s_d.has_zero_dim() || d_d.has_zero_dim(); - } - -protected: - convolution_desc_t desc_; - const convolution_fwd_pd_t *hint_fwd_pd_; - - bool set_default_formats_common_template( - memory_desc_t &src_md, format_tag_t src_tag, - memory_desc_t &wei_md, format_tag_t wei_tag, - memory_desc_t &dst_md, format_tag_t dst_tag, - memory_desc_t &bia_md) { - using namespace format_tag; - -# define IS_OK(f) \ - do { if ((f) != status::success) return false; } while(0) - if (src_md.format_kind == format_kind::any - && !utils::one_of(src_tag, any, undef)) - IS_OK(memory_desc_init_by_tag(src_md, src_tag)); - if (dst_md.format_kind == format_kind::any - && !utils::one_of(dst_tag, any, undef)) - IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); - if (wei_md.format_kind == format_kind::any - && !utils::one_of(wei_tag, any, undef)) - IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); - if (with_bias() && bia_md.format_kind == format_kind::any) - IS_OK(memory_desc_init_by_tag(bia_md, x)); -# undef IS_OK - - return true; - } - - bool set_default_alg_kind(alg_kind_t alg_kind) { - assert(utils::one_of(alg_kind, alg_kind::convolution_direct, - alg_kind::convolution_winograd)); - if (desc_.alg_kind == alg_kind::convolution_auto) - desc_.alg_kind = alg_kind; - return desc_.alg_kind == alg_kind; - } - - bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, - data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { - bool ok = true - && (src_dt == data_type::undef || _src_md()->data_type == src_dt) - && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) - && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) - && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); - if (with_bias() && bia_dt != data_type::undef) - ok = ok && _bia_md()->data_type == bia_dt; - return ok; - } - -private: - const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } - const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } - const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } - const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } -}; - -struct convolution_fwd_pd_t: public convolution_pd_t { - typedef convolution_fwd_pd_t base_class; - typedef convolution_fwd_pd_t hint_class; - - convolution_fwd_pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , weights_md_(desc_.weights_desc) - , bias_md_(desc_.bias_desc) - , dst_md_(desc_.dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_BIAS && with_bias()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - virtual const memory_desc_t *weights_md(int index = 0) const override { - if (index == 0) return &weights_md_; - if (index == 1 && with_bias()) return &bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2 + with_bias(); } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t src_md_; - memory_desc_t weights_md_; - memory_desc_t bias_md_; - memory_desc_t dst_md_; - - bool set_default_formats_common(format_tag_t src_tag, - format_tag_t wei_tag, format_tag_t dst_tag) { - return set_default_formats_common_template(src_md_, src_tag, - weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); - } -}; - -struct convolution_bwd_data_pd_t: public convolution_pd_t { - typedef convolution_bwd_data_pd_t base_class; - typedef convolution_fwd_pd_t hint_class; - - convolution_bwd_data_pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_src_md_(desc_.diff_src_desc) - , weights_md_(desc_.weights_desc) - , bias_md_(desc_.bias_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *weights_md(int index = 0) const override { - if (index == 0) return &weights_md_; - if (index == 1 && with_bias()) return &bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2 + with_bias(); } - virtual int n_outputs() const override { return 1; } - - virtual bool support_bias() const { return false; } - -protected: - memory_desc_t diff_src_md_; - memory_desc_t weights_md_; - memory_desc_t bias_md_; - memory_desc_t diff_dst_md_; - - bool set_default_formats_common(format_tag_t diff_src_tag, - format_tag_t wei_tag, format_tag_t diff_dst_tag) { - return set_default_formats_common_template(diff_src_md_, diff_src_tag, - weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); - } -}; - -struct convolution_bwd_weights_pd_t: public convolution_pd_t { - typedef convolution_bwd_weights_pd_t base_class; - typedef convolution_fwd_pd_t hint_class; - - convolution_bwd_weights_pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , diff_weights_md_(desc_.diff_weights_desc) - , diff_bias_md_(desc_.diff_bias_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_WEIGHTS) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *diff_weights_md(int index = 0) const override { - if (index == 0) return &diff_weights_md_; - if (index == 1 && with_bias()) return &diff_bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2; } - virtual int n_outputs() const override { return 1 + with_bias(); } - -protected: - memory_desc_t src_md_; - memory_desc_t diff_weights_md_; - memory_desc_t diff_bias_md_; - memory_desc_t diff_dst_md_; - - bool set_default_formats_common(format_tag_t src_tag, - format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { - return set_default_formats_common_template(src_md_, src_tag, - diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, - diff_bias_md_); - } -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp deleted file mode 100644 index 98063c1c3..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp +++ /dev/null @@ -1,188 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, - const dims_t strides, const dims_t dilates, const dims_t padding_l, - const dims_t padding_r, padding_kind_t padding_kind) { - bool args_ok = true - && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides, - padding_l) - && one_of(alg_kind, deconvolution_direct, deconvolution_winograd) - && one_of(padding_kind, padding_kind::padding_zero); - if (!args_ok) - return invalid_arguments; - - if (padding_r == nullptr) - padding_r = padding_l; - - auto dd = deconvolution_desc_t(); - dd.primitive_kind = primitive_kind::deconvolution; - dd.prop_kind = prop_kind; - dd.alg_kind = alg_kind; - - dd.diff_src_desc = dd.src_desc = zero_md(); - dd.diff_dst_desc = dd.dst_desc = zero_md(); - dd.diff_weights_desc = dd.weights_desc = zero_md(); - dd.diff_bias_desc = dd.bias_desc = zero_md(); - - const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); - const bool with_bias - = bias_desc && bias_desc->format_kind != format_kind::undef; - const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; - - (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc; - (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc; - (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc) - = *weights_desc; - if (with_bias) - (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc) - = *bias_desc; - - int sp_dims = src_desc->ndims - 2; - utils::array_copy(dd.strides, strides, sp_dims); - utils::array_copy(dd.padding[0], padding_l, sp_dims); - utils::array_copy(dd.padding[1], padding_r, sp_dims); - if (dilates) - utils::array_copy(dd.dilates, dilates, sp_dims); - else - utils::array_set(dd.dilates, 0, sp_dims); - - dd.padding_kind = padding_kind; - dd.accum_data_type = types::default_accum_data_type(src_desc->data_type, - weights_desc->data_type, dst_desc->data_type, prop_kind); - - const int g = with_groups ? weights_desc->dims[0] : 1; - bool consistency = true - && src_desc->ndims == dst_desc->ndims - && utils::one_of(src_desc->ndims, 3, 4, 5) - && utils::one_of(weights_desc->ndims, src_desc->ndims, - src_desc->ndims + 1) - && (with_bias ? bias_desc->ndims == 1 : true) - && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) - && src_desc->dims[0] == dst_desc->dims[0] - && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] - && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; - for (int i = 2; i < src_desc->ndims; ++i) { - int src = src_desc->dims[i]; - int ker = weights_desc->dims[with_groups + i]; - int dil = dd.dilates[i - 2]; - int pad = padding_l[i - 2] + padding_r[i - 2]; - int str = strides[i - 2]; - int dst = dst_desc->dims[i]; - int ker_range = 1 + (ker - 1) * (dil + 1); - - consistency - = consistency && (dst - ker_range + pad) / str + 1 == src; - } - if (!consistency) - return invalid_arguments; - - *deconv_desc = dd; - return success; -} -} - -status_t mkldnn_deconvolution_forward_desc_init( - deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, - alg_kind_t alg_kind, const memory_desc_t *src_desc, - const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_desc, const dims_t strides, - const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, - weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l, - padding_r, padding_kind); -} - -status_t mkldnn_dilated_deconvolution_forward_desc_init( - deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, - alg_kind_t alg_kind, const memory_desc_t *src_desc, - const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_desc, const dims_t strides, - const dims_t dilates, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, - weights_desc, bias_desc, dst_desc, strides, dilates, padding_l, - padding_r, padding_kind); -} - -status_t mkldnn_deconvolution_backward_data_desc_init( - deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, - const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, - weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l, - padding_r, padding_kind); -} - -status_t mkldnn_dilated_deconvolution_backward_data_desc_init( - deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, - const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t dilates, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, - weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l, - padding_r, padding_kind); -} - -status_t mkldnn_deconvolution_backward_weights_desc_init( - deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, - const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, - const dims_t strides, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, - diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr, - padding_l, padding_r, padding_kind); -} - -status_t mkldnn_dilated_deconvolution_backward_weights_desc_init( - deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, - const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, - const dims_t strides, const dims_t dilates, const dims_t padding_l, - const dims_t padding_r, padding_kind_t padding_kind) { - return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, - diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates, - padding_l, padding_r, padding_kind); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp deleted file mode 100644 index 539e44bd9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp +++ /dev/null @@ -1,293 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef DECONVOLUTION_PD_HPP -#define DECONVOLUTION_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "convolution_pd.hpp" -#include "primitive_desc.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct deconvolution_fwd_pd_t; - -struct deconvolution_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::deconvolution; - - deconvolution_pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - {} - - const deconvolution_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case pkind_traits::query_d: - *(const deconvolution_desc_t **)result = desc(); - break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ - - dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; } - - dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; } - dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; } - dim_t G() const - { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; } - - dim_t ID() const { - return ndims() >= 5 - ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; - } - dim_t IH() const { - return ndims() >= 4 - ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; - } - dim_t IW() const { - return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1]; - } - - dim_t OD() const { - return ndims() >= 5 - ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; - } - dim_t OH() const { - return ndims() >= 4 - ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; - } - dim_t OW() const { - return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1]; - } - - dim_t KD() const { - const int w_ndims = ndims() + with_groups(); - return ndims() >= 5 - ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1; - } - dim_t KH() const { - const int w_ndims = ndims() + with_groups(); - return ndims() >= 4 - ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1; - } - dim_t KW() const { - const int w_ndims = ndims() + with_groups(); - return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1]; - } - - dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } - dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } - dim_t KSW() const { return desc_.strides[ndims() - 3]; } - - dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } - dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } - dim_t KDW() const { return desc_.dilates[ndims() - 3]; } - - dim_t padFront() const - { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } - dim_t padBack() const - { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } - dim_t padT() const - { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } - dim_t padB() const - { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } - dim_t padL() const { return desc_.padding[0][ndims() - 3]; } - dim_t padR() const { return desc_.padding[1][ndims() - 3]; } - - bool with_bias() const { - return - !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero(); - } - - bool with_groups() const - { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; } - - int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - - bool has_zero_dim_memory() const { - const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_)); - const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_)); - return s_d.has_zero_dim() || d_d.has_zero_dim(); - } - -protected: - deconvolution_desc_t desc_; - const deconvolution_fwd_pd_t *hint_fwd_pd_; -}; - -struct deconvolution_fwd_pd_t: public deconvolution_pd_t { - typedef deconvolution_fwd_pd_t base_class; - typedef deconvolution_fwd_pd_t hint_class; - - deconvolution_fwd_pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , weights_md_(desc_.weights_desc) - , bias_md_(desc_.bias_desc) - , dst_md_(desc_.dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_BIAS && with_bias()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - virtual const memory_desc_t *weights_md(int index = 0) const override { - if (index == 0) return &weights_md_; - if (index == 1 && with_bias()) return &bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2 + with_bias(); } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t src_md_; - memory_desc_t weights_md_; - memory_desc_t bias_md_; - memory_desc_t dst_md_; -}; - -struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t { - typedef deconvolution_bwd_data_pd_t base_class; - typedef deconvolution_fwd_pd_t hint_class; - - deconvolution_bwd_data_pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_src_md_(desc_.diff_src_desc) - , weights_md_(desc_.weights_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *weights_md(int index = 0) const override - { return index == 0 ? &weights_md_ : nullptr; } - - virtual int n_inputs() const override { return 2; } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t diff_src_md_; - memory_desc_t weights_md_; - memory_desc_t diff_dst_md_; -}; - -struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t { - typedef deconvolution_bwd_weights_pd_t base_class; - typedef deconvolution_fwd_pd_t hint_class; - - deconvolution_bwd_weights_pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , diff_weights_md_(desc_.diff_weights_desc) - , diff_bias_md_(desc_.diff_bias_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_WEIGHTS) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *diff_weights_md(int index = 0) const override { - if (index == 0) return &diff_weights_md_; - if (index == 1 && with_bias()) return &diff_bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2; } - virtual int n_outputs() const override { return 1 + with_bias(); } - -protected: - memory_desc_t src_md_; - memory_desc_t diff_weights_md_; - memory_desc_t diff_bias_md_; - memory_desc_t diff_dst_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp deleted file mode 100644 index f1708fca5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, - alg_kind_t alg_kind, const memory_desc_t *data_desc, - const memory_desc_t *diff_data_desc, float alpha, float beta) { - bool args_ok = true - && !any_null(eltwise_desc, data_desc) - && one_of(prop_kind, forward_training, forward_inference, - backward_data) - && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, - eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, - eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic) - && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); - if (!args_ok) return invalid_arguments; - - auto ed = eltwise_desc_t(); - ed.primitive_kind = primitive_kind::eltwise; - ed.prop_kind = prop_kind; - ed.alg_kind = alg_kind; - - ed.data_desc = *data_desc; - ed.diff_data_desc = - (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md(); - - ed.alpha = alpha; - ed.beta = beta; - - bool consistency = true - && IMPLICATION(ed.prop_kind == backward_data, - array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims, - ed.diff_data_desc.ndims)); - if (!consistency) return invalid_arguments; - - *eltwise_desc = ed; - return success; -} -} - -status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *data_desc, float alpha, float beta) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc, - nullptr, alpha, beta); -} - -status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc, - alg_kind_t alg_kind, const memory_desc_t *diff_data_desc, - const memory_desc_t *data_desc, float alpha, float beta) { - return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc, - diff_data_desc, alpha, beta); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp deleted file mode 100644 index 9fd260fce..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp +++ /dev/null @@ -1,161 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef ELTWISE_PD_HPP -#define ELTWISE_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" - -namespace mkldnn { -namespace impl { - -struct eltwise_fwd_pd_t; - -struct eltwise_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::eltwise; - - eltwise_pd_t(mkldnn::impl::engine_t *engine, - const eltwise_desc_t *adesc, - const primitive_attr_t *attr, - const eltwise_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , data_md_(desc_.data_desc) - {} - - const eltwise_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::eltwise_d: - *(const eltwise_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common eltwise aux functions */ - - dim_t MB() const { return data_desc().dims[0]; } - dim_t C() const { return data_desc().dims[1]; } - dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } - dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } - dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } - - int ndims() const { return data_desc().ndims; } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - - bool has_zero_dim_memory() const - { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } - -protected: - eltwise_desc_t desc_; - const eltwise_fwd_pd_t *hint_fwd_pd_; - - memory_desc_t data_md_; - -private: - const memory_desc_t &data_desc() const { return desc_.data_desc; } -}; - -struct eltwise_fwd_pd_t: public eltwise_pd_t { - typedef eltwise_fwd_pd_t base_class; - typedef eltwise_fwd_pd_t hint_class; - - eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine, - const eltwise_desc_t *adesc, - const primitive_attr_t *attr, - const eltwise_fwd_pd_t *hint_fwd_pd) - : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_SRC) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - - virtual int n_inputs() const override { return 1; } - virtual int n_outputs() const override { return 1; } - - bool is_zero_preserved() const - { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); } -}; - -struct eltwise_bwd_pd_t: public eltwise_pd_t { - typedef eltwise_bwd_pd_t base_class; - typedef eltwise_fwd_pd_t hint_class; - - eltwise_bwd_pd_t(engine_t *engine, - const eltwise_desc_t *adesc, - const primitive_attr_t *attr, - const eltwise_fwd_pd_t *hint_fwd_pd) - : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_data_md_(desc_.diff_data_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - - virtual int n_inputs() const override { return 2; } - virtual int n_outputs() const override { return 1; } - - bool is_zero_preserved() const { return true; } - -protected: - memory_desc_t diff_data_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp deleted file mode 100644 index 3b3e25456..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" -#include "engine.hpp" -#include "nstl.hpp" - -#include "c_types_map.hpp" -#include "../cpu/cpu_engine.hpp" - -namespace mkldnn { -namespace impl { - -engine_factory_t *engine_factories[] = { - &cpu::engine_factory, - nullptr, -}; - -static inline engine_factory_t *get_engine_factory(engine_kind_t kind) { - for (engine_factory_t **ef = engine_factories; *ef; ef++) - if ((*ef)->kind() == kind) - return *ef; - return nullptr; -} - -} -} - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; - -size_t mkldnn_engine_get_count(engine_kind_t kind) { - engine_factory_t *ef = get_engine_factory(kind); - return ef != nullptr ? ef->count() : 0; -} - -status_t mkldnn_engine_create(engine_t **engine, - engine_kind_t kind, size_t index) { - if (engine == nullptr) - return invalid_arguments; - - engine_factory_t *ef = get_engine_factory(kind); - if (ef == nullptr || index >= ef->count()) - return invalid_arguments; - - return ef->engine_create(engine, index); -} - -status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) { - if (engine == nullptr) - return invalid_arguments; - *kind = engine->kind(); - return success; -} - -status_t mkldnn_engine_destroy(engine_t *engine) { - /* TODO: engine->dec_ref_count(); */ - delete engine; - return success; -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp deleted file mode 100644 index 8ac8a29de..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef ENGINE_HPP -#define ENGINE_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive.hpp" -#include "utils.hpp" - -/** \brief An abstraction of an execution unit with shared resources - * - * Responsibilities: - * - Provide engine specific memory allocation - * - Provide engine specific primitive_desc_t creators - */ -struct mkldnn_engine: public mkldnn::impl::c_compatible { - mkldnn_engine(mkldnn::impl::engine_kind_t kind) - : kind_(kind) - {} - virtual ~mkldnn_engine() {} - - /** get kind of the current engine */ - virtual mkldnn::impl::engine_kind_t kind() const { return kind_; } - - /** allocate memory */ - virtual mkldnn::impl::status_t memory_create( - mkldnn::impl::memory_t **memory, - const mkldnn::impl::memory_desc_t *md, - void *handle) = 0; - - /** implementation section (typedefs) */ - - // TODO: remove engine? - typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)( - mkldnn::impl::reorder_pd_t **reorder_pd, - mkldnn::impl::engine_t *engine, - const mkldnn::impl::primitive_attr_t *attr, - mkldnn::impl::engine_t *src_engine, - const mkldnn::impl::memory_desc_t *src_md, - mkldnn::impl::engine_t *dst_engine, - const mkldnn::impl::memory_desc_t *dst_md); - - typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)( - mkldnn::impl::concat_pd_t **concat_pd, - mkldnn::impl::engine_t *engine, - const mkldnn::impl::primitive_attr_t *attr, - const mkldnn::impl::memory_desc_t *dst_md, - int n, int concat_dim, - const mkldnn::impl::memory_desc_t *src_mds); - - typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)( - mkldnn::impl::sum_pd_t **sum_pd, - mkldnn::impl::engine_t *engine, - const mkldnn::impl::primitive_attr_t *attr, - const mkldnn::impl::memory_desc_t *dst_md, - int n, const float *scales, - const mkldnn::impl::memory_desc_t *src_mds); - - typedef mkldnn::impl::status_t (*primitive_desc_create_f)( - mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *, - const mkldnn::impl::primitive_attr_t *attr, - mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *); - - /* implementation section */ - - /** return the list of reorder implementations. engine guarantees to return - * a NULL-terminated list */ - virtual const reorder_primitive_desc_create_f* - get_reorder_implementation_list() const = 0; - - /** return the list of concat implementations. engine guarantees to return - * a NULL-terminated list */ - virtual const concat_primitive_desc_create_f* - get_concat_implementation_list() const = 0; - - /** return the list of sum implementations. engine guarantees to return - * a NULL-terminated list */ - virtual const sum_primitive_desc_create_f* - get_sum_implementation_list() const = 0; - - /** return the list of implementations. engine guarantees to return a - * NULL-terminated list */ - virtual const primitive_desc_create_f* get_implementation_list() const = 0; - -protected: - mkldnn::impl::engine_kind_t kind_; -}; - -namespace mkldnn { -namespace impl { - -struct engine_factory_t: public c_compatible { - virtual size_t count() const = 0; - virtual engine_kind_t kind() const = 0; - virtual status_t engine_create(engine_t **engine, size_t index) const = 0; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp deleted file mode 100644 index 5a9f58cb1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp +++ /dev/null @@ -1,106 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, - const memory_desc_t *src_desc, const memory_desc_t *weights_desc, - const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { - bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc); - if (!args_ok) return invalid_arguments; - - auto id = inner_product_desc_t(); - id.primitive_kind = primitive_kind::inner_product; - id.prop_kind = prop_kind; - - id.diff_src_desc = id.src_desc = zero_md(); - id.diff_dst_desc = id.dst_desc = zero_md(); - id.diff_weights_desc = id.weights_desc = zero_md(); - id.diff_bias_desc = id.bias_desc = zero_md(); - - const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); - const bool with_bias = - bias_desc && bias_desc->format_kind != format_kind::undef; - - (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc; - (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc; - (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) = - *weights_desc; - if (with_bias) - (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) = - *bias_desc; - - id.accum_data_type = types::default_accum_data_type(src_desc->data_type, - weights_desc->data_type, dst_desc->data_type, prop_kind); - - bool consistency = true - && memory_desc_wrapper(weights_desc).nelems() - && one_of(src_desc->ndims, 2, 3, 4, 5) - && dst_desc->ndims == 2 - && weights_desc->ndims == src_desc->ndims - && (with_bias ? bias_desc->ndims == 1 : true) - && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) - && src_desc->dims[0] == dst_desc->dims[0] - && array_cmp(&src_desc->dims[1], &weights_desc->dims[1], - src_desc->ndims - 1) - && dst_desc->dims[1] == weights_desc->dims[0]; - if (!consistency) return invalid_arguments; - - *ip_desc = id; - return success; -} -} - -status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc, - prop_kind_t prop_kind, const memory_desc_t *src_desc, - const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_desc) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc, - dst_desc); -} - -status_t mkldnn_inner_product_backward_data_desc_init( - inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc, - const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc) -{ - return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc, - nullptr, diff_dst_desc); -} - -status_t mkldnn_inner_product_backward_weights_desc_init( - inner_product_desc_t *ip_desc, const memory_desc_t *src_desc, - const memory_desc_t *diff_weights_desc, - const memory_desc_t *diff_bias_desc, - const memory_desc_t *diff_dst_desc) { - return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc, - diff_bias_desc, diff_dst_desc); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp deleted file mode 100644 index 091cf0f5d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp +++ /dev/null @@ -1,56 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "utils.hpp" - -#include "inner_product_pd.hpp" - -namespace mkldnn { -namespace impl { - -using namespace prop_kind; - -memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) { - return desc->prop_kind == backward_data - ? &desc->diff_src_desc : &desc->src_desc; -} - -memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) { - return desc->prop_kind == backward_weights - ? &desc->diff_weights_desc : &desc->weights_desc; -} - -memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) { - return desc->prop_kind == backward_weights - ? &desc->diff_bias_desc : &desc->bias_desc; -} - -memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) { - return utils::one_of(desc->prop_kind, forward_inference, forward_training) - ? &desc->dst_desc : &desc->diff_dst_desc; -} - -const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc) -{ return ip_prop_invariant_src_d(const_cast(desc)); } -const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc) -{ return ip_prop_invariant_wei_d(const_cast(desc)); } -const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc) -{ return ip_prop_invariant_bia_d(const_cast(desc)); } -const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc) -{ return ip_prop_invariant_dst_d(const_cast(desc)); } - -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp deleted file mode 100644 index c426de632..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp +++ /dev/null @@ -1,321 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef INNER_PRODUCT_PD_HPP -#define INNER_PRODUCT_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); -memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); -memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); -memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); -const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); -const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); -const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); -const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); - -struct inner_product_fwd_pd_t; - -struct inner_product_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::inner_product; - - inner_product_pd_t(engine_t *engine, - const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - {} - - const inner_product_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::inner_product_d: - *(const inner_product_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common inner_product aux functions */ - - dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } - dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } - dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } - - dim_t ID() const { - return ndims() >= 5 - ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; - } - dim_t IH() const { - return ndims() >= 4 - ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; - } - dim_t IW() const { - return ndims() >= 3 - ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; - } - - dim_t OD() const { - return ndims() >= 5 - ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; - } - dim_t OH() const { - return ndims() >= 4 - ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; - } - dim_t OW() const { - return ndims() >= 3 - ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; - } - - dim_t KD() const { - return ndims() >= 5 - ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; - } - dim_t KH() const { - return ndims() >= 4 - ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; - } - dim_t KW() const { - return ndims() >= 3 - ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; - } - - dim_t IC_total() const { - return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], - ndims() - 1); - } - - dim_t IC_total_padded() const { - auto src_d = desc()->prop_kind == prop_kind::backward_data - ? memory_desc_wrapper(diff_src_md()) - : memory_desc_wrapper(src_md()); - assert(src_d.is_blocking_desc()); - if (!src_d.is_blocking_desc()) return -1; - return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); - } - - int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } - - bool with_bias() const - { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } - - bool has_zero_dim_memory() const { - const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); - const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); - return s_d.has_zero_dim() || d_d.has_zero_dim(); - } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - -protected: - inner_product_desc_t desc_; - const inner_product_fwd_pd_t *hint_fwd_pd_; - - status_t template_set_default_params(memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t *bias_md) { - using namespace format_tag; - if (src_md.format_kind == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md, - utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); - } - if (dst_md.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(dst_md, nc)); - if (weights_md.format_kind == format_kind::any) { - CHECK(memory_desc_init_by_tag(weights_md, - utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); - } - if (bias_md && bias_md->format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(*bias_md, x)); - return status::success; - } -}; - -struct inner_product_fwd_pd_t: public inner_product_pd_t { - typedef inner_product_fwd_pd_t base_class; - typedef inner_product_fwd_pd_t hint_class; - - inner_product_fwd_pd_t(engine_t *engine, - const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , weights_md_(desc_.weights_desc) - , bias_md_(desc_.bias_desc) - , dst_md_(desc_.dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_BIAS && with_bias()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - virtual const memory_desc_t *weights_md(int index = 0) const override { - if (index == 0) return &weights_md_; - if (index == 1 && with_bias()) return &bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2 + with_bias(); } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t src_md_; - memory_desc_t weights_md_; - memory_desc_t bias_md_; - memory_desc_t dst_md_; - - status_t set_default_params() { - return template_set_default_params(src_md_, weights_md_, dst_md_, - &bias_md_); - } -}; - -struct inner_product_bwd_data_pd_t: public inner_product_pd_t { - typedef inner_product_bwd_data_pd_t base_class; - typedef inner_product_fwd_pd_t hint_class; - - inner_product_bwd_data_pd_t(engine_t *engine, - const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_src_md_(desc_.diff_src_desc) - , weights_md_(desc_.weights_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *weights_md(int index = 0) const override - { return index == 0 ? &weights_md_ : nullptr; } - - virtual int n_inputs() const override { return 2; } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t diff_src_md_; - memory_desc_t weights_md_; - memory_desc_t diff_dst_md_; - - status_t set_default_params() { - return template_set_default_params(diff_src_md_, weights_md_, - diff_dst_md_, nullptr); - } -}; - -struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { - typedef inner_product_bwd_weights_pd_t base_class; - typedef inner_product_fwd_pd_t hint_class; - - inner_product_bwd_weights_pd_t(engine_t *engine, - const inner_product_desc_t *adesc, - const primitive_attr_t *attr, - const inner_product_fwd_pd_t *hint_fwd_pd) - : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , diff_weights_md_(desc_.diff_weights_desc) - , diff_bias_md_(desc_.diff_bias_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_WEIGHTS) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *diff_weights_md(int index = 0) const override { - if (index == 0) return &diff_weights_md_; - if (index == 1 && with_bias()) return &diff_bias_md_; - return nullptr; - } - - virtual int n_inputs() const override { return 2; } - virtual int n_outputs() const override { return 1 + with_bias(); } - -protected: - memory_desc_t src_md_; - memory_desc_t diff_weights_md_; - memory_desc_t diff_bias_md_; - memory_desc_t diff_dst_md_; - - status_t set_default_params() { - return template_set_default_params(src_md_, diff_weights_md_, - diff_dst_md_, &diff_bias_md_); - } -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp deleted file mode 100644 index fcf18b556..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t lrn_desc_init(lrn_desc_t *lrn_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc, - dim_t local_size, float alpha, float beta, float k) { - bool args_ok = true - && !any_null(lrn_desc, data_desc) - && one_of(alg_kind, lrn_within_channel, lrn_across_channels) - && one_of(prop_kind, forward_training, forward_inference, backward_data) - && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); - if (!args_ok) return invalid_arguments; - - auto ld = lrn_desc_t(); - ld.primitive_kind = primitive_kind::lrn; - ld.prop_kind = prop_kind; - ld.alg_kind = alg_kind; - - const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); - - ld.data_desc = *data_desc; - if (!is_fwd) - ld.diff_data_desc = *diff_data_desc; - else - ld.diff_data_desc = zero_md(); - ld.local_size = local_size; - ld.lrn_alpha = alpha; - ld.lrn_beta = beta; - ld.lrn_k = k; - - bool consistency = true - && ld.data_desc.ndims == 4; - if (ld.prop_kind == backward_data) - consistency = consistency - && ld.diff_data_desc.ndims == 4 - && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4); - if (!consistency) return invalid_arguments; - - *lrn_desc = ld; - return success; -} -} - -status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *data_desc, dim_t local_size, float alpha, - float beta, float k) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr, - local_size, alpha, beta, k); -} - -status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc, - alg_kind_t alg_kind, const memory_desc_t *data_desc, - const memory_desc_t *diff_data_desc, dim_t local_size, float alpha, - float beta, float k) { - return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc, - diff_data_desc, local_size, alpha, beta, k); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp deleted file mode 100644 index 90886e965..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp +++ /dev/null @@ -1,170 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef LRN_PD_HPP -#define LRN_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" - -namespace mkldnn { -namespace impl { - -struct lrn_fwd_pd_t; - -struct lrn_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::lrn; - - lrn_pd_t(engine_t *engine, - const lrn_desc_t *adesc, - const primitive_attr_t *attr, - const lrn_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , data_md_(desc_.data_desc) - , ws_md_() - {} - - const lrn_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::lrn_d: - *(const lrn_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common lrn aux functions */ - - dim_t MB() const { return data_desc().dims[0]; } - dim_t C() const { return data_desc().dims[1]; } - dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } - dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } - dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } - - int ndims() const { return data_desc().ndims; } - - bool has_zero_dim_memory() const - { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - -protected: - lrn_desc_t desc_; - const lrn_fwd_pd_t *hint_fwd_pd_; - - memory_desc_t data_md_; - memory_desc_t ws_md_; - -private: - const memory_desc_t &data_desc() const { return desc_.data_desc; } -}; - -struct lrn_fwd_pd_t: public lrn_pd_t { - typedef lrn_fwd_pd_t base_class; - typedef lrn_fwd_pd_t hint_class; - - lrn_fwd_pd_t(engine_t *engine, - const lrn_desc_t *adesc, - const primitive_attr_t *attr, - const lrn_fwd_pd_t *hint_fwd_pd) - : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_SRC) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } - - virtual int n_inputs() const override { return 1; } - virtual int n_outputs() const override - { return 1 + (workspace_md() != nullptr); } -}; - -struct lrn_bwd_pd_t: public lrn_pd_t { - typedef lrn_bwd_pd_t base_class; - typedef lrn_fwd_pd_t hint_class; - - lrn_bwd_pd_t(engine_t *engine, - const lrn_desc_t *adesc, - const primitive_attr_t *attr, - const lrn_fwd_pd_t *hint_fwd_pd) - : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_data_md_(desc_.diff_data_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) - return arg_usage_t::input; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } - - virtual int n_inputs() const override - { return 2 + (workspace_md() != nullptr); } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t diff_data_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp deleted file mode 100644 index 3fddc0bd4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp +++ /dev/null @@ -1,280 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MATH_UTILS_HPP -#define MATH_UTILS_HPP - -#include -#include - -#include "utils.hpp" -#include "nstl.hpp" -#include "mkldnn_traits.hpp" - -#if defined(MKLDNN_X86_64) -#include "immintrin.h" -#endif - -namespace mkldnn { -namespace impl { -namespace math { - -/** rounds @p f to an integer according to the mxcsr register */ -inline int mxcsr_round(float f) { -#if defined(MKLDNN_X86_64) - return _mm_cvtss_si32(_mm_load_ss(&f)); -#else - return (int)nearbyintf(f); // optimism -#endif -} - -template -inline typename utils::enable_if::value, - typename utils::remove_reference::type>::type -saturate(const acc_t &x) { - return (typename utils::remove_reference::type)x; -} - -template -inline typename utils::enable_if::value, - typename utils::remove_reference::type>::type -saturate(const acc_t &x) { - acc_t v = x; - if (v < (acc_t)nstl::numeric_limits::lowest()) - v = (acc_t)nstl::numeric_limits::lowest(); - if (v > (acc_t)nstl::numeric_limits::max()) - v = (acc_t)nstl::numeric_limits::max(); - return (typename utils::remove_reference::type)v; -} - -template -double saturate(const double &x) { - double v = x; - if (v < (double)nstl::numeric_limits::lowest()) - v = (double)nstl::numeric_limits::lowest(); - if (v > (double)nstl::numeric_limits::max()) - v = (double)nstl::numeric_limits::max(); - return v; -} - -template <> inline int8_t saturate(const uint8_t &x) { - return x <= 127u ? x : 127; -} - -template <> inline uint8_t saturate(const int8_t &x) { - return x >= 0 ? x : 0; -} - -template -typename utils::enable_if::value, out_t>::type -out_round(float v) { return (out_t)mxcsr_round(v); } - -template -typename utils::enable_if::value, out_t>::type -out_round(double v) { return (out_t)mxcsr_round((float)v); } - -template -typename utils::enable_if::value, out_t>::type -out_round(float v) { return v; } - -inline int gcd(int a, int b) { - a = impl::nstl::abs(a); - b = impl::nstl::abs(b); - if (a < b) { int x = a; a = b; b = x; } - - if (b == 0) return a; - - int r; - while ((r = a % b) != 0) { a = b; b = r; } - - return b; -} - -template -inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } - -/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ -inline int ilog2q(size_t v) { - if (v == 0) - return -1; - - int p = 0; -# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) - CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); -# undef CP - return p; -} - -template ::type> -inline U one_m_square(T x) { - return (U)(1 - x) * (1 + x); -} - -template ::type> -inline U x_m_square(T x) { - return (U)(1 - x) * x; -} - -/* activation */ -template ::type> -inline U relu_fwd(T s, A alpha) { - return s > 0 ? s : (U)(s * alpha); -} -template ::type> -inline U relu_bwd(T dd, T s, A alpha) { - return s > 0 ? dd : (U)(dd * alpha); -} - -template ::type> -inline U tanh_fwd(T s) { - const float e = tanhf((float) s); - return (U)e; -} - -template ::type> -inline U tanh_bwd(T dd, T s) { - const float e = tanh_fwd((float) s); - return (U)(dd * (1 - e) * (1 + e)); -} - -template ::type> -inline U elu_fwd(T s, A alpha) { - return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); -} -template ::type> - inline U elu_bwd(T dd, T s, A alpha) { - return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); -} - -template ::type> -inline U square_fwd(T s) { - return s * s; -} - -template ::type> -inline U square_bwd(T dd, T s) { - return dd * 2 * s; -} - -template ::type> -inline U abs_fwd(T s) { - return s > 0 ? s : -s; -} - -template ::type> -inline U abs_bwd(T dd, T s) { - return s > 0 ? dd : s < 0 ? -dd : 0; -} - -template ::type> -inline U sqrt_fwd(T s) { - return s > 0 ? (U)(::sqrtf((float)(s))) : 0; -} - -template ::type> -inline U sqrt_bwd(T dd, T s) { - return s > 0 - ? (U)(dd / (2 * ::sqrtf((float)(s)))) - : 0; -} - -template ::type> -inline U linear_fwd(T s, A alpha, A beta) { - return (U)(alpha * s + beta); -} - -template ::type> -inline U linear_bwd(T dd, T s, A alpha, A beta) { - (void) s; - (void) beta; - return (U)(dd * alpha); -} - -template ::type> -inline U bounded_relu_fwd(T s, A alpha) { - s = s > 0 ? s : 0; - return s > alpha ? (U)(alpha) : s; -} - -template ::type> -inline U bounded_relu_bwd(T dd, T s, A alpha) { - return dd * (0 < s && s < alpha ? 1 : 0); -} - -template ::type> -inline U soft_relu_fwd(T s) { - float max_logf = 8.872284e+01; //::logf(FLT_MAX) - return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; -} - -template ::type> -inline U soft_relu_bwd(T dd, T s) { - return (U)(dd / (1 + ::expf((float)(-s)))); -} - -template ::type> -inline U logistic_fwd(T s) { - U v = (U)(::expf((float) -s)); - return 1 / (1 + v); -} - -template ::type> -inline U logistic_bwd(T dd, T s) { - U v = logistic_fwd(s); - return dd * v * (1 - v); -} - -inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { - using namespace alg_kind; - using namespace utils; - const bool preserves_zero = true - && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) - && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); - return preserves_zero; -} - -inline float get_bias(const char *bias, size_t offset, data_type_t data_type) -{ - if (!bias) - return 0.0f; - -#define CASE(dt) \ - case dt: return (float)((const prec_traits
::type *)bias)[offset] - - switch (data_type) { - CASE(data_type::s8); - CASE(data_type::u8); - CASE(data_type::s32); - CASE(data_type::f32); - default: assert(!"unimplemented"); - } - return 0; // never happens (should probably be a NaN) -#undef CASE -} - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp deleted file mode 100644 index cea849c96..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp +++ /dev/null @@ -1,238 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::data_type; - -namespace { -bool memory_desc_sanity_check(int ndims,const dims_t dims, - data_type_t data_type, format_kind_t format_kind) { - if (ndims == 0) return true; - - bool ok = true - && dims != nullptr - && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS - && one_of(data_type, f32, s32, s8, u8) - && format_kind != format_kind::undef; - if (!ok) return false; - for (int d = 0; d < ndims; ++d) - if (dims[d] < 0) return false; - - return true; -} - -bool memory_desc_sanity_check(const memory_desc_t *md) { - if (md == nullptr) return false; - return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, - format_kind::any); -} -} - -status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, - const dims_t dims, data_type_t data_type, format_tag_t tag) { - if (any_null(memory_desc)) return invalid_arguments; - if (ndims == 0 || tag == format_tag::undef) { - *memory_desc = types::zero_md(); - return success; - } - - format_kind_t format_kind = types::format_tag_to_kind(tag); - - /* memory_desc != 0 */ - bool args_ok = !any_null(memory_desc) - && memory_desc_sanity_check(ndims, dims, data_type, format_kind); - if (!args_ok) return invalid_arguments; - - auto md = memory_desc_t(); - md.ndims = ndims; - array_copy(md.dims, dims, ndims); - md.data_type = data_type; - array_copy(md.padded_dims, dims, ndims); - md.format_kind = format_kind; - - status_t status = success; - if (tag == format_tag::undef) { - status = invalid_arguments; - } else if (tag == format_tag::any) { - // nop - } else if (format_kind == format_kind::blocked) { - status = memory_desc_wrapper::compute_blocking(md, tag); - } else { - assert(!"unreachable"); - status = invalid_arguments; - } - - if (status == success) - *memory_desc = md; - - return status; -} - -status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, - int ndims, const dims_t dims, data_type_t data_type, - const dims_t strides) { - if (any_null(memory_desc)) return invalid_arguments; - if (ndims == 0) { - *memory_desc = types::zero_md(); - return success; - } - - /* memory_desc != 0 */ - bool args_ok = !any_null(memory_desc) - && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); - if (!args_ok) return invalid_arguments; - - auto md = memory_desc_t(); - md.ndims = ndims; - array_copy(md.dims, dims, ndims); - md.data_type = data_type; - array_copy(md.padded_dims, dims, ndims); - md.format_kind = format_kind::blocked; - - dims_t default_strides = {0}; - if (strides == nullptr) { - default_strides[md.ndims - 1] = 1; - for (int d = md.ndims - 2; d >= 0; --d) - default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; - strides = default_strides; - } else { - /* TODO: add sanity check for the provided strides */ - } - - array_copy(md.format_desc.blocking.strides, strides, md.ndims); - - *memory_desc = md; - - return status::success; -} - -status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, - const memory_desc_t *parent_md, const dims_t dims, - const dims_t offsets) { - if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) - return invalid_arguments; - - const memory_desc_wrapper src_d(parent_md); - - for (int d = 0; d < src_d.ndims(); ++d) { - if (dims[d] < 0 || offsets[d] < 0 - || (offsets[d] + dims[d] > src_d.dims()[d])) - return invalid_arguments; - } - - if (src_d.format_kind() != format_kind::blocked) - return unimplemented; - - dims_t blocks; - src_d.compute_blocks(blocks); - - memory_desc_t dst_d = *parent_md; - auto &dst_d_blk = dst_d.format_desc.blocking; - - /* TODO: put this into memory_desc_wrapper */ - for (int d = 0; d < src_d.ndims(); ++d) { - /* very limited functionality for now */ - const bool ok = true - && offsets[d] % blocks[d] == 0 /* [r1] */ - && src_d.padded_offsets()[d] == 0 - && (false - || dims[d] % blocks[d] == 0 - || dims[d] < blocks[d]); - if (!ok) - return unimplemented; - - const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; - - dst_d.dims[d] = dims[d]; - dst_d.padded_dims[d] = is_right_border - ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; - dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; - dst_d.offset0 += /* [r1] */ - offsets[d] / blocks[d] * dst_d_blk.strides[d]; - } - - *md = dst_d; - - return success; -} - -int mkldnn_memory_desc_equal(const memory_desc_t *lhs, - const memory_desc_t *rhs) { - if (lhs == rhs) return 1; - if (any_null(lhs, rhs)) return 0; - return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); -} - -size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { - if (md == nullptr) return 0; - return memory_desc_wrapper(*md).size(); -} - -status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, - engine_t *engine, void *handle) { - if (any_null(memory, engine)) return invalid_arguments; - memory_desc_t z_md = types::zero_md(); - return engine->memory_create(memory, md ? md : &z_md, handle); -} - -status_t mkldnn_memory_get_memory_desc(const memory_t *memory, - const memory_desc_t **md) { - if (any_null(memory, md)) return invalid_arguments; - *md = memory->md(); - return success; -} - -status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { - if (any_null(memory, engine)) return invalid_arguments; - *engine = memory->engine(); - return success; -} - -status_t mkldnn_memory_get_data_handle(const memory_t *memory, - void **handle) { - if (any_null(handle)) - return invalid_arguments; - if (memory == nullptr) { - *handle = nullptr; - return success; - } - return memory->get_data_handle(handle); -} - -status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { - if (any_null(memory)) return invalid_arguments; - return memory->set_data_handle(handle); -} - -status_t mkldnn_memory_destroy(memory_t *memory) { - delete memory; - return success; -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp deleted file mode 100644 index 03dfee01f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MEMORY_HPP -#define MEMORY_HPP - -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "nstl.hpp" - -struct mkldnn_memory: public mkldnn::impl::c_compatible { - mkldnn_memory(mkldnn::impl::engine_t *engine, - const mkldnn::impl::memory_desc_t *md) - : engine_(engine), md_(*md) {} - virtual ~mkldnn_memory() {} - - /** allocates/initializes memory */ - virtual mkldnn::impl::status_t init() = 0; - - /** returns memory's engine */ - mkldnn::impl::engine_t *engine() const { return engine_; } - /** returns memory's description */ - const mkldnn::impl::memory_desc_t *md() const { return &md_; } - - /** returns data handle */ - virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0; - - /** sets data handle */ - virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0; - - /** zeros padding */ - virtual mkldnn::impl::status_t zero_pad() const - { return mkldnn::impl::status::success; } - -protected: - mkldnn::impl::engine_t *engine_; - const mkldnn::impl::memory_desc_t md_; - -private: - mkldnn_memory() = delete; - mkldnn_memory(const mkldnn_memory &) = delete; - mkldnn_memory(mkldnn_memory &&) = delete; - mkldnn_memory &operator=(const mkldnn_memory &) = delete; - mkldnn_memory &operator=(mkldnn_memory &&) = delete; -}; - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp deleted file mode 100644 index 8a99be33f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp +++ /dev/null @@ -1,212 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include - -#include "c_types_map.hpp" -#include "memory_desc_wrapper.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -status_t fill_blocked(memory_desc_t &md, - std::initializer_list perm, - std::initializer_list inner_blks, - std::initializer_list inner_idxs) { - const bool ok = true - && perm.size() == (size_t)md.ndims - && inner_blks.size() == inner_idxs.size(); - if (!ok) return status::invalid_arguments; - - md.offset0 = 0; - - blocking_desc_t &blk = md.format_desc.blocking; - - dim_t block_size = 1; - dims_t blocks = {0}; - utils::array_set(blocks, 1, md.ndims); - - blk.inner_nblks = (int)inner_blks.size(); - - int iblk = 0; - for (const auto &b: inner_idxs) - blk.inner_idxs[iblk++] = b; - - iblk = 0; - for (const auto &b: inner_blks) { - int dim = blk.inner_idxs[iblk]; - block_size *= b; - blocks[dim] *= b; - blk.inner_blks[iblk++] = b; - } - - utils::array_set(md.padded_offsets, 0, md.ndims); - for (int d = 0; d < md.ndims; ++d) - md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); - - dim_t stride = block_size; - // if only we use C++14, the initializer_list would have rbegin()/rend()... - for (int d = 0; d < md.ndims; ++d) - stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; - - for (const auto &d: perm) { - if (md.padded_dims[d] == 0) { - blk.strides[d] = 1; - continue; - } - stride /= md.padded_dims[d] / blocks[d]; - blk.strides[d] = stride; - } - - assert(stride == block_size); - - return status::success; -} - -status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, - format_tag_t tag) -{ - using namespace format_tag; - - if (memory_desc.ndims == 0) return status::invalid_arguments; - -# define C(tag, ... /* perm, inner_blks, inner_idxs */) \ - case tag: return fill_blocked(memory_desc, __VA_ARGS__) - - switch (tag) { - C(a, {0}, {}, {}); - C(ab, {0, 1}, {}, {}); - C(abc, {0, 1, 2}, {}, {}); - C(abcd, {0, 1, 2, 3}, {}, {}); - C(abcde, {0, 1, 2, 3, 4}, {}, {}); - C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); - C(abdec, {0, 1, 3, 4, 2}, {}, {}); - C(acb, {0, 2, 1}, {}, {}); - C(acbde, {0, 2, 1, 3, 4}, {}, {}); - C(acdb, {0, 2, 3, 1}, {}, {}); - C(acdeb, {0, 2, 3, 4, 1}, {}, {}); - C(ba, {1, 0}, {}, {}); - C(bac, {1, 0, 2}, {}, {}); - C(bacd, {1, 0, 2, 3}, {}, {}); - C(bcda, {1, 2, 3, 0}, {}, {}); - C(cba, {2, 1, 0}, {}, {}); - C(cdba, {2, 3, 1, 0}, {}, {}); - C(cdeba, {2, 3, 4, 1, 0}, {}, {}); - C(decab, {3, 4, 2, 0, 1}, {}, {}); - - C(Abc4a, {0, 1, 2}, {4}, {0}); - C(aBc4b, {0, 1, 2}, {4}, {1}); - C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); - C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); - C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); - C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); - C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); - C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); - C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); - C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); - C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); - C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); - C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); - C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); - C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); - C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); - C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); - C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); - C(Acb4a, {0, 2, 1}, {4}, {0}); - C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); - C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); - - C(Abc16a, {0, 1, 2}, {16}, {0}); - C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); - C(aBc16b, {0, 1, 2}, {16}, {1}); - C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); - C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); - C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); - C(aBc8b, {0, 1, 2}, {8}, {1}); - C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); - C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); - C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); - C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); - C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); - C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); - C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); - C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); - C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); - C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); - C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); - C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); - C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); - C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); - C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); - C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); - C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); - C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); - C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); - C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); - C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); - C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); - C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); - C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); - C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); - C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); - C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); - C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); - C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); - C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); - C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); - C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); - C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); - C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); - C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); - C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); - C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); - C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); - C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); - C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); - C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); - C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); - C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); - C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); - C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); - C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); - C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); - C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); - C(Acb16a, {0, 2, 1}, {16}, {0}); - C(Acb8a, {0, 2, 1}, {8}, {0}); - C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); - C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); - C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); - C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); - C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); - C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); - C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); - C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); - default: break; - } - -#undef C - - return status::invalid_arguments; -} - -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp deleted file mode 100644 index 1758f9078..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp +++ /dev/null @@ -1,400 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MEMORY_DESC_WRAPPER_HPP -#define MEMORY_DESC_WRAPPER_HPP - -#include - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "utils.hpp" - -#include "type_helpers.hpp" - -namespace mkldnn { -namespace impl { - -/** thin wrapper class over \struct memory_desc_t which allows easy - * manipulations with underlying C structure, which is taken by reference */ -struct memory_desc_wrapper: public c_compatible { - const memory_desc_t *md_; - - /** constructor which takes a reference to a constant underlying C memory - * descriptor \param md */ - memory_desc_wrapper(const memory_desc_t *md): md_(md) {} - memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {} - - /* implementing attributes */ - int ndims() const { return md_->ndims; } - const dims_t &dims() const { return md_->dims; } - data_type_t data_type() const { return md_->data_type; } - - const dims_t &padded_dims() const { return md_->padded_dims; } - const dims_t &padded_offsets() const { return md_->padded_offsets; } - dim_t offset0() const { return md_->offset0; } - - format_kind_t format_kind() const { return md_->format_kind; } - - bool is_blocking_desc() const - { return format_kind() == format_kind::blocked; } - bool is_wino_desc() const - { return format_kind() == format_kind::wino; } - bool is_rnn_packed_desc() const - { return format_kind() == format_kind::rnn_packed; } - - const blocking_desc_t &blocking_desc() const { - assert(is_blocking_desc()); - return md_->format_desc.blocking; - } - const wino_desc_t &wino_desc() const { - assert(is_wino_desc()); - return md_->format_desc.wino_desc; - } - const rnn_packed_desc_t &rnn_packed_desc() const { - assert(is_rnn_packed_desc()); - return md_->format_desc.rnn_packed_desc; - } - - const memory_extra_desc_t &extra() const { return md_->extra; } - - /* some useful function */ - - /** returns the number of elements including padding if \param with_padding - * is true, and the number of data elements otherwise */ - dim_t nelems(bool with_padding = false) const { - if (is_zero()) return 0; - return utils::array_product( - with_padding ? padded_dims() : dims(), ndims()); - } - - /** returns true if memory descriptor is zero */ - bool is_zero() const { return ndims() == 0; } - - /** returns true if memory descriptor contains zero as one of its dim */ - bool has_zero_dim() const { return nelems() == 0; } - - /** return the size of data type (a shortcut) */ - size_t data_type_size() const - { return types::data_type_size(data_type()); } - - /** return the size of data type of additional buffer */ - size_t additional_buffer_data_size() const { - if (extra().flags & memory_extra_flags::compensation_conv_s8s8) - return sizeof(int32_t); - return 0; - } - - /** return true if memory format has additional buffer */ - bool is_additional_buffer() const { - return (extra().flags & memory_extra_flags::compensation_conv_s8s8); - } - - /** returns the size of additional buffer */ - size_t additional_buffer_size() const { - if (extra().flags & memory_extra_flags::compensation_conv_s8s8) { - int cmask = extra().compensation_mask; - assert(cmask == 1 || cmask == 3); - dim_t prod = 1; - for (int d = 0; d < ndims(); ++d) - if (cmask & (1<(max_size, - padded_dims()[d] / blocks[d] * bd.strides[d]); - - if (max_size == 1 && bd.inner_nblks != 0) { - max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); - } - - return max_size * data_type_size() + additional_buffer_size(); - } - } - - /** returns true if data is dense in memory */ - bool is_dense(bool with_padding = false) const { - if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) - return false; - return nelems(with_padding) * data_type_size() == size(); - } - - /** returns true if memory desc is fully defined */ - bool is_defined() const { return format_kind() != format_kind::any; } - - /** returns true if the only (potentially) padded dim is \param dim */ - bool only_padded_dim(int dim) const { - for (int d = 0; d < ndims(); ++d) - if (d != dim && dims()[d] != padded_dims()[d]) - return false; - return true; - } - - /** returns true if memory desc has blocked layout and block dims are 1s */ - bool is_plain() const { - if (!is_blocking_desc()) return false; - return blocking_desc().inner_nblks == 0; - } - - /** returns overall block sizes */ - void compute_blocks(dims_t blocks) const { - if (!is_blocking_desc()) { - utils::array_set(blocks, 0, ndims()); - return; - } - - utils::array_set(blocks, 1, ndims()); - - const auto &bd = blocking_desc(); - for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) - blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; - } - - /* comparison section */ - - bool operator==(const memory_desc_wrapper &rhs) const - { return *this->md_ == *rhs.md_; } - bool operator!=(const memory_desc_wrapper &rhs) const - { return !operator==(rhs); } - bool operator==(const memory_desc_t &rhs) const - { return operator==(memory_desc_wrapper(rhs)); } - bool operator!=(const memory_desc_t &rhs) const - { return !operator==(rhs); } - - /** returns true if data (w/o padding if with_padding == false and w/ - * padding otherwise) have the same physical structure, i.e. dimensions, - * strides, and blocked structure. Depending on with_data_type flag - * data_type is taken or not taken into account. dim_start allows to check - * similarity for the logical part of data [dim_start .. ndims()]. - * CAUTION: format kind any and undef are not similar to whatever, hence the - * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ - /* TODO: revise */ - bool similar_to(const memory_desc_wrapper &rhs, - bool with_padding = true, bool with_data_type = true, - int dim_start = 0) const; - - /** returns true if one memory can be reordered to another */ - bool consistent_with(const memory_desc_wrapper &rhs) const; - - /** returns true if the memory desc corresponds to the given format tag and - * strides. - * @sa memory_desc_matches_tag */ - bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { - return memory_desc_matches_tag(*md_, tag, strides); - } - - /** returns matching tag (or undef if match is not found) - * XXX: This is a workaround that eventually should go away! */ - template - format_tag_t matches_one_of_tag(Tags ...tags) const { - for (const auto tag: {tags...}) { - if (memory_desc_matches_tag(*md_, tag)) - return tag; - } - return format_tag::undef; - } - - /* offset section */ - - /** returns physical offset by logical one. logical offset is represented by - * an array \param pos. if \param is_pos_padded is true \param pos - * represents the position in already padded area */ - dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { - assert(is_blocking_desc()); - const blocking_desc_t &blk = blocking_desc(); - - dims_t pos_copy = {0}; - for (int d = 0; d < ndims(); ++d) - pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); - - dim_t phys_offset = offset0(); - - if (blk.inner_nblks > 0) { - dim_t blk_stride = 1; - for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { - const int d = blk.inner_idxs[iblk]; - const dim_t p = pos_copy[d] % blk.inner_blks[iblk]; - - phys_offset += p * blk_stride; - - pos_copy[d] /= blk.inner_blks[iblk]; - - blk_stride *= blk.inner_blks[iblk]; - } - } - - for (int d = 0; d < ndims(); ++d) { - const dim_t p = pos_copy[d]; - phys_offset += p * blk.strides[d]; - } - - return phys_offset; - } - - /** returns physical offset by logical one. logical offset is represented by - * a scalar \param l_offset. if \param is_pos_padded is true, \param - * l_offset represents logical offset in already padded area */ - dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { - assert(is_blocking_desc()); - dims_t pos; - for (int rd = 0; rd < ndims(); ++rd) { - const int d = ndims() - 1 - rd; - const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d]; - pos[d] = l_offset % cur_dim; - l_offset /= cur_dim; - } - return off_v(pos, is_pos_padded); - } - - /** returns physical offset by logical one. logical offset is represented by - * a tuple of indices (\param xn, ..., \param x1, \param x0) */ - template - dim_t off(Args... args) const { - assert(sizeof...(args) == ndims()); - dims_t pos = { args... }; - return off_v(pos, false); - } - - /** returns physical offset by logical one. logical offset is represented by - * a tuple of indices (\param xn, ..., \param x1, \param x0) in already - * padded area */ - template - dim_t off_padding(Args... args) const { - assert(sizeof...(args) == ndims()); - dims_t pos = { args... }; - return off_v(pos, true); - } - - /** returns physical offset by logical one. Logical offset is represented by - * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a - * user responsibility to adjust the result to get offset within blocks */ - template - dim_t blk_off(Args... args) const { - return _blk_off(args...); - } - - template - dim_t blk_off(T xn, Args... args) const { - return skip_first - ? blk_off(args...) - : blk_off(xn, args...); - } - - /* static functions section */ - /* TODO: replace with non-static, once md_ becomes non-const ref */ - - static status_t compute_blocking(memory_desc_t &memory_desc, - format_tag_t tag); - -private: - /* TODO: put logical_offset in utils */ - template - dim_t logical_offset(T x0) const { return x0; } - - template - dim_t logical_offset(T xn, Args... args) const { - const size_t n_args = sizeof...(args); - return xn * utils::array_product( - &dims()[ndims() - n_args]) + logical_offset(args...); - } - - template - dim_t _blk_off() const { return offset0(); } - - template - dim_t _blk_off(T xc, Args ...args) const { - assert(is_blocking_desc()); - constexpr int dc = ORIG_LEN - sizeof...(args) - 1; - return xc * blocking_desc().strides[dc] - + _blk_off(args...); - } -}; - -inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, - bool with_padding, bool with_data_type, int dim_start) const { - using namespace utils; - - if (one_of(format_kind(), format_kind::undef, format_kind::any)) - return false; - if (is_wino_desc() || is_rnn_packed_desc()) - return false; - - const int ds = dim_start; - const auto &blk = blocking_desc(); - const auto &r_blk = rhs.blocking_desc(); - - return ndims() == rhs.ndims() - && dim_start <= ndims() /* guard */ - && format_kind() == rhs.format_kind() - && IMPLICATION(with_data_type, data_type() == rhs.data_type()) - && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) - && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) - && blk.inner_nblks == r_blk.inner_nblks - && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) - && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) - && IMPLICATION(with_padding, true - && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds, - ndims() - ds) - && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds, - ndims() - ds)); -} - -inline bool memory_desc_wrapper::consistent_with( - const memory_desc_wrapper &rhs) const { - if (ndims() == rhs.ndims()) { - for (int d = 0; d < ndims(); ++d) { - if (dims()[d] != rhs.dims()[d]) return false; - } - return true; - } else { - /* TODO: revise. - * is the following possible? - * [1, a, b] <--reorder--> [a, b] - * [a, 1, b] <--reorder--> [a, b] - * not, at least for now */ - return false; - } -} - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp deleted file mode 100644 index ec077b308..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp +++ /dev/null @@ -1,295 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MEMORY_TRACKING_HPP -#define MEMORY_TRACKING_HPP - -#include -#include - -#include "nstl.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace memory_tracking { - -/* Memory tracking capabilities - * - * The main purpose of this header file is to provide uniform way to register - * required memory for a scratchpad at a primitive descriptor creation time - * and then easily access it having only the base address of the scratchpad. - * - * Primitives might contain multiple disjoint parts that require temporary - * buffers (known as scratchpad) during their execution. A primitive descriptor - * should summarize all the needs into one single number -- the buffer size - * that would be requested from a user. At execution time, the corresponding - * primitive will receive a base pointer to a scratchpad. It then needs to - * provide each part of algorithm the corresponding piece of memory. Three main - * challenges here are: - * 1. Track correct offset (from the base scratchpad address) for each piece - * 2. Algorithm might require that different memory pieces to be aligned, so - * the scratchpad size is no more just a sum of size of the corresponding - * subparts. - * 3. While a primitive is responsible for its scratchpad, the implementation - * might use some other basic blocks (e.g. cpu_reducer) that also require - * scratchpad memory. So there should be a simple way of passing the - * information back and force between the main algorithm (a primitive) and - * auxiliary stuff that lives completely separately from it (e.g. reducer). - * - * To address these challenges this header file provides 3 structures: - * 1. registry_t -- the class the stores the information about requested - * memory. The information includes required size and desired - * alignment for each piece. This class is also responsible - * for computing the right offset to a given piece using the - * base pointer. - * This class is basically a ledger with all entries. - * Lives in primitive descriptors. - * - * 2. registrar_t -- the interface to a registry_t to book memory. Used at - * primitive descriptor creation time only. Contains a - * reference to the corresponding *mutable* registry. - * Always modifiable. - * Allows chaining (using prefixes). - * - * 3. grantor_t -- the interface to a registry_t to access memory. Used at - * primitive execution time only. Contains a reference to - * the corresponding *constant* registry and base pointer. - * Always constant. - * Allows chaining (using prefixes). - * - * Both registrar_t and grantor_t allow chaining with extra prefix provided. - * The feature is useful when a primitive offload a part of computations to - * some other primitives which require their own scratchpad space - * (e.g. reducer). Prefixes are used to avoid key collision in cases when - * multiple sub-primitive (e.g. multiple reducers) are used. - * - * A short example below demonstrates how to use aforementioned classes. In it - * the main primitive is convolution that uses scratchpad for keeping padded - * bias. It also needs a reducer, that needs its own space as well. - * - * ``` c++ - * struct reducer_t { - * static void init(registrar_t &scratchpad) { - * // preserve space for the reduction (one page aligned) - * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096); - * } - * - * void exec(const grantor_t &scratchpad) { - * // get the pointer to preserved space. scratchpad came from - * // upper primitive (convolution in this example) - * auto space = scratchpad.get(key_reducer_space); - * - * space[:] += ...; - * } - * }; - * - * struct conv_t { - * struct pd_t { - * void init() { - * registrar_t scratchpad(scratchpad_registry_); - * - * // preserve a space for padded bias (using default alignment) - * scratchpad.book(key_conv_padded_bias, 128); - * - * // create a proxy registrar for the reducer All entries made - * // by reducer would live in convolution's registry, but would - * // have their own `prefix`, so no interference with conv's - * // buffers. - * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); - * - * reducer_t::init(reducer_scratchpad); - * } - * - * registry_t scratchpad_registry_; - * } - * - * void exec() { - * // get the base pointer to a scratchpad memory from a user - * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD); - * - * // create a grantor to the scratchpad (and provide the base - * // pointer). - * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); - * - * // access the padded_bias (need only key name and the grantor) - * auto padded_bias = scratchpad.get(key_conv_padded_bias); - * - * // to give the `right` grantor to reducer we need to add the - * // corresponding prefix, so that reducer would be able to access - * // its keys. The call is very similar to the one in pd_t::init - * // with only difference in types: grantor_t vs registrar_t. - * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); - * reducer->exec(reducer_scratchpad); - * } - * }; - * ``` - */ - - -/* namespace with common keys and prefixes */ -namespace names { -enum { - key_none = 0, - key_bnorm_tmp_mean, - key_bnorm_tmp_var, - key_bnorm_tmp_diff_ss, - key_bnorm_tmp_stats, - key_bnorm_reduction, - key_concat_iptrs, - key_concat_istrides, - key_concat_nelems, - key_concat_optrs, - key_conv_adjusted_scales, - key_conv_bia_reduction, - key_conv_gemm_col, - key_conv_gemm_imtr, - key_conv_int_dat_in_acc_dt, - key_conv_padded_bias, - key_conv_rtus_space, - key_conv_tr_diff_dst, - key_conv_tr_diff_dst_bctx, - key_conv_tr_src, - key_conv_tr_src_bctx, - key_conv_wei_reduction, - key_conv_wei_bia_reduction, - key_conv_wei_bia_reduction_bctx, - key_iprod_int_dat_in_acc_dt, - key_reducer_space, - key_reducer_space_bctx, - key_reorder_wino_plain, - key_reorder_wino_transform_space, - key_reorder_rnn_weights_quantization, - key_reorder_rnn_weights_reduction, - key_rnn_space, - key_rnn_ptrs_bia, - key_rnn_ptrs_wei_layer, - key_rnn_ptrs_wei_iter, - key_softmax_reduction, - key_wino_U, - key_wino_V, - key_wino_M, - key_barrier, -}; - -enum { - prefix_none = 0, - prefix_reducer_bia, - prefix_reducer_wei, -}; -} - -// level 0: 00 00 00 xxx -// level 1: 00 00 aa xxx -// level 2: 00 aa bb xxx -// level 3: aa bb cc xxx -// max # of levels: 3 + 1 (base_level) -// here: -// xxx : [1 .. MAX_KEY) : key -// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 - -using key_t = uint32_t; -enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), }; - -/// generates global key based on a prefix and a local key -inline key_t make_key(key_t prefix, key_t key) { return prefix + key; } - -/// generates global prefix based on the global parent and the local ones -inline key_t make_prefix(key_t parent_prefix, key_t prefix) -{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; } - -struct registrar_t; -struct grantor_t; - -struct registry_t { - void book(const key_t &key, size_t size, size_t alignment) { - if (size == 0) return; - assert(offset_map_.count(key) == 0); - - size = utils::rnd_up(size, minimal_alignment); - alignment = nstl::max(alignment, minimal_alignment); - offset_map_[key] = entry_t{size_, size, alignment}; - - size_ += size + alignment - minimal_alignment; - } - - void *get(const key_t &key, void *base_ptr) const { - if (base_ptr == nullptr) { assert(size() == 0); return nullptr; } - if (offset_map_.count(key) != 1) return nullptr; - - const auto &e = offset_map_.at(key); - base_ptr = utils::align_ptr(base_ptr, minimal_alignment); - char *ptr = (char *)base_ptr + e.offset; - return utils::align_ptr(ptr, e.alignment); - } - - size_t size() const - { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; } - - registrar_t registrar(); - grantor_t grantor(void *base_ptr) const; - -protected: - enum { minimal_alignment = 64 }; - struct entry_t { size_t offset, size, alignment; }; - - std::unordered_map offset_map_; - size_t size_ = 0; -}; - -struct registrar_t { - enum { default_alignment = 64 }; - - registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {} - registrar_t(registrar_t &parent, const key_t &prefix) - : registry_(parent.registry_) - , prefix_(make_prefix(parent.prefix_, prefix)) {} - - void book(const key_t &key, size_t size, - size_t alignment = default_alignment) - { registry_.book(make_key(prefix_, key), size, alignment); } - -protected: - registry_t ®istry_; - const key_t prefix_; -}; - -struct grantor_t { - grantor_t(const registry_t ®istry, void *base_ptr) - : registry_(registry), prefix_(0), base_ptr_(base_ptr) {} - grantor_t(const grantor_t &parent, const key_t &prefix) - : registry_(parent.registry_) - , prefix_(make_prefix(parent.prefix_, prefix)) - , base_ptr_(parent.base_ptr_) {} - - template T *get(const key_t &key) const - { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); } - -protected: - const registry_t ®istry_; - const key_t prefix_; - void *base_ptr_; -}; - -inline registrar_t registry_t::registrar() { return registrar_t(*this); } -inline grantor_t registry_t::grantor(void *base_ptr) const -{ return grantor_t(*this, base_ptr); } - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp deleted file mode 100644 index 2ef4a8fdd..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp +++ /dev/null @@ -1,131 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include -#include - -#include "mkldnn_debug.h" -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#define DPRINT(...) do { \ - int l = snprintf(str + written_len, str_len, __VA_ARGS__); \ - if (l < 0) return l; \ - if ((size_t)l >= str_len) return -1; \ - written_len += l; str_len -= l; \ -} while(0) - -int mkldnn_md2fmt_str(char *str, size_t str_len, - const mkldnn_memory_desc_t *mdesc) { - using namespace mkldnn::impl; - - if (str == nullptr || str_len <= 1u) - return -1; - - int written_len = 0; - - if (mdesc == nullptr) { - DPRINT("%s::%s::", - mkldnn_dt2str(data_type::undef), - mkldnn_fmt_kind2str(format_kind::undef)); - return written_len; - } - - memory_desc_wrapper md(mdesc); - - DPRINT("%s:", mkldnn_dt2str(md.data_type())); - - bool padded_dims = false, padded_offsets = false; - for (int d = 0; d < md.ndims(); ++d) { - if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true; - if (md.padded_offsets()[d] != 0) padded_offsets = true; - } - bool offset0 = md.offset0(); - DPRINT("%s%s%s:", - padded_dims ? "p" : "", - padded_offsets ? "o" : "", - offset0 ? "0" : ""); - - DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind())); - - if (!md.is_blocking_desc()) { - /* TODO: extend */ - DPRINT("%s:", ""); - } else { - const auto &blk = md.blocking_desc(); - - dims_t blocks; - md.compute_blocks(blocks); - - char dim_chars[MKLDNN_MAX_NDIMS + 1]; - - bool plain = true; - for (int d = 0; d < md.ndims(); ++d) { - dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; - if (blocks[d] != 1) plain = false; - } - - dims_t strides; - utils::array_copy(strides, blk.strides, md.ndims()); - utils::simultaneous_sort(strides, dim_chars, md.ndims(), - [](dim_t a, dim_t b) { return b - a; }); - - dim_chars[md.ndims()] = '\0'; - DPRINT("%s", dim_chars); - - if (!plain) { - for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { - DPRINT("%d%c", (int)blk.inner_blks[iblk], - 'a' + (char)blk.inner_idxs[iblk]); - } - } - - DPRINT("%s", ":"); - } - - DPRINT("f%lx", (long)md.extra().flags); - - return written_len; -} - -int mkldnn_md2dim_str(char *str, size_t str_len, - const mkldnn_memory_desc_t *mdesc) { - using namespace mkldnn::impl; - - if (str == nullptr || str_len <= 1) - return -1; - - int written_len = 0; - - if (mdesc == nullptr || mdesc->ndims == 0) { - DPRINT("%s", ""); - return written_len; - } - - memory_desc_wrapper md(mdesc); - - for (int d = 0; d < md.ndims() - 1; ++d) - DPRINT("%" PRId64 "x", md.dims()[d]); - DPRINT("%" PRId64, md.dims()[md.ndims() - 1]); - - return written_len; -} - -#undef DPRINT diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp deleted file mode 100644 index 16a8f7ea5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp +++ /dev/null @@ -1,365 +0,0 @@ -/******************************************************************************* -* Copyright 2018-2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* DO NOT EDIT, AUTO-GENERATED */ - -#include - -#include "mkldnn_debug.h" -#include "mkldnn_types.h" - -const char *mkldnn_status2str(mkldnn_status_t v) { - if (v == mkldnn_success) return "success"; - if (v == mkldnn_out_of_memory) return "out_of_memory"; - if (v == mkldnn_try_again) return "try_again"; - if (v == mkldnn_invalid_arguments) return "invalid_arguments"; - if (v == mkldnn_not_ready) return "not_ready"; - if (v == mkldnn_unimplemented) return "unimplemented"; - if (v == mkldnn_iterator_ends) return "iterator_ends"; - if (v == mkldnn_runtime_error) return "runtime_error"; - if (v == mkldnn_not_required) return "not_required"; - assert(!"unknown status"); - return "unknown status"; -} - -const char *mkldnn_dt2str(mkldnn_data_type_t v) { - if (v == mkldnn_data_type_undef) return "undef"; - if (v == mkldnn_f32) return "f32"; - if (v == mkldnn_s32) return "s32"; - if (v == mkldnn_s8) return "s8"; - if (v == mkldnn_u8) return "u8"; - assert(!"unknown dt"); - return "unknown dt"; -} - -const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) { - if (v == mkldnn_format_kind_undef) return "undef"; - if (v == mkldnn_format_kind_any) return "any"; - if (v == mkldnn_blocked) return "blocked"; - if (v == mkldnn_format_kind_wino) return "wino"; - if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed"; - assert(!"unknown fmt_kind"); - return "unknown fmt_kind"; -} - -const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) { - if (v == mkldnn_format_tag_undef) return "undef"; - if (v == mkldnn_format_tag_any) return "format_tag_any"; - if (v == mkldnn_a) return "a"; - if (v == mkldnn_ab) return "ab"; - if (v == mkldnn_abc) return "abc"; - if (v == mkldnn_abcd) return "abcd"; - if (v == mkldnn_abcde) return "abcde"; - if (v == mkldnn_abcdef) return "abcdef"; - if (v == mkldnn_abdec) return "abdec"; - if (v == mkldnn_acb) return "acb"; - if (v == mkldnn_acbde) return "acbde"; - if (v == mkldnn_acdb) return "acdb"; - if (v == mkldnn_acdeb) return "acdeb"; - if (v == mkldnn_ba) return "ba"; - if (v == mkldnn_bac) return "bac"; - if (v == mkldnn_bacd) return "bacd"; - if (v == mkldnn_bcda) return "bcda"; - if (v == mkldnn_cba) return "cba"; - if (v == mkldnn_cdba) return "cdba"; - if (v == mkldnn_cdeba) return "cdeba"; - if (v == mkldnn_decab) return "decab"; - if (v == mkldnn_Abc16a) return "Abc16a"; - if (v == mkldnn_ABc16a16b) return "ABc16a16b"; - if (v == mkldnn_aBc16b) return "aBc16b"; - if (v == mkldnn_ABc16b16a) return "ABc16b16a"; - if (v == mkldnn_Abc4a) return "Abc4a"; - if (v == mkldnn_aBc4b) return "aBc4b"; - if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b"; - if (v == mkldnn_ABc4b4a) return "ABc4b4a"; - if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a"; - if (v == mkldnn_ABc8a8b) return "ABc8a8b"; - if (v == mkldnn_aBc8b) return "aBc8b"; - if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b"; - if (v == mkldnn_ABc8b8a) return "ABc8b8a"; - if (v == mkldnn_Abcd16a) return "Abcd16a"; - if (v == mkldnn_ABcd16a16b) return "ABcd16a16b"; - if (v == mkldnn_aBcd16b) return "aBcd16b"; - if (v == mkldnn_ABcd16b16a) return "ABcd16b16a"; - if (v == mkldnn_aBCd16b16c) return "aBCd16b16c"; - if (v == mkldnn_aBCd16c16b) return "aBCd16c16b"; - if (v == mkldnn_Abcd4a) return "Abcd4a"; - if (v == mkldnn_aBcd4b) return "aBcd4b"; - if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b"; - if (v == mkldnn_ABcd4b4a) return "ABcd4b4a"; - if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c"; - if (v == mkldnn_aBCd4c4b) return "aBCd4c4b"; - if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a"; - if (v == mkldnn_ABcd8a8b) return "ABcd8a8b"; - if (v == mkldnn_aBcd8b) return "aBcd8b"; - if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b"; - if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b"; - if (v == mkldnn_ABcd8b8a) return "ABcd8b8a"; - if (v == mkldnn_aBCd8b8c) return "aBCd8b8c"; - if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c"; - if (v == mkldnn_aBCd8c8b) return "aBCd8c8b"; - if (v == mkldnn_Abcde16a) return "Abcde16a"; - if (v == mkldnn_ABcde16a16b) return "ABcde16a16b"; - if (v == mkldnn_aBcde16b) return "aBcde16b"; - if (v == mkldnn_ABcde16b16a) return "ABcde16b16a"; - if (v == mkldnn_aBCde16b16c) return "aBCde16b16c"; - if (v == mkldnn_aBCde16c16b) return "aBCde16c16b"; - if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c"; - if (v == mkldnn_Abcde4a) return "Abcde4a"; - if (v == mkldnn_aBcde4b) return "aBcde4b"; - if (v == mkldnn_ABcde4b4a) return "ABcde4b4a"; - if (v == mkldnn_aBCde4b4c) return "aBCde4b4c"; - if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c"; - if (v == mkldnn_aBCde4c4b) return "aBCde4c4b"; - if (v == mkldnn_Abcde8a) return "Abcde8a"; - if (v == mkldnn_ABcde8a8b) return "ABcde8a8b"; - if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b"; - if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b"; - if (v == mkldnn_ABcde8b8a) return "ABcde8b8a"; - if (v == mkldnn_aBCde8b8c) return "aBCde8b8c"; - if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c"; - if (v == mkldnn_aBCde8c8b) return "aBCde8c8b"; - if (v == mkldnn_aBcdef16b) return "aBcdef16b"; - if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c"; - if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b"; - if (v == mkldnn_aBcdef4b) return "aBcdef4b"; - if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b"; - if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c"; - if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c"; - if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b"; - if (v == mkldnn_aBdc16b) return "aBdc16b"; - if (v == mkldnn_aBdc4b) return "aBdc4b"; - if (v == mkldnn_aBdc8b) return "aBdc8b"; - if (v == mkldnn_aBdec16b) return "aBdec16b"; - if (v == mkldnn_aBdec4b) return "aBdec4b"; - if (v == mkldnn_aBdec8b) return "aBdec8b"; - if (v == mkldnn_aBdefc16b) return "aBdefc16b"; - if (v == mkldnn_aBdefc4b) return "aBdefc4b"; - if (v == mkldnn_aBdefc8b) return "aBdefc8b"; - if (v == mkldnn_Acb16a) return "Acb16a"; - if (v == mkldnn_Acb4a) return "Acb4a"; - if (v == mkldnn_Acb8a) return "Acb8a"; - if (v == mkldnn_aCBd16b16c) return "aCBd16b16c"; - if (v == mkldnn_aCBde16b16c) return "aCBde16b16c"; - if (v == mkldnn_Acdb16a) return "Acdb16a"; - if (v == mkldnn_Acdb4a) return "Acdb4a"; - if (v == mkldnn_Acdb8a) return "Acdb8a"; - if (v == mkldnn_Acdeb16a) return "Acdeb16a"; - if (v == mkldnn_Acdeb4a) return "Acdeb4a"; - if (v == mkldnn_Acdeb8a) return "Acdeb8a"; - if (v == mkldnn_BAc16a16b) return "BAc16a16b"; - if (v == mkldnn_BAcd16a16b) return "BAcd16a16b"; - if (v == mkldnn_format_tag_last) return "format_tag_last"; - if (v == mkldnn_x) return "x"; - if (v == mkldnn_nc) return "nc"; - if (v == mkldnn_cn) return "cn"; - if (v == mkldnn_ncw) return "ncw"; - if (v == mkldnn_nwc) return "nwc"; - if (v == mkldnn_nchw) return "nchw"; - if (v == mkldnn_nhwc) return "nhwc"; - if (v == mkldnn_chwn) return "chwn"; - if (v == mkldnn_ncdhw) return "ncdhw"; - if (v == mkldnn_ndhwc) return "ndhwc"; - if (v == mkldnn_oi) return "oi"; - if (v == mkldnn_io) return "io"; - if (v == mkldnn_oiw) return "oiw"; - if (v == mkldnn_wio) return "wio"; - if (v == mkldnn_oihw) return "oihw"; - if (v == mkldnn_hwio) return "hwio"; - if (v == mkldnn_ihwo) return "ihwo"; - if (v == mkldnn_iohw) return "iohw"; - if (v == mkldnn_oidhw) return "oidhw"; - if (v == mkldnn_dhwio) return "dhwio"; - if (v == mkldnn_goiw) return "goiw"; - if (v == mkldnn_goihw) return "goihw"; - if (v == mkldnn_hwigo) return "hwigo"; - if (v == mkldnn_giohw) return "giohw"; - if (v == mkldnn_goidhw) return "goidhw"; - if (v == mkldnn_tnc) return "tnc"; - if (v == mkldnn_ntc) return "ntc"; - if (v == mkldnn_ldsnc) return "ldsnc"; - if (v == mkldnn_ldigo) return "ldigo"; - if (v == mkldnn_ldgoi) return "ldgoi"; - if (v == mkldnn_ldgo) return "ldgo"; - if (v == mkldnn_nCdhw16c) return "nCdhw16c"; - if (v == mkldnn_nCdhw4c) return "nCdhw4c"; - if (v == mkldnn_nCdhw8c) return "nCdhw8c"; - if (v == mkldnn_nChw16c) return "nChw16c"; - if (v == mkldnn_nChw4c) return "nChw4c"; - if (v == mkldnn_nChw8c) return "nChw8c"; - if (v == mkldnn_nCw16c) return "nCw16c"; - if (v == mkldnn_nCw4c) return "nCw4c"; - if (v == mkldnn_nCw8c) return "nCw8c"; - if (v == mkldnn_IOw16o16i) return "IOw16o16i"; - if (v == mkldnn_OIw16i16o) return "OIw16i16o"; - if (v == mkldnn_OIw16o16i) return "OIw16o16i"; - if (v == mkldnn_Oiw16o) return "Oiw16o"; - if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i"; - if (v == mkldnn_OIw4i4o) return "OIw4i4o"; - if (v == mkldnn_Oiw4o) return "Oiw4o"; - if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i"; - if (v == mkldnn_OIw8i8o) return "OIw8i8o"; - if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o"; - if (v == mkldnn_OIw8o8i) return "OIw8o8i"; - if (v == mkldnn_Owi16o) return "Owi16o"; - if (v == mkldnn_Owi4o) return "Owi4o"; - if (v == mkldnn_Owi8o) return "Owi8o"; - if (v == mkldnn_IOhw16o16i) return "IOhw16o16i"; - if (v == mkldnn_Ohwi16o) return "Ohwi16o"; - if (v == mkldnn_Ohwi4o) return "Ohwi4o"; - if (v == mkldnn_Ohwi8o) return "Ohwi8o"; - if (v == mkldnn_OIhw16i16o) return "OIhw16i16o"; - if (v == mkldnn_OIhw16o16i) return "OIhw16o16i"; - if (v == mkldnn_Oihw16o) return "Oihw16o"; - if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i"; - if (v == mkldnn_OIhw4i4o) return "OIhw4i4o"; - if (v == mkldnn_Oihw4o) return "Oihw4o"; - if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i"; - if (v == mkldnn_OIhw8i8o) return "OIhw8i8o"; - if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o"; - if (v == mkldnn_OIhw8o8i) return "OIhw8o8i"; - if (v == mkldnn_Odhwi16o) return "Odhwi16o"; - if (v == mkldnn_Odhwi4o) return "Odhwi4o"; - if (v == mkldnn_Odhwi8o) return "Odhwi8o"; - if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o"; - if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i"; - if (v == mkldnn_Oidhw16o) return "Oidhw16o"; - if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o"; - if (v == mkldnn_Oidhw4o) return "Oidhw4o"; - if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i"; - if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o"; - if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i"; - if (v == mkldnn_Goiw16g) return "Goiw16g"; - if (v == mkldnn_gIOw16o16i) return "gIOw16o16i"; - if (v == mkldnn_gOIw16i16o) return "gOIw16i16o"; - if (v == mkldnn_gOIw16o16i) return "gOIw16o16i"; - if (v == mkldnn_gOiw16o) return "gOiw16o"; - if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i"; - if (v == mkldnn_gOIw4i4o) return "gOIw4i4o"; - if (v == mkldnn_gOiw4o) return "gOiw4o"; - if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i"; - if (v == mkldnn_gOIw8i8o) return "gOIw8i8o"; - if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o"; - if (v == mkldnn_gOIw8o8i) return "gOIw8o8i"; - if (v == mkldnn_gOwi16o) return "gOwi16o"; - if (v == mkldnn_gOwi4o) return "gOwi4o"; - if (v == mkldnn_gOwi8o) return "gOwi8o"; - if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i"; - if (v == mkldnn_gOhwi16o) return "gOhwi16o"; - if (v == mkldnn_gOhwi4o) return "gOhwi4o"; - if (v == mkldnn_gOhwi8o) return "gOhwi8o"; - if (v == mkldnn_Goihw16g) return "Goihw16g"; - if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o"; - if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i"; - if (v == mkldnn_gOihw16o) return "gOihw16o"; - if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i"; - if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i"; - if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o"; - if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i"; - if (v == mkldnn_gOihw4o) return "gOihw4o"; - if (v == mkldnn_Goihw8g) return "Goihw8g"; - if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i"; - if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o"; - if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o"; - if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i"; - if (v == mkldnn_gOdhwi16o) return "gOdhwi16o"; - if (v == mkldnn_gOdhwi4o) return "gOdhwi4o"; - if (v == mkldnn_gOdhwi8o) return "gOdhwi8o"; - if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o"; - if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i"; - if (v == mkldnn_gOidhw16o) return "gOidhw16o"; - if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o"; - if (v == mkldnn_gOidhw4o) return "gOidhw4o"; - if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i"; - if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o"; - if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i"; - assert(!"unknown fmt_tag"); - return "unknown fmt_tag"; -} - -const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) { - if (v == mkldnn_prop_kind_undef) return "undef"; - if (v == mkldnn_forward_training) return "forward_training"; - if (v == mkldnn_forward_inference) return "forward_inference"; - if (v == mkldnn_forward_scoring) return "forward_scoring"; - if (v == mkldnn_forward) return "forward"; - if (v == mkldnn_backward) return "backward"; - if (v == mkldnn_backward_data) return "backward_data"; - if (v == mkldnn_backward_weights) return "backward_weights"; - if (v == mkldnn_backward_bias) return "backward_bias"; - assert(!"unknown prop_kind"); - return "unknown prop_kind"; -} - -const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { - if (v == mkldnn_undefined_primitive) return "undef"; - if (v == mkldnn_reorder) return "reorder"; - if (v == mkldnn_shuffle) return "shuffle"; - if (v == mkldnn_concat) return "concat"; - if (v == mkldnn_sum) return "sum"; - if (v == mkldnn_convolution) return "convolution"; - if (v == mkldnn_deconvolution) return "deconvolution"; - if (v == mkldnn_eltwise) return "eltwise"; - if (v == mkldnn_softmax) return "softmax"; - if (v == mkldnn_pooling) return "pooling"; - if (v == mkldnn_lrn) return "lrn"; - if (v == mkldnn_batch_normalization) return "batch_normalization"; - if (v == mkldnn_inner_product) return "inner_product"; - if (v == mkldnn_rnn) return "rnn"; - assert(!"unknown prim_kind"); - return "unknown prim_kind"; -} - -const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { - if (v == mkldnn_alg_kind_undef) return "undef"; - if (v == mkldnn_convolution_direct) return "convolution_direct"; - if (v == mkldnn_convolution_winograd) return "convolution_winograd"; - if (v == mkldnn_convolution_auto) return "convolution_auto"; - if (v == mkldnn_deconvolution_direct) return "deconvolution_direct"; - if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd"; - if (v == mkldnn_eltwise_relu) return "eltwise_relu"; - if (v == mkldnn_eltwise_tanh) return "eltwise_tanh"; - if (v == mkldnn_eltwise_elu) return "eltwise_elu"; - if (v == mkldnn_eltwise_square) return "eltwise_square"; - if (v == mkldnn_eltwise_abs) return "eltwise_abs"; - if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt"; - if (v == mkldnn_eltwise_linear) return "eltwise_linear"; - if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu"; - if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu"; - if (v == mkldnn_eltwise_logistic) return "eltwise_logistic"; - if (v == mkldnn_pooling_max) return "pooling_max"; - if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; - if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; - if (v == mkldnn_pooling_avg) return "pooling_avg"; - if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; - if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; - if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; - if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; - if (v == mkldnn_vanilla_gru) return "vanilla_gru"; - if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset"; - assert(!"unknown alg_kind"); - return "unknown alg_kind"; -} - -const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) { - if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right"; - if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left"; - if (v == mkldnn_bidirectional_concat) return "bidirectional_concat"; - if (v == mkldnn_bidirectional_sum) return "bidirectional_sum"; - if (v == mkldnn_unidirectional) return "unidirectional"; - assert(!"unknown rnn_direction"); - return "unknown rnn_direction"; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp deleted file mode 100644 index 7e5789e2c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_THREAD_HPP -#define MKLDNN_THREAD_HPP - -#include "utils.hpp" -#include "z_magic.hpp" - -#define MKLDNN_THR_SEQ 0 -#define MKLDNN_THR_OMP 1 -#define MKLDNN_THR_TBB 2 - -/* Ideally this condition below should never happen (if the library is built - * using regular cmake). For the 3rd-party projects that build the library - * from the sources on their own try to guess the right threading... */ -#if !defined(MKLDNN_THR) -# define MKLDNN_THR MKLDNN_THR_TBB -#endif - -#if MKLDNN_THR == MKLDNN_THR_SEQ -#define MKLDNN_THR_SYNC 1 -inline int mkldnn_get_max_threads() { return 1; } -inline int mkldnn_get_num_threads() { return 1; } -inline int mkldnn_get_thread_num() { return 0; } -inline int mkldnn_in_parallel() { return 0; } -inline void mkldnn_thr_barrier() {} - -#define PRAGMA_OMP(...) - -#elif MKLDNN_THR == MKLDNN_THR_OMP -#include -#define MKLDNN_THR_SYNC 1 - -inline int mkldnn_get_max_threads() { return omp_get_max_threads(); } -inline int mkldnn_get_num_threads() { return omp_get_num_threads(); } -inline int mkldnn_get_thread_num() { return omp_get_thread_num(); } -inline int mkldnn_in_parallel() { return omp_in_parallel(); } -inline void mkldnn_thr_barrier() { -# pragma omp barrier -} - -#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) - -#elif MKLDNN_THR == MKLDNN_THR_TBB -#include "tbb/task_arena.h" -#include "tbb/parallel_for.h" -#define MKLDNN_THR_SYNC 0 - -inline int mkldnn_get_max_threads() -{ return tbb::this_task_arena::max_concurrency(); } -inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); } -inline int mkldnn_get_thread_num() -{ return tbb::this_task_arena::current_thread_index(); } -inline int mkldnn_in_parallel() { return 0; } -inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); } - -#define PRAGMA_OMP(...) - -#endif - -/* MSVC still supports omp 2.0 only */ -#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) -# define collapse(x) -# define PRAGMA_OMP_SIMD(...) -#else -# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) -#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER) - -namespace mkldnn { -namespace impl { - -inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; } - -template -inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { - T n_min = 1; - T &n_my = n_end; - if (team <= 1 || n == 0) { - n_start = 0; - n_my = n; - } else if (n_min == 1) { - // team = T1 + T2 - // n = T1*n1 + T2*n2 (n1 - n2 = 1) - T n1 = utils::div_up(n, (T)team); - T n2 = n1 - 1; - T T1 = n - n2 * (T)team; - n_my = (T)tid < T1 ? n1 : n2; - n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; - } - - n_end += n_start; -} - -} // namespace impl -} // namespace mkldnn - -#include "mkldnn_thread_parallel_nd.hpp" - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp deleted file mode 100644 index 50f9b2962..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp +++ /dev/null @@ -1,277 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP -#define MKLDNN_THREAD_PARALLEL_ND_HPP - -/* This header must be included by mkldnn_thread.hpp only */ - -/* Functions: - * - parallel(nthr, f) - executes f in parallel using at most - * nthr threads. If nthr equals 0 - * mkldnn_get_max_threads() threads is - * used - * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already - * created threads - * - parallel_nd(dims..., f) - creates a parallel section and then - * calls for_nd - * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then - * calls for_nd (mostly for convenience) - */ - -namespace mkldnn { -namespace impl { - -/* general parallelization */ -template -void parallel(int nthr, F f) { - if (nthr == 0) nthr = mkldnn_get_max_threads(); -#if MKLDNN_THR == MKLDNN_THR_SEQ - assert(nthr == 1); - f(0, 1); -#elif MKLDNN_THR == MKLDNN_THR_OMP - if (nthr == 1) { f(0, 1); return; } -# pragma omp parallel num_threads(nthr) - f(mkldnn_get_thread_num(), mkldnn_get_num_threads()); -#elif MKLDNN_THR == MKLDNN_THR_TBB - if (nthr == 1) { f(0, 1); return; } - tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner()); -#endif -} - -/* for_nd section */ - -template -void for_nd(const int ithr, const int nthr, const T0 &D0, F f) { - T0 start{0}, end{0}; - balance211(D0, nthr, ithr, start, end); - for (T0 d0 = start; d0 < end; ++d0) f(d0); -} - -template -void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) { - const size_t work_amount = (size_t)D0 * D1; - if (work_amount == 0) return; - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - T0 d0{0}; T1 d1{0}; - utils::nd_iterator_init(start, d0, D0, d1, D1); - for (size_t iwork = start; iwork < end; ++iwork) { - f(d0, d1); - utils::nd_iterator_step(d0, D0, d1, D1); - } -} - -template -void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, - const T2 &D2, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2; - if (work_amount == 0) return; - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - T0 d0{0}; T1 d1{0}; T2 d2{0}; - utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); - for (size_t iwork = start; iwork < end; ++iwork) { - f(d0, d1, d2); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); - } -} - -template -void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, - const T2 &D2, const T3 &D3, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2 * D3; - if (work_amount == 0) return; - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; - utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); - for (size_t iwork = start; iwork < end; ++iwork) { - f(d0, d1, d2, d3); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); - } -} - -template -void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, - const T2 &D2, const T3 &D3, const T4 &D4, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; - if (work_amount == 0) return; - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; - utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); - for (size_t iwork = start; iwork < end; ++iwork) { - f(d0, d1, d2, d3, d4); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); - } -} - -template -void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, - const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; - if (work_amount == 0) return; - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; - utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, - d5, D5); - for (size_t iwork = start; iwork < end; ++iwork) { - f(d0, d1, d2, d3, d4, d5); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); - } -} - -// Skip a lambda function in the parameter pack. -template -constexpr size_t get_work_amount(const T &v) { return 1; } -template -constexpr size_t get_work_amount(const T &v, Args &&...args) -{ return (size_t)v * get_work_amount(utils::forward(args)...); } - -/* parallel_nd and parallel_nd_in_omp section */ - -#if MKLDNN_THR != MKLDNN_THR_TBB -template -void parallel_nd(Args &&...args) { -#if MKLDNN_THR == MKLDNN_THR_SEQ - for_nd(0, 1, utils::forward(args)...); -#elif MKLDNN_THR == MKLDNN_THR_OMP - const bool do_parallel = get_work_amount(utils::forward(args)...) > 1; -# pragma omp parallel if (do_parallel) - { - const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads(); - const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num(); - for_nd(ithr, nthr, utils::forward(args)...); - } -#endif -} -#else // MKLDNN_THR != MKLDNN_THR_TBB - -// gcc 4.8 has a bug with passing parameter pack to lambdas. -// So have to explicitly instantiate all the cases. - -template -void parallel_nd(const T0 &D0, F f) { - const size_t work_amount = (size_t)D0; - if (work_amount == 0) return; - tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { - for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { - f(T0(iwork)); - } - }, tbb::static_partitioner()); -} - -template -void parallel_nd(const T0 &D0, const T1 &D1, F f) { - const size_t work_amount = (size_t)D0 * D1; - if (work_amount == 0) return; - tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { - T0 d0{0}; T1 d1{0}; - utils::nd_iterator_init(r.begin(), d0, D0, d1, D1); - for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { - f(d0, d1); - utils::nd_iterator_step(d0, D0, d1, D1); - } - }, tbb::static_partitioner()); -} - -template -void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2; - if (work_amount == 0) return; - tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { - T0 d0{0}; T1 d1{0}; T2 d2{0}; - utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2); - for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { - f(d0, d1, d2); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); - } - }, tbb::static_partitioner()); -} - -template -void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2 * D3; - if (work_amount == 0) return; - tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { - T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; - utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3); - for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { - f(d0, d1, d2, d3); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); - } - }, tbb::static_partitioner()); -} - -template -void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, - const T4 &D4, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; - if (work_amount == 0) return; - tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { - T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; - utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); - for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { - f(d0, d1, d2, d3, d4); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); - } - }, tbb::static_partitioner()); -} - -template -void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, - const T4 &D4, const T5 &D5, F f) { - const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; - if (work_amount == 0) return; - tbb::parallel_for(tbb::blocked_range(0, work_amount), [&](const tbb::blocked_range& r) { - T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; - utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, - d5, D5); - for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { - f(d0, d1, d2, d3, d4, d5); - utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); - } - }, tbb::static_partitioner()); -} -#endif - -template -void parallel_nd_in_omp(Args &&...args) { -#if MKLDNN_THR == MKLDNN_THR_SEQ - for_nd(0, 1, utils::forward(args)...); -#elif MKLDNN_THR == MKLDNN_THR_OMP - for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(), - utils::forward(args)...); -#elif MKLDNN_THR == MKLDNN_THR_TBB - assert(!"unsupported parallel_nd_in_omp()"); -#endif -} - -} // namespace impl -} // namespace mkldnn - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp deleted file mode 100644 index aa671a0b6..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp +++ /dev/null @@ -1,77 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef MKLDNN_TRAITS_HPP -#define MKLDNN_TRAITS_HPP - -#include -#include - -#include "mkldnn.h" -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "utils.hpp" -#include "z_magic.hpp" - -namespace mkldnn { -namespace impl { - -template struct prec_traits {}; /* ::type -> float */ -template struct data_traits {}; /* ::data_type -> f32 */ -template struct typesize_traits {}; /* ::data_type_size -> f32 */ -template struct pkind_traits {}; /* ::desc_type, ::query_d */ - -template <> struct prec_traits { typedef float type; }; -template <> struct prec_traits { typedef int32_t type; }; -template <> struct prec_traits { typedef int8_t type; }; -template <> struct prec_traits { typedef uint8_t type; }; - -template <> struct data_traits -{ static constexpr data_type_t data_type = data_type::f32; }; -template <> struct data_traits -{ static constexpr data_type_t data_type = data_type::s32; }; -template <> struct data_traits -{ static constexpr data_type_t data_type = data_type::s8; }; -template <> struct data_traits -{ static constexpr data_type_t data_type = data_type::u8; }; - -template <> struct typesize_traits<4> { typedef float type; }; -template <> struct typesize_traits<2> { typedef int16_t type; }; -template <> struct typesize_traits<1> { typedef uint8_t type; }; - -#define PKIND_TRAITS_INST(op) \ -template <> struct pkind_traits { \ - typedef CONCAT2(op, _desc_t) desc_type; \ - static constexpr query_t query_d = query::CONCAT2(op, _d); \ -} -PKIND_TRAITS_INST(convolution); -PKIND_TRAITS_INST(deconvolution); -PKIND_TRAITS_INST(shuffle); -PKIND_TRAITS_INST(eltwise); -PKIND_TRAITS_INST(softmax); -PKIND_TRAITS_INST(pooling); -PKIND_TRAITS_INST(lrn); -PKIND_TRAITS_INST(batch_normalization); -PKIND_TRAITS_INST(inner_product); -PKIND_TRAITS_INST(rnn); -#undef PKIND_TRAITS_INST - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp deleted file mode 100644 index f89ea999e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp +++ /dev/null @@ -1,193 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef NSTL_HPP -#define NSTL_HPP - -#include -#include -#include - -#include -#include - -#include "z_magic.hpp" - -namespace mkldnn { -namespace impl { - -void *malloc(size_t size, int alignment); -void free(void *p); - -struct c_compatible { - enum { default_alignment = 64 }; - static void *operator new(size_t sz) { - return malloc(sz, default_alignment); - } - static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; } - static void *operator new[](size_t sz) { - return malloc(sz, default_alignment); - } - static void operator delete(void *p) { free(p); } - static void operator delete[](void *p) { free(p); } -}; - -namespace nstl { - -template -inline const T abs(const T& a) { - return a >= 0 ? a : -a; -} - -template -inline const T& max(const T& a, const T& b) { - return a > b ? a : b; -} - -template -inline const T& min(const T& a, const T& b) { - return a < b ? a : b; -} - -template void swap(T& t1, T& t2) { - T tmp(t1); - t1 = t2; - t2 = tmp; -} - -// Rationale: MKL-DNN needs numeric limits implementation that does not -// generate dependencies on C++ run-time libraries. - -template struct numeric_limits; - -template<> struct numeric_limits { - static constexpr float lowest() { return -FLT_MAX; } - static constexpr float max() { return FLT_MAX; } -}; - -template<> struct numeric_limits { - static constexpr int lowest() { return INT32_MIN; } - static constexpr int max() { return INT32_MAX; } -}; - -template<> struct numeric_limits { - static constexpr int16_t lowest() { return INT16_MIN; } - static constexpr int16_t max() { return INT16_MAX; } -}; - -template<> struct numeric_limits { - static constexpr int8_t lowest() { return INT8_MIN; } - static constexpr int8_t max() { return INT8_MAX; } -}; - -template<> struct numeric_limits { - static constexpr uint8_t lowest() { return 0; } - static constexpr uint8_t max() { return UINT8_MAX; } -}; - -template struct is_integral -{ static constexpr bool value = false; }; -template<> struct is_integral { static constexpr bool value = true; }; -template<> struct is_integral { static constexpr bool value = true; }; -template<> struct is_integral { static constexpr bool value = true; }; -template<> struct is_integral { static constexpr bool value = true; }; - -template struct is_same -{ static constexpr bool value = false; }; -template struct is_same -{ static constexpr bool value = true; }; - -// Rationale: MKL-DNN needs container implementations that do not generate -// dependencies on C++ run-time libraries. -// -// Implementation philosophy: caller is responsible to check if the operation -// is valid. The only functions that have to return status are those that -// depend on memory allocation or similar operations. -// -// This means that e.g. an operator [] does not have to check for boundaries. -// The caller should have checked the boundaries. If it did not we crash and -// burn: this is a bug in MKL-DNN and throwing an exception would not have been -// recoverable. -// -// On the other hand, insert() or resize() or a similar operation needs to -// return a status because the outcome depends on factors external to the -// caller. The situation is probably also not recoverable also, but MKL-DNN -// needs to be nice and report "out of memory" to the users. - -enum nstl_status_t { - success = 0, - out_of_memory -}; - -template class vector: public c_compatible { -private: - std::vector _impl; -public: - typedef typename std::vector::iterator iterator; - typedef typename std::vector::const_iterator const_iterator; - typedef typename std::vector::size_type size_type; - vector() {} - vector(size_type n): _impl(n) {} - vector(size_type n, const T &value): _impl(n, value) {} - template - vector(input_iterator first, input_iterator last): _impl(first, last) {} - ~vector() {} - size_type size() const { return _impl.size(); } - T& operator[] (size_type i) { return _impl[i]; } - const T& operator[] (size_type i) const { return _impl[i]; } - iterator begin() { return _impl.begin(); } - const_iterator begin() const { return _impl.begin(); } - iterator end() { return _impl.end(); } - const_iterator end() const { return _impl.end(); } - template - nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end) - { - _impl.insert(pos, begin, end); - return success; - } - void clear() { _impl.clear(); } - void push_back(const T& t) { _impl.push_back(t); } - void resize(size_type count) { _impl.resize(count); } - void reserve(size_type count) { _impl.reserve(count); } -}; - -template class map: public c_compatible { -private: - std::map _impl; -public: - typedef typename std::map::iterator iterator; - typedef typename std::map::const_iterator const_iterator; - typedef typename std::map::size_type size_type; - map() {} - ~map() {} - size_type size() const { return _impl.size(); } - T& operator[](const Key &k) { return _impl[k]; } - const T& operator[](const Key &k) const { return _impl[k]; } - iterator begin() { return _impl.begin(); } - const_iterator begin() const { return _impl.begin(); } - iterator end() { return _impl.end(); } - const_iterator end() const { return _impl.end(); } - template - void clear() { _impl.clear(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp deleted file mode 100644 index be96e654f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t pooling_desc_init(pooling_desc_t *pool_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *dst_desc, - const dims_t strides, const dims_t kernel, const dims_t padding_l, - const dims_t padding_r, padding_kind_t padding_kind) { - bool args_ok = true - && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l) - && one_of(alg_kind, pooling_max, - pooling_avg_include_padding, - pooling_avg_exclude_padding) - && one_of(padding_kind, padding_kind::padding_zero); - if (!args_ok) return invalid_arguments; - - if (padding_r == nullptr) padding_r = padding_l; - - auto pd = pooling_desc_t(); - pd.primitive_kind = primitive_kind::pooling; - pd.prop_kind = prop_kind; - pd.alg_kind = alg_kind; - pd.src_desc.ndims = src_desc->ndims; - - const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); - - pd.diff_src_desc = pd.src_desc = zero_md(); - pd.diff_dst_desc = pd.dst_desc = zero_md(); - - (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc; - (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc; - - int sp_dims = src_desc->ndims - 2; - utils::array_copy(pd.strides, strides, sp_dims); - utils::array_copy(pd.kernel, kernel, sp_dims); - utils::array_copy(pd.padding[0], padding_l, sp_dims); - utils::array_copy(pd.padding[1], padding_r, sp_dims); - - pd.padding_kind = padding_kind; - if (one_of(alg_kind, pooling_max, pooling_avg_include_padding, - pooling_avg_exclude_padding)) { - pd.accum_data_type = types::default_accum_data_type( - src_desc->data_type, dst_desc->data_type); - } else { - pd.accum_data_type = dst_desc->data_type; - } - - bool consistency = true - && utils::one_of(src_desc->ndims, 4, 5) - && utils::one_of(dst_desc->ndims, 4, 5) - && src_desc->dims[0] == dst_desc->dims[0] - && src_desc->dims[1] == dst_desc->dims[1]; - for (int i = 2; i < src_desc->ndims; ++i) - consistency = consistency && ( - (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2] - + padding_r[i - 2]) / strides[i - 2] + 1 - == dst_desc->dims[i]); - if (!consistency) return invalid_arguments; - - *pool_desc = pd; - return success; -} -} - -status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc, - prop_kind_t prop_kind, alg_kind_t alg_kind, - const memory_desc_t *src_desc, const memory_desc_t *dst_desc, - const dims_t strides, const dims_t kernel, const dims_t padding_l, - const dims_t padding_r, padding_kind_t padding_kind) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc, - dst_desc, strides, kernel, padding_l, padding_r, padding_kind); -} - -status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc, - alg_kind_t alg_kind, const memory_desc_t *diff_src_desc, - const memory_desc_t *diff_dst_desc, const dims_t strides, - const dims_t kernel, const dims_t padding_l, const dims_t padding_r, - padding_kind_t padding_kind) { - return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind, - diff_src_desc, diff_dst_desc, strides, kernel, padding_l, - padding_r, padding_kind); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp deleted file mode 100644 index 4c9c00941..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp +++ /dev/null @@ -1,238 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef POOLING_PD_HPP -#define POOLING_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" -#include "type_helpers.hpp" - -namespace mkldnn { -namespace impl { - -struct pooling_fwd_pd_t; - -struct pooling_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::pooling; - - pooling_pd_t(engine_t *engine, - const pooling_desc_t *adesc, - const primitive_attr_t *attr, - const pooling_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , ws_md_() - {} - - const pooling_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::pooling_d: - *(const pooling_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common pooling aux functions */ - - dim_t MB() const { return src_desc().dims[0]; } - dim_t C() const { return src_desc().dims[1]; } - - dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } - dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } - dim_t IW() const { return src_desc().dims[ndims() - 1]; } - - dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } - dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } - dim_t OW() const { return dst_desc().dims[ndims() - 1]; } - - dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } - dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } - dim_t KW() const { return desc_.kernel[ndims() - 3]; } - - dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } - dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } - dim_t KSW() const { return desc_.strides[ndims() - 3]; } - - dim_t padFront() const - { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } - dim_t padBack() const - { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } - dim_t padT() const - { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } - dim_t padB() const - { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } - dim_t padL() const { return desc_.padding[0][ndims() - 3]; } - dim_t padR() const { return desc_.padding[1][ndims() - 3]; } - - int ndims() const { return src_desc().ndims; } - bool is_3d() const { return ndims() == 5; } - - bool has_zero_dim_memory() const - { return memory_desc_wrapper(src_desc()).has_zero_dim(); } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - -protected: - pooling_desc_t desc_; - const pooling_fwd_pd_t *hint_fwd_pd_; - - memory_desc_t ws_md_; - - void init_default_ws() { - ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); - ws_md_.data_type = indices_data_type(); - } - - data_type_t indices_data_type() const { - /* the simplest way to express 256... */ - const int u8_max = nstl::numeric_limits< - typename prec_traits::type>::max(); - return utils::array_product(desc()->kernel, ndims()) <= u8_max - ? data_type::u8 : data_type::s32; - } - -private: - const memory_desc_t &src_desc() const - { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } - const memory_desc_t &dst_desc() const - { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } -}; - -struct pooling_fwd_pd_t: public pooling_pd_t { - typedef pooling_fwd_pd_t base_class; - typedef pooling_fwd_pd_t hint_class; - - pooling_fwd_pd_t(engine_t *engine, - const pooling_desc_t *adesc, - const primitive_attr_t *attr, - const pooling_fwd_pd_t *hint_fwd_pd) - : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) - , src_md_(desc_.src_desc) - , dst_md_(desc_.dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_SRC) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } - - virtual int n_inputs() const override { return 1; } - virtual int n_outputs() const override - { return 1 + (workspace_md() != nullptr); } - -protected: - memory_desc_t src_md_; - memory_desc_t dst_md_; - - virtual status_t set_default_params() { - if (dst_md()->format_kind != format_kind::any) - return status::success; - - if (src_md()->format_kind != format_kind::blocked) - return status::unimplemented; - - return memory_desc_init_by_blocking_desc(dst_md_, - src_md_.format_desc.blocking); - } -}; - -struct pooling_bwd_pd_t: public pooling_pd_t { - typedef pooling_bwd_pd_t base_class; - typedef pooling_fwd_pd_t hint_class; - - pooling_bwd_pd_t(engine_t *engine, - const pooling_desc_t *adesc, - const primitive_attr_t *attr, - const pooling_fwd_pd_t *hint_fwd_pd) - : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_src_md_(desc_.diff_src_desc) - , diff_dst_md_(desc_.diff_dst_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_DIFF_DST) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) - return arg_usage_t::input; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_src_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_dst_md_ : nullptr; } - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } - - virtual int n_inputs() const override - { return 1 + (workspace_md() != nullptr); } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t diff_src_md_; - memory_desc_t diff_dst_md_; - - virtual status_t set_default_params() { - if (diff_src_md()->format_kind != format_kind::any) - return status::success; - - if (diff_dst_md()->format_kind != format_kind::blocked) - return status::unimplemented; - - return memory_desc_init_by_blocking_desc(diff_src_md_, - diff_dst_md_.format_desc.blocking); - } -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp deleted file mode 100644 index fdf6522f6..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "primitive_desc.hpp" -#include "primitive.hpp" -#include "type_helpers.hpp" -#include "stream.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::primitive_kind; - -namespace { -// XXX: this is a huge hammer. This disables all and any msan checks on -// primitives outputs. -// -// A proper approach would be an implementation-specific unpoisoning. -void unpoison_outputs(const exec_args_t &args) { - for(const auto &arg: args) { - if (arg.second.is_const) continue; - auto *mem = arg.second.mem; - void *p; - mem->get_data_handle(&p); - size_t s = memory_desc_wrapper(*mem->md()).size(); - msan_unpoison(p, s); - } -} -} - -status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) { - if (primitive_desc) delete primitive_desc; - return success; -} - -status_t mkldnn_primitive_create(primitive_t **primitive, - const primitive_desc_t *primitive_desc) { - if (utils::any_null(primitive, primitive_desc)) - return invalid_arguments; - return primitive_desc->create_primitive(primitive); -} - -status_t mkldnn_primitive_execute(const primitive_t *primitive, - stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) { - bool ok = true - && !utils::any_null(primitive, stream) - && primitive->engine() == stream->engine() - && IMPLICATION(nargs > 0, c_args != nullptr); - if (!ok) return invalid_arguments; - - exec_args_t args; - status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args); - if (status != status::success) return status; - - exec_ctx_t ctx(stream, std::move(args)); - - if (mkldnn_verbose()->level) { - double ms = get_msec(); - status = primitive->execute(ctx); - ms = get_msec() - ms; - printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms); - fflush(0); - } else { - status = primitive->execute(ctx); - } - - if (msan_enabled) unpoison_outputs(ctx.args()); - - return status; -} - -status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive, - const primitive_desc_t **primitive_desc) { - if (utils::any_null(primitive, primitive_desc)) - return invalid_arguments; - return safe_ptr_assign(*primitive_desc, - primitive->pd()); -} - -status_t mkldnn_primitive_destroy(primitive_t *primitive) { - if (primitive != nullptr) - delete primitive; - return success; -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp deleted file mode 100644 index 3b506d6d1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef PRIMITIVE_HPP -#define PRIMITIVE_HPP - -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "primitive_desc.hpp" -#include "primitive_exec_types.hpp" - -/** \brief A pure virtual primitive class - * - * Primitive contains links to its inputs & outputs, though it does not track - * their readiness on execution step. - * - * @remark @b Rational. - * Dependencies are essential through-out the whole MKL-DNN library, so it - * makes sense to include them on the very low level. On the other hand, - * tracking them should be a task for corresponding essence, like scheduler, - * stream or whatever. Primitive itself should know nothing about the - * environment it is running in. - * - * @note - * To make user experience better we should provide API which allows - * achieving the best (or good enough) performance when creating primitives - * in natural order: i.e. from bottom to top for forward pass and from top to - * bottom for backward pass. Please consider restriction [1] in Level 0. - */ -struct mkldnn_primitive: public mkldnn::impl::c_compatible { - mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd) - : pd_(pd->clone()) {} - virtual ~mkldnn_primitive() { delete pd_; } - - /** returns primitive's engine */ - mkldnn::impl::engine_t *engine() const { return pd_->engine(); } - /** returns primitive's inputs */ - const mkldnn::impl::primitive_desc_t *pd() const { return pd_; } - /** returns primitive's kind */ - mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); } - - /** executes primitive with execution context @p ctx */ - virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx) - const = 0; - -protected: - const mkldnn::impl::primitive_desc_t *pd_; - -private: - mkldnn_primitive() = delete; - mkldnn_primitive(const mkldnn_primitive &) = delete; - mkldnn_primitive(mkldnn_primitive &&) = delete; - mkldnn_primitive &operator=(const mkldnn_primitive &) = delete; - mkldnn_primitive &operator=(mkldnn_primitive &&) = delete; -}; - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp deleted file mode 100644 index 9fd638842..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp +++ /dev/null @@ -1,290 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_attr.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; - -namespace mkldnn { -namespace impl { - -status_t scales_t::set(dim_t count, int mask, const float *scales) { - cleanup(); - - count_ = count; - mask_ = mask; - - if (count_ == 1) { - scales_ = scales_buf_; - utils::array_set(scales_, scales[0], scales_buf_size); - } else { - scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); - if (scales_ == nullptr) - return status::out_of_memory; - - for (dim_t c = 0; c < count_; ++c) - scales_[c] = scales[c]; - } - - return status::success; -} - -} -} - -status_t post_ops_t::append_sum(float scale) { - if (len_ == capacity) - return out_of_memory; - - entry_[len_].kind = primitive_kind::sum; - entry_[len_].sum.scale = scale; - - len_++; - - return success; -} - -status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, - float beta) { - using namespace mkldnn::impl::alg_kind; - bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, - eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, - eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); - if (!known_alg) - return invalid_arguments; - - if (len_ == capacity) - return out_of_memory; - - entry_[len_].kind = primitive_kind::eltwise; - entry_[len_].eltwise.scale = scale; - entry_[len_].eltwise.alg = alg; - entry_[len_].eltwise.alpha = alpha; - entry_[len_].eltwise.beta = beta; - - len_++; - - return success; -} - -status_t primitive_attr_t::set_scratchpad_mode( - scratchpad_mode_t scratchpad_mode) { - using namespace mkldnn::impl::scratchpad_mode; - - const bool ok = one_of(scratchpad_mode, library, user); - if (!ok) - return invalid_arguments; - - scratchpad_mode_ = scratchpad_mode; - return success; -} - -status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { - this->post_ops_ = post_ops; - return success; -} - -/* Public C API */ - -status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { - if (attr == nullptr) - return invalid_arguments; - - return safe_ptr_assign(*attr, - new mkldnn_primitive_attr); -} - -status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, - const primitive_attr_t *existing_attr) { - if (any_null(attr, existing_attr)) - return invalid_arguments; - - return safe_ptr_assign(*attr, - existing_attr->clone()); -} - -status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { - if (attr) - delete attr; - - return success; -} - -status_t mkldnn_primitive_attr_get_scratchpad_mode( - const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { - if (any_null(attr, scratchpad_mode)) - return invalid_arguments; - - *scratchpad_mode = attr->scratchpad_mode_; - - return success; -} - -status_t mkldnn_primitive_attr_set_scratchpad_mode( - primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { - if (any_null(attr)) - return invalid_arguments; - - return attr->set_scratchpad_mode(scratchpad_mode); -} - -status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, - dim_t *count, int *mask, const float **scales) { - if (any_null(attr, count, mask, scales)) - return invalid_arguments; - - *count = attr->output_scales_.count_; - *mask = attr->output_scales_.mask_; - *scales = attr->output_scales_.scales_; - - return success; -} - -status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, - dim_t count, int mask, const float *scales) { - bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; - if (!ok) - return invalid_arguments; - - return attr->output_scales_.set(count, mask, scales); -} - -status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, - const post_ops_t **post_ops) { - if (any_null(attr, post_ops)) - return invalid_arguments; - - *post_ops = &attr->post_ops_; - return success; -} - -status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, - const post_ops_t *post_ops) { - if (any_null(attr, post_ops)) - return invalid_arguments; - - return attr->set_post_ops(*post_ops); -} - -status_t mkldnn_post_ops_create(post_ops_t **post_ops) { - if (post_ops == nullptr) - return invalid_arguments; - - return safe_ptr_assign(*post_ops, new mkldnn_post_ops); -} - -status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { - if (post_ops) - delete post_ops; - - return success; -} - -int mkldnn_post_ops_len(const post_ops_t *post_ops) { - if (post_ops) - return post_ops->len_; - - return 0; -} - -primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, - int index) { - bool ok = post_ops && 0 <= index && index < post_ops->len_; - if (!ok) - return primitive_kind::undefined; - - return post_ops->entry_[index].kind; -} - -status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { - if (post_ops == nullptr) - return invalid_arguments; - - return post_ops->append_sum(scale); -} - -namespace { -bool simple_get_params_check(const post_ops_t *post_ops, int index, - primitive_kind_t kind) { - bool ok = true - && post_ops != nullptr - && 0 <= index - && index < post_ops->len_ - && post_ops->entry_[index].kind == kind; - return ok; -} -} - -status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, - float *scale) { - bool ok = true - && simple_get_params_check(post_ops, index, primitive_kind::sum) - && !any_null(scale); - if (!ok) - return invalid_arguments; - - *scale = post_ops->entry_[index].sum.scale; - return success; -} - -status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, - alg_kind_t kind, float alpha, float beta) { - if (post_ops == nullptr) - return invalid_arguments; - - return post_ops->append_eltwise(scale, kind, alpha, beta); -} - -status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, - int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { - bool ok = true - && simple_get_params_check(post_ops, index, primitive_kind::eltwise) - && !any_null(scale, alpha, beta); - if (!ok) - return invalid_arguments; - - const auto &e = post_ops->entry_[index].eltwise; - *scale = e.scale; - *alg = e.alg; - *alpha = e.alpha; - *beta = e.beta; - - return success; -} - -status_t mkldnn_primitive_attr_set_rnn_data_qparams( - primitive_attr_t *attr, const float scale, const float shift) { - if (attr == nullptr) - return invalid_arguments; - - return attr->rnn_data_qparams_.set(scale, shift); -} - -status_t mkldnn_primitive_attr_set_rnn_weights_qparams( - primitive_attr_t *attr, dim_t count, int mask, const float *scales) { - bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; - if (!ok) - return invalid_arguments; - - return attr->rnn_weights_qparams_.set(count, mask, scales); -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp deleted file mode 100644 index e2130c7ab..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp +++ /dev/null @@ -1,183 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef PRIMITIVE_ATTR_HPP -#define PRIMITIVE_ATTR_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct rnn_data_qparams_t : public c_compatible { - rnn_data_qparams_t() : scale_(1.), shift_(0.) {} - bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } - - status_t set(float scale, float shift) { - scale_ = scale; - shift_ = shift; - return status::success; - } - - float scale_; - float shift_; -}; - -struct scales_t: public c_compatible { - scales_t(): count_(1), mask_(0), scales_(scales_buf_) - { set(1.); } - - scales_t(const scales_t &rhs): scales_t() - { set(rhs.count_, rhs.mask_, rhs.scales_); } - - ~scales_t() { cleanup(); } - - scales_t &operator=(const scales_t &rhs) { - if (&rhs == this) - return *this; - status_t status = set(rhs.count_, rhs.mask_, rhs.scales_); - assert(status == status::success); - (void)status; - return *this; - } - - bool has_default_values() const { - for (dim_t c = 0; c < count_; ++c) { - if(scales_[c] != 1.) return false; - } - return true; - } - - status_t set(dim_t count, int mask, const float *scales); - status_t set(float single_scale) { return this->set(1, 0, &single_scale); } - - dim_t count_; - int mask_; - float *scales_; - -private: - enum { scales_buf_size = 16 }; - float scales_buf_[scales_buf_size]; - - void cleanup() { - if (scales_ != scales_buf_ && scales_ != nullptr) - impl::free(scales_); - - count_ = 1; - mask_ = 0; - scales_ = scales_buf_; - } -}; - -} -} - -struct mkldnn_post_ops: public mkldnn::impl::c_compatible { - struct entry_t { - struct eltwise_t { - mkldnn::impl::alg_kind_t alg; - float scale, alpha, beta; - }; - - mkldnn::impl::primitive_kind_t kind; - union { - struct { float scale; } sum; - eltwise_t eltwise; - }; - - bool is_eltwise(bool require_scale_one = true) const { - using namespace mkldnn::impl; - return kind == primitive_kind::eltwise - && IMPLICATION(require_scale_one, eltwise.scale == 1.f); - } - - bool is_relu(bool require_scale_one = true, - bool require_nslope_zero = true) const { - using namespace mkldnn::impl; - return is_eltwise(require_scale_one) - && eltwise.alg == alg_kind::eltwise_relu - && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); - } - - bool is_sum(bool require_scale_one = true) const { - using namespace mkldnn::impl; - return kind == primitive_kind::sum - && IMPLICATION(require_scale_one, sum.scale == 1.f); - } - }; - - mkldnn_post_ops(): len_(0) {} - - mkldnn::impl::status_t append_sum(float scale); - mkldnn::impl::status_t append_eltwise(float scale, - mkldnn::impl::alg_kind_t alg, float alpha, float beta); - - int find(mkldnn::impl::primitive_kind_t kind, int start = 0, - int stop = -1) const { - if (stop == -1) stop = len_; - stop = mkldnn::impl::nstl::min(stop, len_); - for (int idx = start; idx < stop; ++idx) - if (entry_[idx].kind == kind) return idx; - return -1; - } - - bool has_default_values() const { return len_ == 0; } - - bool contain(mkldnn::impl::primitive_kind_t kind, int index) const - { return find(kind, index, index + 1) == index; } - - enum { capacity = 4 }; - - int len_; - entry_t entry_[capacity]; -}; - -struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible { - mkldnn_primitive_attr() - : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library) - {} - - mkldnn_primitive_attr *clone() const - { return new mkldnn_primitive_attr(*this); } - - /** Returns true if the attributes have default values. - * - * @note The scratchpad_mode_ is not take into account */ - bool has_default_values() const { - return true - && output_scales_.has_default_values() - && post_ops_.has_default_values() - && rnn_data_qparams_.has_default_values() - && rnn_weights_qparams_.has_default_values(); - } - - mkldnn::impl::status_t set_scratchpad_mode( - mkldnn::impl::scratchpad_mode_t scratchpad_mode); - mkldnn::impl::status_t set_post_ops( - const mkldnn::impl::post_ops_t &post_ops); - - mkldnn::impl::scratchpad_mode_t scratchpad_mode_; - mkldnn::impl::scales_t output_scales_; - mkldnn::impl::post_ops_t post_ops_; - mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_; - mkldnn::impl::scales_t rnn_weights_qparams_; -}; - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp deleted file mode 100644 index 723c41e05..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "primitive_desc.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; - -status_t primitive_desc_t::query(query_t what, int idx, void *result) const { - auto safe_ret_md = [&](const memory_desc_t *_) { - if (_ == nullptr) return not_required; - *(const memory_desc_t **)result = _; - return success; - }; - - switch (what) { - case query::engine: *(engine_t**)result = engine(); break; - case query::primitive_kind: *(primitive_kind_t*)result = kind(); break; - - case query::scratchpad_engine: - *(engine_t**)result = scratchpad_engine(); break; - - case query::memory_consumption_s64: - *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break; - - case query::op_d: - if (idx != 0 || op_desc() == nullptr) return invalid_arguments; - *(const_c_op_desc_t *)result - = static_cast(op_desc()); break; - - case query::src_md: return safe_ret_md(src_md(idx)); - case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); - case query::dst_md: return safe_ret_md(dst_md(idx)); - case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); - case query::weights_md: return safe_ret_md(weights_md(idx)); - case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx)); - case query::workspace_md: - if (idx != 0) return status::invalid_arguments; - return safe_ret_md(workspace_md(idx)); - case query::scratchpad_md: - if (idx != 0) return status::invalid_arguments; - return safe_ret_md(scratchpad_md(idx)); - - case query::num_of_inputs_s32: *(int*)result = n_inputs(); break; - case query::num_of_outputs_s32: *(int*)result = n_outputs(); break; - - case query::impl_info_str: *(const char **)result = name(); break; - - default: return unimplemented; - } - return success; -} - -status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc, - const primitive_attr_t **attr) { - if (utils::any_null(primitive_desc, attr)) - return invalid_arguments; - - *attr = primitive_desc->attr(); - return success; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp deleted file mode 100644 index 536dcfa1d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp +++ /dev/null @@ -1,174 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef PRIMITIVE_DESC_HPP -#define PRIMITIVE_DESC_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "primitive_attr.hpp" -#include "verbose.hpp" - -struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { - using md_t = mkldnn::impl::memory_desc_t; - - mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, - const mkldnn::impl::primitive_attr_t *attr, - mkldnn::impl::primitive_kind_t kind) - : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } - - mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, - mkldnn::impl::primitive_kind_t kind) - : engine_(engine), kind_(kind) { info_[0] = '\0'; } - - virtual mkldnn_primitive_desc *clone() const = 0; - virtual ~mkldnn_primitive_desc() {} - - const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } - mkldnn::impl::engine_t *engine() const { return engine_; } - mkldnn::impl::primitive_kind_t kind() const { return kind_; } - - virtual void init_info() {} - const char *info() const { return info_; } - - mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() - { return scratchpad_registry_; } - const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const - { return scratchpad_registry_; } - virtual mkldnn::impl::engine_t *scratchpad_engine() const - { return engine_; } - - virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } - - enum class arg_usage_t { unused, input, output }; - virtual arg_usage_t arg_usage( - mkldnn::impl::primitive_arg_index_t arg) const { - using mkldnn::impl::types::is_zero_md; - if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) - return arg_usage_t::output; - return arg_usage_t::unused; - } - -# define DECLARE_MD_STUB(stub) \ - virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ - { return nullptr; } - - DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); - DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); - DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); - DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); - DECLARE_MD_STUB(workspace_md); -# undef DECLARE_MD_STUB - - const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { - return idx == 0 ? &scratchpad_md_ : nullptr; - } - - virtual void init_scratchpad_md() { - auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); - mkldnn::impl::dims_t dims = { size }; - mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, - mkldnn::impl::data_type::u8, mkldnn_x); - } - - /** returns the scratchpad size for the given scratchpad mode. */ - mkldnn::impl::dim_t scratchpad_size( - mkldnn::impl::scratchpad_mode_t mode) const { - if (mode != attr_.scratchpad_mode_) return 0; - return scratchpad_registry().size(); - } - - virtual int n_inputs() const { return 0; } - virtual int n_outputs() const { return 0; } - - virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, - void *result) const; - - virtual mkldnn::impl::status_t create_primitive( - mkldnn::impl::primitive_t **primitive) const = 0; - - virtual const char *name() const { return "mkldnn_primitive_desc"; } - - /* static magic */ - - template - static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, - const mkldnn::impl::op_desc_t *adesc, - const mkldnn::impl::primitive_attr_t *attr, - mkldnn::impl::engine_t *engine, - const mkldnn::impl::primitive_desc_t *hint_fwd) { - using namespace mkldnn::impl; - using namespace mkldnn::impl::status; - using pd_op_desc_t = typename pkind_traits::desc_type; - if (adesc->kind != pd_t::base_pkind) return invalid_arguments; - assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); - auto hint = - reinterpret_cast(hint_fwd); - auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); - if (_pd == nullptr) return out_of_memory; - if (_pd->init() != success) { delete _pd; return unimplemented; } - _pd->init_info(); - _pd->init_scratchpad_md(); - *pd = _pd; - return success; - } - -protected: - mkldnn::impl::engine_t *engine_; - mkldnn::impl::primitive_attr_t attr_; - mkldnn::impl::primitive_kind_t kind_; - - mkldnn::impl::memory_desc_t scratchpad_md_; - - char info_[MKLDNN_VERBOSE_BUF_LEN]; - - mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; - -protected: - /** compares ws between fwd_pd and this (make sense to use for bwd_pd) - * Expectation: this already set workspace, and this workspace should - * exactly match the one from fwd_pd */ - bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { - using namespace mkldnn::impl; - if (!workspace_md()) return true; // the impl lives fine w/o workspace - return fwd_pd && fwd_pd->workspace_md() - && *fwd_pd->workspace_md() == *workspace_md(); - } -}; - -#define DECLARE_COMMON_PD_t(impl_name, ...) \ - virtual pd_t *clone() const override { return new pd_t(*this); } \ - virtual status_t create_primitive(primitive_t **p) const override { \ - double ms = get_msec(); \ - auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ - ms = get_msec() - ms; \ - if (mkldnn_verbose()->level >= 2) { \ - printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ - fflush(0); \ - } \ - return ret; \ - } \ - virtual const char *name() const override { return impl_name; } -#define DECLARE_COMMON_PD_T(impl_name, ...) \ - DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp deleted file mode 100644 index 43e5a31ef..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "memory.hpp" -#include "primitive.hpp" -#include "primitive_exec_types.hpp" - -namespace mkldnn { -namespace impl { - -status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, - const mkldnn_exec_arg_t *c_args, exec_args_t &args) { - using namespace status; - - if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; - - int n_inputs = 0; - int n_outputs = 0; - - for (int i = 0; i < nargs; ++i) { - primitive_arg_index_t arg = c_args[i].arg; - auto *mem = c_args[i].memory; - - switch (pd->arg_usage(arg)) { - case primitive_desc_t::arg_usage_t::input: - if (args.count(arg) != 0) return invalid_arguments; - args[arg] = {mem, true}; - n_inputs++; - break; - case primitive_desc_t::arg_usage_t::output: - if (args.count(arg) != 0) return invalid_arguments; - args[arg] = {mem, false}; - n_outputs++; - break; - case primitive_desc_t::arg_usage_t::unused: - break; - } - } - - bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md()); - - if (n_inputs != pd->n_inputs()) return invalid_arguments; - if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0)) - return invalid_arguments; - - return success; -} - -const void *exec_ctx_t::input(primitive_arg_index_t arg) const { - if (args_.count(arg) != 1) return nullptr; - const auto ma = args_.at(arg); - assert(ma.is_const); - void *ptr; - status_t status = ma.mem->get_data_handle(&ptr); - assert(status == status::success); MAYBE_UNUSED(status); - return ptr; -} - -void *exec_ctx_t::output(primitive_arg_index_t arg) const { - if (args_.count(arg) != 1) return nullptr; - const auto ma = args_.at(arg); - assert(!ma.is_const); - void *ptr; - status_t status = ma.mem->get_data_handle(&ptr); - assert(status == status::success); MAYBE_UNUSED(status); - return ptr; -} - -const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const { - assert(args_.count(arg) == 1); - const auto ma = args_.at(arg); - assert(!ma.is_const); - return ma.mem; -} - -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp deleted file mode 100644 index 0645891da..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef PRIMITIVE_EXEC_TYPES_HPP -#define PRIMITIVE_EXEC_TYPES_HPP - -#include - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "memory.hpp" -#include "primitive_desc.hpp" - -namespace mkldnn { -namespace impl { - -struct memory_arg_t { - memory_t *mem; - bool is_const; -}; - -using exec_args_t = std::unordered_map; - -status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, - const mkldnn_exec_arg_t *c_args, exec_args_t &args); - -/** Primitive execution context (helps passing stream, memories, and events. */ -struct exec_ctx_t { - exec_ctx_t(const exec_ctx_t &) = default; - exec_ctx_t(exec_ctx_t &&) = default; - - exec_ctx_t(stream_t *stream): stream_(stream) {} - exec_ctx_t(stream_t *stream, exec_args_t &&args) - : stream_(stream) - , args_(std::move(args)) {} - - stream_t *stream() const { return stream_; } - const exec_args_t &args() const { return args_; } - - /* tentative solution... TODO: replace with functions return memory_t */ - const void *input(primitive_arg_index_t arg) const; - void *output(primitive_arg_index_t arg) const; - - const memory_t *memory(primitive_arg_index_t arg) const; - -private: - stream_t *stream_; - exec_args_t args_; -}; - -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp deleted file mode 100644 index 5a1cd7d37..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "primitive_desc.hpp" -#include "type_helpers.hpp" -#include "primitive_iterator.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; - -status_t mkldnn_primitive_desc_iterator_create( - primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc, - const primitive_attr_t *attr, engine_t *engine, - const primitive_desc_t *hint_fwd_pd) { - const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; - - auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd); - if (it == nullptr) return out_of_memory; - - ++(*it); - if (*it == it->end()) { - delete it; - return unimplemented; - } - - *iterator = it; - return success; -} - -status_t mkldnn_primitive_desc_iterator_next( - primitive_desc_iterator_t *iterator) { - if (iterator == nullptr) return invalid_arguments; - ++(*iterator); - return *iterator == iterator->end() ? iterator_ends : success; -} - -primitive_desc_t *mkldnn_primitive_desc_iterator_fetch( - const primitive_desc_iterator_t *iterator) { - if (iterator == nullptr) return nullptr; - return *(*iterator); -} - -status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc, - const primitive_desc_t *existing_primitive_desc) { - if (utils::any_null(primitive_desc, existing_primitive_desc)) - return invalid_arguments; - return safe_ptr_assign(*primitive_desc, - existing_primitive_desc->clone()); -} - -status_t mkldnn_primitive_desc_iterator_destroy( - primitive_desc_iterator_t *iterator) { - if (iterator != nullptr) - delete iterator; - return success; -} - -status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc, - const_c_op_desc_t c_op_desc, const primitive_attr_t *attr, - engine_t *engine, const primitive_desc_t *hint_fwd_pd) { - const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; - - mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd); - ++it; - if (it == it.end()) return unimplemented; - - return safe_ptr_assign(*primitive_desc, *it); -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp deleted file mode 100644 index 4e88ab3aa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ -#ifndef PRIMITIVE_ITERATOR_HPP -#define PRIMITIVE_ITERATOR_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "primitive_desc.hpp" -#include "type_helpers.hpp" - -struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { - using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; - - mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, - const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) - : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) - , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) - , impl_list_(engine_->get_implementation_list()), last_idx_(0) - { - while (impl_list_[last_idx_] != nullptr) ++last_idx_; - } - ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } - - bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const - { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } - bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const - { return !operator==(rhs); } - - mkldnn::impl::primitive_desc_iterator_t end() const - { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } - - mkldnn::impl::primitive_desc_iterator_t &operator++() { - if (pd_) { delete pd_; pd_ = nullptr; } - while (++idx_ != last_idx_) { - auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, - hint_fwd_pd_); - if (s == mkldnn::impl::status::success) break; - } - return *this; - } - - mkldnn::impl::primitive_desc_t *operator*() const { - if (*this == end() || pd_ == nullptr) return nullptr; - return pd_->clone(); - } - -protected: - int idx_; - mkldnn::impl::engine_t *engine_; - mkldnn::impl::primitive_desc_t *pd_; - const mkldnn::impl::op_desc_t *op_desc_; - const mkldnn::impl::primitive_attr_t attr_; - const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; - const pd_create_f *impl_list_; - int last_idx_; - -private: - mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) - : idx_(last_idx), engine_(engine), pd_(nullptr) - , op_desc_(nullptr), hint_fwd_pd_(nullptr) - , impl_list_(nullptr), last_idx_(last_idx) {} -}; - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp deleted file mode 100644 index 835cd7358..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/query.cpp +++ /dev/null @@ -1,59 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "primitive_desc.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; - -status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc, - query_t what, int index, void *result) { - if (any_null(primitive_desc, result)) - return invalid_arguments; - - return primitive_desc->query(what, index, result); -} - -const memory_desc_t *mkldnn_primitive_desc_query_md( - const primitive_desc_t *primitive_desc, query_t what, int index) { - const memory_desc_t *res_md = nullptr; - bool args_ok = true - && primitive_desc != nullptr - && (what & query::some_md) == query::some_md - && what != query::some_md - && mkldnn_primitive_desc_query(primitive_desc, - what, index, &res_md) == success; - return args_ok ? res_md : nullptr; -} - -int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc, - query_t what, int index) { - int res_s32; - bool args_ok = primitive_desc != nullptr - && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32) - && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32) - == success; - return args_ok ? res_s32 : 0; -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp deleted file mode 100644 index d11f1a036..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "reorder_pd.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; - -status_t mkldnn_reorder_primitive_desc_create( - primitive_desc_t **reorder_pd, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md, - const primitive_attr_t *attr) { - if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md)) - return invalid_arguments; - - auto s_ek = src_engine->kind(); - auto d_ek = dst_engine->kind(); - if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek))) - return invalid_arguments; - - auto r_pd = reinterpret_cast(reorder_pd); - auto s_mdw = memory_desc_wrapper(*src_md); - auto d_mdw = memory_desc_wrapper(*dst_md); - - if (!s_mdw.consistent_with(d_mdw)) - return invalid_arguments; - - auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine; - - const primitive_attr_t dummy_attr; - if (attr == NULL) - attr = &dummy_attr; - - for (auto r = e->get_reorder_implementation_list(); *r; ++r) { - if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md) - == success) { - (*r_pd)->init_info(); - (*r_pd)->init_scratchpad_md(); - return success; - } - } - return unimplemented; -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp deleted file mode 100644 index 963cb0f58..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp +++ /dev/null @@ -1,85 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef REORDER_PD_HPP -#define REORDER_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "primitive_attr.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct reorder_pd_t: public primitive_desc_t { - reorder_pd_t(engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) - : primitive_desc_t(engine, attr, primitive_kind::reorder) - , src_engine_(src_engine) - , dst_engine_(dst_engine) - , scratchpad_engine_(nullptr) - , src_md_(*src_md) - , dst_md_(*dst_md) - {} - - virtual const op_desc_t *op_desc() const override { return nullptr; } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_FROM) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_TO) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &src_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - - virtual int n_inputs() const override { return 1; } - virtual int n_outputs() const override { return 1; } - - float alpha() const { return attr()->output_scales_.scales_[0]; } - float beta() const { - const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); - return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; - } - virtual mkldnn::impl::engine_t *scratchpad_engine() const override - { return scratchpad_engine_; } - -protected: - engine_t *src_engine_; - engine_t *dst_engine_; - engine_t *scratchpad_engine_; - - memory_desc_t src_md_; - memory_desc_t dst_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp deleted file mode 100644 index 36967431a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp +++ /dev/null @@ -1,400 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "cpu/gemm/os_blas.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::types; -using namespace mkldnn::impl::utils; - -namespace { -memory_desc_t copy_maybe_null(const memory_desc_t *md) { - return md ? *md : zero_md(); -} - -rnn_desc_t zero_rnn_desc() { - auto rd = rnn_desc_t(); - rd.src_layer_desc = zero_md(); - rd.src_iter_desc = zero_md(); - rd.weights_layer_desc = zero_md(); - rd.weights_iter_desc = zero_md(); - rd.bias_desc = zero_md(); - rd.dst_layer_desc = zero_md(); - rd.dst_iter_desc = zero_md(); - rd.diff_src_layer_desc = zero_md(); - rd.diff_src_iter_desc = zero_md(); - rd.diff_weights_layer_desc = zero_md(); - rd.diff_weights_iter_desc = zero_md(); - rd.diff_bias_desc = zero_md(); - rd.diff_dst_layer_desc = zero_md(); - rd.diff_dst_iter_desc = zero_md(); - return rd; -} -} - -/* Public C Api */ - -status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, - mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, - unsigned int flags, float alpha, float clipping) { - using namespace mkldnn::impl::alg_kind; - - bool args_ok = true - && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, - gru_linear_before_reset) - && IMPLICATION(cell_kind == vanilla_rnn, - one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); - if (!args_ok) - return invalid_arguments; - - auto rcd = mkldnn_rnn_cell_desc_t(); - - rcd.cell_kind = cell_kind; - rcd.activation_kind = act_f; - rcd.flags = flags; - rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; - rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; - - *rnn_cell_desc = rcd; - - return success; -} - -int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { - switch (rnn_cell_desc->cell_kind) { - case mkldnn::impl::alg_kind::vanilla_rnn: return 1; - case mkldnn::impl::alg_kind::vanilla_gru: return 3; - case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; - case mkldnn::impl::alg_kind::vanilla_lstm: return 4; - default: assert(!"unknown cell kind"); return 0; - } - return 0; -} - -int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { - switch (rnn_cell_desc->cell_kind) { - case mkldnn::impl::alg_kind::vanilla_rnn: return 1; - case mkldnn::impl::alg_kind::vanilla_gru: return 1; - case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; - case mkldnn::impl::alg_kind::vanilla_lstm: return 2; - default: assert(!"unknown cell kind"); return 0; - } - return 0; -} - -status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, - prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, - const memory_desc_t *src_iter_desc, - const memory_desc_t *weights_layer_desc, - const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_layer_desc, - const memory_desc_t *dst_iter_desc) { - using namespace data_type; - data_type_t src_layer_dt = src_layer_desc->data_type; - data_type_t dst_layer_dt = dst_layer_desc->data_type; - data_type_t weights_iter_dt = weights_iter_desc->data_type; - data_type_t weights_layer_dt = weights_layer_desc->data_type; - - bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, - weights_layer_dt) - && IMPLICATION(!is_zero_md(src_iter_desc), - src_iter_desc->data_type == f32) - && IMPLICATION(!is_zero_md(dst_iter_desc), - dst_iter_desc->data_type == f32) - && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); - -#if USE_MKL_PACKED_GEMM - bool is_u8u8u8 = src_layer_dt == u8 - && IMPLICATION(!is_zero_md(src_iter_desc), - src_iter_desc->data_type == u8) - && IMPLICATION(!is_zero_md(dst_iter_desc), - dst_iter_desc->data_type == u8) - && one_of(dst_layer_dt, u8, f32) - && everyone_is(s8, weights_iter_dt, weights_layer_dt) - && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); - - bool is_f32u8f32 = src_layer_dt == u8 - && IMPLICATION(!is_zero_md(src_iter_desc), - src_iter_desc->data_type == f32) - && IMPLICATION(!is_zero_md(dst_iter_desc), - dst_iter_desc->data_type == f32) - && one_of(dst_layer_dt, u8, f32) - && everyone_is(s8, weights_iter_dt, weights_layer_dt) - && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); - - bool is_inference = prop_kind == prop_kind::forward_inference; - bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; - - return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) - ? success - : unimplemented; -#else - return is_f32 ? success : unimplemented; -#endif -} - -status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, - rnn_direction_t direction, int L, int D, int T, int N, int S, int G, - int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, - const memory_desc_t *src_iter_desc, - const memory_desc_t *weights_layer_desc, - const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_layer_desc, - const memory_desc_t *dst_iter_desc) { - bool args_ok; - - // * algorithm specific - args_ok = true - && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, - DIC == SIC); - if (!args_ok) return invalid_arguments; - int extra_bias = - rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; - - // * on num layers - args_ok = true - && L == weights_layer_desc->dims[0] - && L == weights_iter_desc->dims[0] - && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) - && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) - && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); - if (!args_ok) return invalid_arguments; - - // * on num directions - args_ok = true - && D == weights_layer_desc->dims[1] - && D == weights_iter_desc->dims[1] - && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) - && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) - && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); - if (!args_ok) return invalid_arguments; - - // * on num iterations - args_ok = true - && T == src_layer_desc->dims[0] - && T == dst_layer_desc->dims[0]; - if (!args_ok) return invalid_arguments; - - // * on mb - args_ok = true - && N == src_layer_desc->dims[1] - && N == dst_layer_desc->dims[1] - && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) - && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); - if (!args_ok) return invalid_arguments; - - // * on num gates - args_ok = true - && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) - && G == weights_layer_desc->dims[3] - && G == weights_iter_desc->dims[3] - && IMPLICATION(!is_zero_md(bias_desc), - G + extra_bias == bias_desc->dims[2]); - if (!args_ok) return invalid_arguments; - - // * on num states - args_ok = true - && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) - && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) - && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); - if (!args_ok) return invalid_arguments; - - // * on slc - args_ok = true - && SLC == weights_layer_desc->dims[2] - && SLC == src_layer_desc->dims[2]; - if (!args_ok) return invalid_arguments; - - // * on sic - args_ok = true - && SIC == weights_iter_desc->dims[2] - && IMPLICATION(!is_zero_md(src_iter_desc), - SIC == src_iter_desc->dims[4]); - if (!args_ok) return invalid_arguments; - - // * on dlc - int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; - args_ok = true - && DLC == dlc_multiplier * DIC - && DLC == dst_layer_desc->dims[2]; - if (!args_ok) return invalid_arguments; - - // * on dic - args_ok = true - && DIC == weights_layer_desc->dims[4] - && DIC == weights_iter_desc->dims[4] - && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) - && IMPLICATION(!is_zero_md(dst_iter_desc), - DIC == dst_iter_desc->dims[4]); - if (!args_ok) return invalid_arguments; - - // * unrolling/fusion conditions - args_ok = true - && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) - && IMPLICATION(T > 1, SIC == DIC); - if (!args_ok) return invalid_arguments; - - return success; -} - -status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, - prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, - const rnn_direction_t direction, const memory_desc_t *src_layer_desc, - const memory_desc_t *src_iter_desc, - const memory_desc_t *weights_layer_desc, - const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_layer_desc, - const memory_desc_t *dst_iter_desc) { - bool args_ok = true && rnn_cell_desc != nullptr - && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, - dst_layer_desc); - if (!args_ok) return invalid_arguments; - - //check dimensions consistency - int L = weights_layer_desc->dims[0]; - int T = src_layer_desc->dims[0]; - int N = src_layer_desc->dims[1]; - const int D = one_of(direction, mkldnn_unidirectional_left2right, - mkldnn_unidirectional_right2left) ? - 1 : - 2; - int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); - int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); - int SLC = src_layer_desc->dims[2]; - int SIC = weights_iter_desc->dims[2]; - int DLC = dst_layer_desc->dims[2]; - int DIC = weights_layer_desc->dims[4]; - - CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, - G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, - weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, - dst_iter_desc)); - - CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, - src_layer_desc, src_iter_desc, weights_layer_desc, - weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); - - // Create the descriptor - mkldnn_rnn_desc_t rd = zero_rnn_desc(); - - rd.primitive_kind = primitive_kind::rnn; - rd.prop_kind = prop_kind; - rd.cell_desc = *rnn_cell_desc; - rd.direction = direction; - rd.src_layer_desc = copy_maybe_null(src_layer_desc); - rd.src_iter_desc = copy_maybe_null(src_iter_desc); - rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); - rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); - rd.bias_desc = copy_maybe_null(bias_desc); - rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); - rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); - - *rnn_desc = rd; - - return success; -} - -status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, - prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, - const rnn_direction_t direction, const memory_desc_t *src_layer_desc, - const memory_desc_t *src_iter_desc, - const memory_desc_t *weights_layer_desc, - const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, - const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, - const memory_desc_t *diff_src_layer_desc, - const memory_desc_t *diff_src_iter_desc, - const memory_desc_t *diff_weights_layer_desc, - const memory_desc_t *diff_weights_iter_desc, - const memory_desc_t *diff_bias_desc, - const memory_desc_t *diff_dst_layer_desc, - const memory_desc_t *diff_dst_iter_desc) { - bool args_ok = true - && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, - dst_layer_desc, diff_src_layer_desc, - diff_weights_layer_desc, diff_weights_iter_desc, - diff_dst_layer_desc); - if (!args_ok) - return invalid_arguments; - - auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { - return is_zero_md(a_md) == is_zero_md(b_md); - }; - - args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) - && xnor_md(dst_iter_desc, diff_dst_iter_desc) - && xnor_md(src_iter_desc, diff_src_iter_desc); - if (!args_ok) - return invalid_arguments; - - //check dimensions consistency - int L = weights_layer_desc->dims[0]; - int T = src_layer_desc->dims[0]; - int N = src_layer_desc->dims[1]; - const int D = one_of(direction, mkldnn_unidirectional_left2right, - mkldnn_unidirectional_right2left) ? - 1 : - 2; - int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); - int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); - int SLC = src_layer_desc->dims[2]; - int SIC = weights_iter_desc->dims[2]; - int DLC = dst_layer_desc->dims[2]; - int DIC = weights_layer_desc->dims[4]; - - status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, - G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, - weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, - dst_iter_desc); - if (st != success) return st; - - st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, - G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, - diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, - diff_dst_layer_desc, diff_dst_iter_desc); - if (st != success) return st; - - mkldnn_rnn_desc_t rd = zero_rnn_desc(); - - rd.primitive_kind = primitive_kind::rnn; - rd.prop_kind = prop_kind; - rd.cell_desc = *rnn_cell_desc; - rd.direction = direction; - - rd.src_layer_desc = copy_maybe_null(src_layer_desc); - rd.src_iter_desc = copy_maybe_null(src_iter_desc); - rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); - rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); - rd.bias_desc = copy_maybe_null(bias_desc); - rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); - rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); - rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); - rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); - rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); - rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); - rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); - rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); - rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); - - *rnn_desc = rd; - - return success; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp deleted file mode 100644 index 1ee2ba111..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp +++ /dev/null @@ -1,280 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef RNN_PD_HPP -#define RNN_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" -#include "type_helpers.hpp" - -namespace mkldnn { -namespace impl { - -struct rnn_fwd_pd_t; - -struct rnn_pd_t : public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::rnn; - - rnn_pd_t(engine_t *engine, - const rnn_desc_t *adesc, - const primitive_attr_t *attr, - const rnn_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , src_layer_md_(desc_.src_layer_desc) - , src_iter_md_(desc_.src_iter_desc) - , weights_layer_md_(desc_.weights_layer_desc) - , weights_iter_md_(desc_.weights_iter_desc) - , bias_md_(desc_.bias_desc) - , dst_layer_md_(desc_.dst_layer_desc) - , dst_iter_md_(desc_.dst_iter_desc) - , ws_md_() - {} - - const rnn_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - virtual const memory_desc_t *src_md(int index = 0) const override { - if (index == 0) return &src_layer_md_; - if (index == 1 && with_src_iter()) return &src_iter_md_; - return nullptr; - } - virtual const memory_desc_t *weights_md(int index = 0) const override { - if (index == 0) return &weights_layer_md_; - if (index == 1) return &weights_iter_md_; - if (index == 2 && with_bias()) return &bias_md_; - return nullptr; - } - virtual const memory_desc_t *dst_md(int index = 0) const override { - if (index == 0) return &dst_layer_md_; - if (index == 1 && with_dst_iter()) return &dst_iter_md_; - return nullptr; - } - virtual const memory_desc_t *workspace_md(int index = 0) const override - { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } - - /* common pooling aux functions */ - - bool is_training() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::backward); - } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - - dim_t T() const { return desc_.src_layer_desc.dims[0]; } - dim_t MB() const { return desc_.src_layer_desc.dims[1]; } - - dim_t L() const { return desc_.weights_layer_desc.dims[0]; } - dim_t D() const { return desc_.weights_layer_desc.dims[1]; } - - dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } - - dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } - dim_t G() const { return desc_.weights_layer_desc.dims[3]; } - dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; } - - dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } - - bool with_bias() const - { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); } - - bool with_src_iter() const - { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); } - - bool with_dst_iter() const - { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); } - - mkldnn::impl::alg_kind_t cell_kind() const - { return desc_.cell_desc.cell_kind; } - mkldnn::impl::alg_kind_t activation_kind() const - { return desc_.cell_desc.activation_kind; } - - bool is_lbr() const - { return cell_kind() == mkldnn_gru_linear_before_reset; } - - mkldnn_rnn_direction_t direction() const { return desc_.direction; } - -protected: - rnn_desc_t desc_; - const rnn_fwd_pd_t *hint_fwd_pd_; - - memory_desc_t src_layer_md_; - memory_desc_t src_iter_md_; - memory_desc_t weights_layer_md_; - memory_desc_t weights_iter_md_; - memory_desc_t bias_md_; - memory_desc_t dst_layer_md_; - memory_desc_t dst_iter_md_; - - memory_desc_t ws_md_; -}; - -struct rnn_fwd_pd_t: public rnn_pd_t { - typedef rnn_fwd_pd_t base_class; - typedef rnn_fwd_pd_t hint_class; - - rnn_fwd_pd_t(engine_t *engine, - const rnn_desc_t *adesc, - const primitive_attr_t *attr, - const rnn_fwd_pd_t *hint_fwd_pd) - : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_SRC_LAYER) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter()) - return arg_usage_t::input; - - if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, - MKLDNN_ARG_WEIGHTS_ITER)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_BIAS && with_bias()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST_LAYER) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter()) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && is_training()) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual int n_inputs() const override - { return 3 + with_bias() + with_src_iter(); } - virtual int n_outputs() const override - { return 1 + with_dst_iter() + is_training(); } -}; - -struct rnn_bwd_pd_t : public rnn_pd_t { - typedef rnn_bwd_pd_t base_class; - typedef rnn_fwd_pd_t hint_class; - - rnn_bwd_pd_t(engine_t *engine, - const rnn_desc_t *adesc, - const primitive_attr_t *attr, - const rnn_fwd_pd_t *hint_fwd_pd) - : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_src_layer_md_(desc_.diff_src_layer_desc) - , diff_src_iter_md_(desc_.diff_src_iter_desc) - , diff_weights_layer_md_(desc_.diff_weights_layer_desc) - , diff_weights_iter_md_(desc_.diff_weights_iter_desc) - , diff_bias_md_(desc_.diff_bias_desc) - , diff_dst_layer_md_(desc_.diff_dst_layer_desc) - , diff_dst_iter_md_(desc_.diff_dst_iter_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER, - MKLDNN_ARG_DIFF_DST_LAYER)) - return arg_usage_t::input; - - if (with_src_iter()) { - if (arg == MKLDNN_ARG_SRC_ITER) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC_ITER) - return arg_usage_t::output; - } - - if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, - MKLDNN_ARG_WEIGHTS_ITER)) - return arg_usage_t::input; - - if (with_bias()) { - if (arg == MKLDNN_ARG_BIAS) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_BIAS) - return arg_usage_t::output; - } - - if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER) - && with_dst_iter()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_WORKSPACE) - return arg_usage_t::input; - - if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER, - MKLDNN_ARG_DIFF_WEIGHTS_LAYER, - MKLDNN_ARG_DIFF_WEIGHTS_ITER)) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *diff_src_md(int index = 0) const override { - if (index == 0) return &diff_src_layer_md_; - if (index == 1 && with_src_iter()) return &diff_src_iter_md_; - return nullptr; - } - virtual const memory_desc_t *diff_weights_md( - int index = 0) const override { - if (index == 0) return &diff_weights_layer_md_; - if (index == 1) return &diff_weights_iter_md_; - if (index == 2 && with_bias()) return &diff_bias_md_; - return nullptr; - } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override { - if (index == 0) return &diff_dst_layer_md_; - if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; - return nullptr; - } - - virtual int n_inputs() const override - { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); } - virtual int n_outputs() const override - { return 3 + with_src_iter() + with_bias(); } - -protected: - memory_desc_t diff_src_layer_md_; - memory_desc_t diff_src_iter_md_; - memory_desc_t diff_weights_layer_md_; - memory_desc_t diff_weights_iter_md_; - memory_desc_t diff_bias_md_; - memory_desc_t diff_dst_layer_md_; - memory_desc_t diff_dst_iter_md_; -}; - -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp deleted file mode 100644 index 6bc14fc72..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "scratchpad.hpp" - -namespace mkldnn { -namespace impl { - -/* Allocating memory buffers on a page boundary to reduce TLB/page misses */ -const size_t page_size = 2097152; - -/* - Implementation of the scratchpad_t interface that is compatible with - a concurrent execution -*/ -struct concurent_scratchpad_t : public scratchpad_t { - concurent_scratchpad_t(size_t size) { - size_ = size; - scratchpad_ = (char *) malloc(size, page_size); - assert(scratchpad_ != nullptr); - } - - ~concurent_scratchpad_t() { - free(scratchpad_); - } - - virtual char *get() const { - return scratchpad_; - } - -private: - char *scratchpad_; - size_t size_; -}; - -/* - Implementation of the scratchpad_t interface that uses a global - scratchpad -*/ - -struct global_scratchpad_t : public scratchpad_t { - global_scratchpad_t(size_t size) { - if (size > size_) { - if (scratchpad_ != nullptr) free(scratchpad_); - size_ = size; - scratchpad_ = (char *) malloc(size, page_size); - assert(scratchpad_ != nullptr); - } - reference_count_++; - } - - ~global_scratchpad_t() { - reference_count_--; - if (reference_count_ == 0) { - free(scratchpad_); - scratchpad_ = nullptr; - size_ = 0; - } - } - - virtual char *get() const { - return scratchpad_; - } - -private: - /* - Using thread-local here is unnecessary and even buggy! All threads - actually share the same scratchpad, which is created and queried only - on the main thread. If the scratchpad is queried on some thread other - than the one it was created on (e.g. the application calls the API from - multiple threads), thread-local causes a segfault because the scratchpad - is uninitialized on the current thread. - */ - /*thread_local*/ static char *scratchpad_; - /*thread_local*/ static size_t size_; - /*thread_local*/ static unsigned int reference_count_; -}; - -/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr; -/*thread_local*/ size_t global_scratchpad_t::size_ = 0; -/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0; - - -/* - Scratchpad creation routine -*/ -scratchpad_t *create_scratchpad(size_t size) { -#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC - return new global_scratchpad_t(size); -#else - return new concurent_scratchpad_t(size); -#endif -} - -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp deleted file mode 100644 index f7a246bc9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef COMMON_SCRATCHPAD_HPP -#define COMMON_SCRATCHPAD_HPP - -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct scratchpad_t { - virtual ~scratchpad_t() {} - virtual char *get() const = 0; -}; - -scratchpad_t *create_scratchpad(size_t size); - -} -} -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp deleted file mode 100644 index e32e73522..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind, - const memory_desc_t *data_desc, int axis, dim_t group_size) { - bool args_ok = true - && !any_null(shuffle_desc, data_desc) - && one_of(prop_kind, forward_training, forward_inference, - backward, backward_data) - && axis >= 0 && axis < data_desc->ndims - && group_size > 0 && group_size <= data_desc->dims[axis]; - if (!args_ok) return invalid_arguments; - - auto sd = shuffle_desc_t(); - sd.primitive_kind = primitive_kind::shuffle; - sd.prop_kind = prop_kind; - sd.data_desc = *data_desc; - sd.axis = axis; - sd.group_size = group_size; - - bool consistency = true - && sd.data_desc.dims[axis] % sd.group_size == 0; - if (!consistency) return invalid_arguments; - - *shuffle_desc = sd; - return success; -} -} - -status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc, - prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis, - dim_t group_size) { - if (!one_of(prop_kind, forward_training, forward_inference)) - return invalid_arguments; - return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis, - group_size); -} - -status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc, - const memory_desc_t *diff_data_desc, int axis, dim_t group_size) { - return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis, - group_size); -} - -// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp deleted file mode 100644 index cc5553fe7..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp +++ /dev/null @@ -1,121 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SHUFFLE_PD_HPP -#define SHUFFLE_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" - -namespace mkldnn { -namespace impl { - -struct shuffle_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::shuffle; - - typedef shuffle_pd_t base_class; - typedef shuffle_pd_t hint_class; - - shuffle_pd_t(engine_t *engine, - const shuffle_desc_t *adesc, - const primitive_attr_t *attr, - const shuffle_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , data_md_(desc_.data_desc) - {} - - const shuffle_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::shuffle_d: - *(const shuffle_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (is_fwd()) { - if (arg == MKLDNN_ARG_SRC) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - } else { - if (arg == MKLDNN_ARG_DIFF_DST) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - } - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 && is_fwd() ? &data_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 && is_fwd() ? &data_md_ : nullptr; } - - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } - - virtual int n_inputs() const override { return 1; } - virtual int n_outputs() const override { return 1; } - - /* shuffle aux functions */ - - dim_t MB() const { return data_md()->dims[0]; } - dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } - dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } - dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } - dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } - - int ndims() const { return data_md()->ndims; } - - int axis() const { return desc_.axis; } - dim_t group_size() const { return desc_.group_size; } - dim_t axis_size() const { return data_md()->dims[axis()]; } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - - const memory_desc_t *data_md() const { return &data_md_; } - -protected: - shuffle_desc_t desc_; - const shuffle_pd_t *hint_fwd_pd_; - memory_desc_t data_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp deleted file mode 100644 index 82848e3d1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "memory_desc_wrapper.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::alg_kind; -using namespace mkldnn::impl::types; - -namespace { -status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind, - const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) { - bool args_ok = true - && !any_null(softmax_desc, data_desc) - && 0 <= softmax_axis - && softmax_axis < data_desc->ndims; - if (!args_ok) return invalid_arguments; - - auto sd = softmax_desc_t(); - sd.primitive_kind = primitive_kind::softmax; - sd.prop_kind = prop_kind; - - bool is_bwd = (sd.prop_kind == backward_data); - sd.data_desc = *data_desc; - sd.diff_desc = is_bwd ? *diff_desc : zero_md(); - sd.softmax_axis = softmax_axis; - - *softmax_desc = sd; - return success; -} -} - -status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc, - prop_kind_t prop_kind, const memory_desc_t *data_desc, - int softmax_axis) { - if (!one_of(prop_kind, forward_inference, forward_training)) - return invalid_arguments; - return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis); -} - -status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc, - const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, - int softmax_axis) { - return softmax_desc_init(softmax_desc, prop_kind::backward_data, - data_desc, diff_desc, softmax_axis); -} -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp deleted file mode 100644 index 8a16ce901..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp +++ /dev/null @@ -1,161 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SOFTMAX_PD_HPP -#define SOFTMAX_PD_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "primitive_desc.hpp" - -namespace mkldnn { -namespace impl { - -struct softmax_fwd_pd_t; - -struct softmax_pd_t: public primitive_desc_t { - static constexpr auto base_pkind = primitive_kind::softmax; - - softmax_pd_t(engine_t *engine, - const softmax_desc_t *adesc, - const primitive_attr_t *attr, - const softmax_fwd_pd_t *hint_fwd_pd) - : primitive_desc_t(engine, attr, base_pkind) - , desc_(*adesc) - , hint_fwd_pd_(hint_fwd_pd) - , data_md_(desc_.data_desc) - {} - - const softmax_desc_t *desc() const { return &desc_; } - virtual const op_desc_t *op_desc() const override - { return reinterpret_cast(this->desc()); } - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual status_t query(query_t what, int idx, void *result) const override { - switch (what) { - case query::softmax_d: - *(const softmax_desc_t**)result = desc(); break; - default: return primitive_desc_t::query(what, idx, result); - } - return status::success; - } - - /* common softmax aux functions */ - - dim_t MB() const { return data_desc().dims[0]; } - dim_t C() const { return data_desc().dims[1]; } - dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } - dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } - dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } - - int ndims() const { return data_desc().ndims; } - - bool is_fwd() const { - return utils::one_of(desc_.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - } - -protected: - softmax_desc_t desc_; - const softmax_fwd_pd_t *hint_fwd_pd_; - - memory_desc_t data_md_; - -private: - const memory_desc_t &data_desc() const { return desc_.data_desc; } -}; - -struct softmax_fwd_pd_t: public softmax_pd_t { - typedef softmax_fwd_pd_t base_class; - typedef softmax_fwd_pd_t hint_class; - - softmax_fwd_pd_t(engine_t *engine, - const softmax_desc_t *adesc, - const primitive_attr_t *attr, - const softmax_fwd_pd_t *hint_fwd_pd) - : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg == MKLDNN_ARG_SRC) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - - virtual int n_inputs() const override { return 1; } - virtual int n_outputs() const override - { return 1 + (workspace_md() != nullptr); } -}; - -struct softmax_bwd_pd_t: public softmax_pd_t { - typedef softmax_bwd_pd_t base_class; - typedef softmax_fwd_pd_t hint_class; - - softmax_bwd_pd_t(engine_t *engine, - const softmax_desc_t *adesc, - const primitive_attr_t *attr, - const softmax_fwd_pd_t *hint_fwd_pd) - : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) - , diff_data_md_(desc_.diff_desc) - {} - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST)) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DIFF_SRC) - return arg_usage_t::output; - - if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) - return arg_usage_t::input; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &data_md_ : nullptr; } - virtual const memory_desc_t *diff_dst_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - virtual const memory_desc_t *diff_src_md(int index = 0) const override - { return index == 0 ? &diff_data_md_ : nullptr; } - - virtual int n_inputs() const override - { return 2 + (workspace_md() != nullptr); } - virtual int n_outputs() const override { return 1; } - -protected: - memory_desc_t diff_data_md_; -}; - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp deleted file mode 100644 index 00af8935c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "stream.hpp" -#include "utils.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::status; - -/* API */ - -status_t mkldnn_stream_create(stream_t **stream, engine_t *engine, - unsigned flags) { - bool args_ok = true - && !utils::any_null(stream, engine) - && flags == stream_flags::default_flags; - if (!args_ok) - return invalid_arguments; - - return safe_ptr_assign(*stream, new stream_t(engine, flags)); -} - -status_t mkldnn_stream_destroy(stream_t *stream) { - delete stream; - return success; -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp deleted file mode 100644 index f010e5f6e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef STREAM_HPP -#define STREAM_HPP - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" - -struct mkldnn_stream: public mkldnn::impl::c_compatible { - mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags) - : engine_(engine), flags_(flags) {} - virtual ~mkldnn_stream() {} - - /** returns stream's engine */ - mkldnn::impl::engine_t *engine() const { return engine_; } - - /** returns stream's kind */ - unsigned flags() const { return flags_; } - -protected: - mkldnn::impl::engine_t *engine_; - unsigned flags_; -}; - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp deleted file mode 100644 index 365663c0f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "engine.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "sum_pd.hpp" - -using namespace mkldnn::impl; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::status; - -status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd, - const memory_desc_t *dst_md, int n, const float *scales, - const memory_desc_t *src_mds, const primitive_attr_t *attr, - engine_t *engine) { - bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0; - if (!args_ok) return invalid_arguments; - - const primitive_attr_t dummy_attr; - if (attr == NULL) - attr = &dummy_attr; - - const int ndims = src_mds[0].ndims; - const dims_t &dims = src_mds[0].dims; - const data_type_t dt = src_mds[0].data_type; - - for (int i = 1; i < n; ++i) { - if (src_mds[i].ndims != ndims) return invalid_arguments; - for (int d = 0; d < ndims; ++d) { - if (src_mds[i].dims[d] != dims[d]) - return invalid_arguments; - } - if (src_mds[i].data_type != dt) return invalid_arguments; - } - - memory_desc_t dummy_dst_md; - if (dst_md) { - if (dst_md->ndims != ndims) return invalid_arguments; - for (int d = 0; d < ndims; ++d) { - if (dst_md->dims[d] != dims[d]) - return invalid_arguments; - } - } else { - dummy_dst_md = src_mds[0]; - dummy_dst_md.format_kind = format_kind::any; - dst_md = &dummy_dst_md; - } - - auto s_pd = reinterpret_cast(sum_pd); - - for (auto s = engine->get_sum_implementation_list(); *s; ++s) { - if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) { - (*s_pd)->init_info(); - (*s_pd)->init_scratchpad_md(); - return success; - } - } - return unimplemented; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp deleted file mode 100644 index 80254667d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp +++ /dev/null @@ -1,143 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SUM_PD_HPP -#define SUM_PD_HPP - -#include -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "primitive_desc.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -struct sum_pd_t: public primitive_desc_t { - sum_pd_t(engine_t *engine, const primitive_attr_t *attr, - const memory_desc_t *dst_md, int n, const float *scales, - const memory_desc_t *src_mds) - : primitive_desc_t(engine, attr, primitive_kind::sum) - , n_(n), dst_md_(*dst_md) - { - scales_.reserve(n_); - for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]); - src_mds_.reserve(n_); - for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); - } - - virtual void init_info() override { impl::init_info(this, this->info_); } - - virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { - if (arg >= MKLDNN_ARG_MULTIPLE_SRC - && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) - return arg_usage_t::input; - - if (arg == MKLDNN_ARG_DST) - return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - virtual const memory_desc_t *src_md(int index = 0) const override - { return index < n_inputs() ? &src_mds_[index] : nullptr; } - virtual const memory_desc_t *dst_md(int index = 0) const override - { return index == 0 ? &dst_md_ : nullptr; } - - virtual int n_inputs() const override { return n_; } - virtual int n_outputs() const override { return 1; } - - const float *scales() const { return &scales_[0]; } - -protected: - int n_; - nstl::vector scales_; - memory_desc_t dst_md_; - nstl::vector src_mds_; - -protected: - /* inits dst_md_ in simple cases. The call may fail. */ - status_t init() { - for (int i = 0; i < n_; ++i) { - const memory_desc_wrapper src_d(&src_mds_[i]); - if (!src_d.is_blocking_desc() || src_d.is_additional_buffer()) - return status::unimplemented; - } - bool ok = true - && set_default_params() == status::success - && attr()->has_default_values(); - return ok ? status::success : status::unimplemented; - } - - status_t set_default_params() { - if (dst_md_.format_kind != format_kind::any) - return status::success; - - /* The stupidest ever heuristics (but not the same as we had before): - * - Pick the first non-plain format; - * - If all formats are plain, pick the format of the first input - */ - for (int i = 0; i < n_; ++i) { - const memory_desc_wrapper src_d(src_mds_[i]); - if (!src_d.is_plain() && src_d.is_blocking_desc()) { - return memory_desc_init_by_blocking_desc(dst_md_, - src_d.blocking_desc()); - } - } - - if (src_mds_[0].format_kind != format_kind::blocked) - return status::unimplemented; - - dst_md_ = src_mds_[0]; - - return status::success; - } -}; - -#define DECLARE_SUM_PD_t(impl_name, ...) \ - static status_t create(sum_pd_t **sum_pd, \ - engine_t *engine, const primitive_attr_t *attr, \ - const memory_desc_t *dst_md, int n, const float *scales, \ - const memory_desc_t *src_mds) { \ - using namespace status; \ - auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \ - if (_pd == nullptr) return out_of_memory; \ - if (_pd->init() != success) { delete _pd; return unimplemented; } \ - return safe_ptr_assign(*sum_pd, _pd); \ - } \ - virtual status_t create_primitive(primitive_t **p) const override { \ - double ms = get_msec(); \ - auto ret = safe_ptr_assign(*p, new (__VA_ARGS__)(this)); \ - ms = get_msec() - ms; \ - if (mkldnn_verbose()->level >= 2) { \ - printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ - fflush(0); \ - } \ - return ret; \ - } \ - virtual pd_t *clone() const override { return new pd_t(*this); } \ - virtual const char *name() const override { return impl_name; } \ - -#define DECLARE_SUM_PD_T(impl_name, ...) \ - DECLARE_SUM_PD_t(impl_name, __VA_ARGS__) - -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp deleted file mode 100644 index a408f4598..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp +++ /dev/null @@ -1,200 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef TAG_TRAITS_HPP -#define TAG_TRAITS_HPP - -#include - -#include "c_types_map.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -enum class block_dim_t { - _, - _A, _B, - _AB, _BC, -}; - -enum class inner_blk_t { - _, - _4a, _4b, - _8a, _8b, - _16a, _16b, - - _4b4a, _4b4c, _4c4b, - _8a8b, _8b8a, _8b8c, _8c8b, - _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b, - - _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c, -}; - -/** returns the offset within the block for weights blocked over oc and ic */ -template -constexpr int AB_or_BC_blk_off(int x0, int x1) { - using ib = inner_blk_t; - static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b, - ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b, - ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c, - ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, - ib::_4c16b4c, ib::_8c16b2c), - "unexpected inner_blk format"); - return false ? 0 - : (f == ib::_4b4c) ? 4 * x0 + x1 - : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 - : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 - : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 - : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 - : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 - : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 - : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 - : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 - : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 - : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 - : INT_MIN; -} - -template struct inner_blk_traits { - using ib = inner_blk_t; -}; - -template struct tag_traits { - // block_dim_t block_dims; - // inner_blk_t inner_blks; - // int ndims; -}; - -#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ -template <> struct tag_traits { \ - static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ - static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ - static constexpr int ndims = _ndims; \ -} - -DECL_TRAITS(a, _, _, 1); -DECL_TRAITS(ab, _, _, 2); -DECL_TRAITS(abc, _, _, 3); -DECL_TRAITS(abcd, _, _, 4); -DECL_TRAITS(abcde, _, _, 5); -DECL_TRAITS(abcdef, _, _, 6); -DECL_TRAITS(abdec, _, _, 5); -DECL_TRAITS(acb, _, _, 3); -DECL_TRAITS(acbde, _, _, 5); -DECL_TRAITS(acdb, _, _, 4); -DECL_TRAITS(acdeb, _, _, 5); -DECL_TRAITS(ba, _, _, 2); -DECL_TRAITS(bac, _, _, 3); -DECL_TRAITS(bacd, _, _, 4); -DECL_TRAITS(bcda, _, _, 4); -DECL_TRAITS(cba, _, _, 3); -DECL_TRAITS(cdba, _, _, 4); -DECL_TRAITS(cdeba, _, _, 5); -DECL_TRAITS(decab, _, _, 5); - -DECL_TRAITS(Abc4a, _A, _4a, 3); -DECL_TRAITS(aBc4b, _B, _4b, 3); -DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3); -DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3); -DECL_TRAITS(Abcd4a, _A, _4a, 4); -DECL_TRAITS(aBcd4b, _B, _4b, 4); -DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4); -DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4); -DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4); -DECL_TRAITS(Abcde4a, _A, _4a, 5); -DECL_TRAITS(aBcde4b, _B, _4b, 5); -DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5); -DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5); -DECL_TRAITS(aBcdef4b, _B, _4b, 6); -DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6); -DECL_TRAITS(aBdc4b, _B, _4b, 4); -DECL_TRAITS(aBdec4b, _B, _4b, 5); -DECL_TRAITS(aBdefc4b, _B, _4b, 6); -DECL_TRAITS(Acb4a, _A, _4a, 3); -DECL_TRAITS(Acdb4a, _A, _4a, 4); -DECL_TRAITS(Acdeb4a, _A, _4a, 5); - -DECL_TRAITS(Abc16a, _A, _16a, 3); -DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3); -DECL_TRAITS(aBc16b, _B, _16b, 3); -DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3); -DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3); -DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3); -DECL_TRAITS(aBc8b, _B, _8b, 3); -DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3); -DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3); -DECL_TRAITS(Abcd16a, _A, _16a, 4); -DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4); -DECL_TRAITS(aBcd16b, _B, _16b, 4); -DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4); -DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4); -DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4); -DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4); -DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); -DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); -DECL_TRAITS(aBcd8b, _B, _8b, 4); -DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); -DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4); -DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4); -DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4); -DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4); -DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4); -DECL_TRAITS(Abcde16a, _A, _16a, 5); -DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5); -DECL_TRAITS(aBcde16b, _B, _16b, 5); -DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5); -DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5); -DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5); -DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5); -DECL_TRAITS(Abcde8a, _A, _8a, 5); -DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5); -DECL_TRAITS(aBcde8b, _B, _8b, 5); -DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5); -DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5); -DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5); -DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5); -DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5); -DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5); -DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5); -DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5); -DECL_TRAITS(aBcdef16b, _B, _16b, 6); -DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6); -DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6); -DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6); -DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6); -DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6); -DECL_TRAITS(aBdc16b, _B, _16b, 4); -DECL_TRAITS(aBdc8b, _B, _8b, 4); -DECL_TRAITS(aBdec16b, _B, _16b, 5); -DECL_TRAITS(aBdec8b, _B, _8b, 5); -DECL_TRAITS(aBdefc16b, _B, _16b, 6); -DECL_TRAITS(aBdefc8b, _B, _8b, 6); -DECL_TRAITS(Acb16a, _A, _16a, 3); -DECL_TRAITS(Acb8a, _A, _8a, 3); -DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); -DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); -DECL_TRAITS(Acdb16a, _A, _16a, 4); -DECL_TRAITS(Acdb8a, _A, _8a, 4); -DECL_TRAITS(Acdeb16a, _A, _16a, 5); -DECL_TRAITS(Acdeb8a, _A, _8a, 5); -DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); -DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); - -} // namespace impl -} // namespace mkldnn - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp deleted file mode 100644 index 4f0636873..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp +++ /dev/null @@ -1,348 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef TYPE_HELPERS_HPP -#define TYPE_HELPERS_HPP - -#include -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "mkldnn_traits.hpp" -#include "nstl.hpp" -#include "utils.hpp" -#include "math_utils.hpp" - -namespace mkldnn { -namespace impl { - -template -status_t safe_ptr_assign(T * &lhs, T* rhs) { - if (rhs == nullptr) return status::out_of_memory; - lhs = rhs; - return status::success; -} - -template struct is_subset -{ static constexpr bool value = false; }; -template struct is_subset -{ static constexpr bool value = true; }; -template struct is_subset::value, float>::type> -{ static constexpr bool value = true; }; -#define ISSPEC(t1, t2) template <> \ - struct is_subset { static constexpr bool value = true; } -ISSPEC(int16_t, int32_t); -ISSPEC(int8_t, int32_t); -ISSPEC(uint8_t, int32_t); -ISSPEC(int8_t, int16_t); -ISSPEC(uint8_t, int16_t); -#undef ISSPEC - -inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); - -namespace types { - -inline size_t data_type_size(data_type_t data_type) { - using namespace data_type; - switch (data_type) { - case f32: return sizeof(prec_traits::type); - case s32: return sizeof(prec_traits::type); - case s8: return sizeof(prec_traits::type); - case u8: return sizeof(prec_traits::type); - case data_type::undef: - default: assert(!"unknown data_type"); - } - return 0; /* not supposed to be reachable */ -} - -inline format_kind_t format_tag_to_kind(format_tag_t tag) { - switch (tag) { - case format_tag::undef: return format_kind::undef; - case format_tag::any: return format_kind::any; - case format_tag::last: return format_kind::undef; - default: return format_kind::blocked; - } - - assert(!"unreachable"); - return format_kind::undef; -} - -inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs, - const memory_extra_desc_t &rhs) { - return true - && lhs.flags == rhs.flags - && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8, - lhs.compensation_mask == rhs.compensation_mask) - && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust, - lhs.scale_adjust == rhs.scale_adjust); -} - -inline bool blocking_desc_is_equal(const blocking_desc_t &lhs, - const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) { - using mkldnn::impl::utils::array_cmp; - return true - && lhs.inner_nblks == rhs.inner_nblks - && array_cmp(lhs.strides, rhs.strides, ndims) - && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) - && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); -} - -inline bool wino_desc_is_equal(const wino_desc_t &lhs, - const wino_desc_t &rhs) { - return lhs.wino_format == rhs.wino_format - && lhs.alpha == rhs.alpha - && lhs.ic == rhs.ic - && lhs.oc == rhs.oc - && lhs.ic_block == rhs.ic_block - && lhs.oc_block == rhs.oc_block - && lhs.ic2_block == rhs.ic2_block - && lhs.oc2_block == rhs.oc2_block - && lhs.r == rhs.r; -} - -inline bool rnn_packed_desc_is_equal( - const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { - bool ok = true - && lhs.format == rhs.format - && lhs.n_parts == rhs.n_parts - && lhs.offset_compensation == rhs.offset_compensation - && lhs.size == rhs.size - && lhs.n == rhs.n; - if (!ok) - return false; - - for (int i = 0; i < rhs.n_parts; i++) - ok = ok && lhs.parts[i] == rhs.parts[i]; - for (int i = 0; i < rhs.n_parts; i++) - ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; - return ok; -} - -inline memory_desc_t zero_md() { - auto zero = memory_desc_t(); - return zero; -} - -inline bool is_zero_md(const memory_desc_t *md) { - return md == nullptr || *md == zero_md(); -} - -inline data_type_t default_accum_data_type(data_type_t src_dt, - data_type_t dst_dt) { - using namespace utils; - using namespace data_type; - - if (one_of(f32, src_dt, dst_dt)) return f32; - if (one_of(s32, src_dt, dst_dt)) return s32; - - if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; - - assert(!"unimplemented use-case: no default parameters available"); - return dst_dt; -} - -inline data_type_t default_accum_data_type(data_type_t src_dt, - data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { - using namespace utils; - using namespace data_type; - using namespace prop_kind; - - /* prop_kind doesn't matter */ - if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32; - - if (one_of(prop_kind, forward_training, forward_inference)) { - if ((src_dt == u8 || src_dt == s8) - && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8)) - return s32; - } else if (prop_kind == backward_data) { - if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && - one_of(dst_dt, s8, u8)) - return s32; - } - - assert(!"unimplemented use-case: no default parameters available"); - return dst_dt; -} - -} - -inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { - using namespace mkldnn::impl::utils; - bool base_equal = true - && lhs.ndims == rhs.ndims - && array_cmp(lhs.dims, rhs.dims, lhs.ndims) - && lhs.data_type == rhs.data_type - && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) - && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) - && lhs.offset0 == rhs.offset0 - && lhs.format_kind == rhs.format_kind; - if (!base_equal) return false; - if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; - if (lhs.format_kind == format_kind::blocked) - return types::blocking_desc_is_equal(lhs.format_desc.blocking, - rhs.format_desc.blocking, lhs.ndims); - else if (lhs.format_kind == format_kind::wino) - return types::wino_desc_is_equal(lhs.format_desc.wino_desc, - rhs.format_desc.wino_desc); - else if (lhs.format_kind == format_kind::rnn_packed) - return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, - rhs.format_desc.rnn_packed_desc); - return true; -} - -inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { - return !operator==(lhs, rhs); -} - -inline status_t memory_desc_init_by_strides(memory_desc_t &md, - const dims_t strides) { - return mkldnn_memory_desc_init_by_strides( - &md, md.ndims, md.dims, md.data_type, strides); -} - -inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag, - const dims_t strides = nullptr) { - status_t status = mkldnn_memory_desc_init_by_tag( - &md, md.ndims, md.dims, md.data_type, tag); - if (status != status::success || strides == nullptr) - return status; - - /* TODO: add consistency check */ - - for (int d = 0; d < md.ndims; ++d) - md.format_desc.blocking.strides[d] = strides[d]; - - return status::success; -} - -/** inits memory descriptor based on logical dimensions kept in @p md, and the - * blocking structure @p blk. - * - * @note blk.strides represent the order only (from smaller to bigger) - * - * TODO: move md related functions to one single place - */ -inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md, - const blocking_desc_t &blk) { - dims_t blocks = {0}; - utils::array_set(blocks, 1, md.ndims); - dim_t block_size = 1; - for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { - blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; - block_size *= blk.inner_blks[iblk]; - } - - for (int d = 0; d < md.ndims; ++d) { - md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); - md.padded_offsets[d] = 0; - } - md.offset0 = 0; - - md.format_kind = format_kind::blocked; - auto &mblk = md.format_desc.blocking; - mblk = blk; - - const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy - utils::array_copy(mblk.strides, blk.strides, ndims); - - int perm[MKLDNN_MAX_NDIMS]; - for (int d = 0; d < ndims; ++d) perm[d] = d; - - utils::simultaneous_sort(mblk.strides, perm, ndims, - [](stride_t a, stride_t b) { return b - a; }); - - dim_t stride = block_size; - for (int _d = ndims - 1; _d >= 0; --_d) { - const int d = perm[_d]; - md.format_desc.blocking.strides[d] = stride; - stride *= md.padded_dims[d] / blocks[d]; - } - - md.extra = utils::zero(); - - return status::success; -} - -/** returns true if memory desc @p md corresponds to the given format tag and - * strides. - * If strides are not passed (or passed as nullptr) the dense structure is - * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns). - * Strides might contain `0` value, indicating the stride must match the one - * that mkldnn_memory_desc_init_by_tag() returns. - * Strides might contain `-1` values, that would be ignored during the - * comparison. For instance, this can be used if a stride along minibatch - * doesn't matter. */ -inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, - const dims_t strides = nullptr) { - if (md.format_kind != types::format_tag_to_kind(tag)) - return false; - - memory_desc_t md_gold; - status_t status = mkldnn_memory_desc_init_by_tag( - &md_gold, md.ndims, md.dims, md.data_type, tag); - if (status != status::success) return false; - - if (md.format_kind != format_kind::blocked) - return false; // unimplemented yet - - const auto &blk = md.format_desc.blocking; - const auto &blk_gold = md_gold.format_desc.blocking; - - using utils::array_cmp; - bool same_blocks = true - && blk.inner_nblks == blk_gold.inner_nblks - && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) - && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); - - if (!same_blocks) - return false; - - if (strides == nullptr) - return array_cmp(blk.strides, blk_gold.strides, md.ndims); - - for (int d = 0; d < md.ndims; ++d) { - dim_t stride = strides[d]; - if (stride == -1) continue; - if (stride == 0) stride = blk_gold.strides[d]; - if (blk.strides[d] != stride) return false; - } - - return true; -} - -/** returns matching tag (or undef if match is not found) - * XXX: This is a workaround that eventually should go away! */ -template -format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md, - Tags ...tags) { - for (const auto tag: {tags...}) { - if (memory_desc_matches_tag(md, tag)) - return tag; - } - return format_tag::undef; -} - -} -} - -#include "memory_desc_wrapper.hpp" - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp deleted file mode 100644 index d23f4682d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp +++ /dev/null @@ -1,135 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#ifdef _WIN32 -#include -#include -#endif -#include -#include -#include - -#include "mkldnn.h" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { - -int getenv(const char *name, char *buffer, int buffer_size) { - if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0)) - return INT_MIN; - - int result = 0; - int term_zero_idx = 0; - size_t value_length = 0; - -#ifdef _WIN32 - value_length = GetEnvironmentVariable(name, buffer, buffer_size); -#else - const char *value = ::getenv(name); - value_length = value == NULL ? 0 : strlen(value); -#endif - - if (value_length > INT_MAX) - result = INT_MIN; - else { - int int_value_length = (int)value_length; - if (int_value_length >= buffer_size) { - result = -int_value_length; - } else { - term_zero_idx = int_value_length; - result = int_value_length; -#ifndef _WIN32 - strncpy(buffer, value, value_length); -#endif - } - } - - if (buffer != NULL) - buffer[term_zero_idx] = '\0'; - return result; -} - -int getenv_int(const char *name, int default_value) -{ - int value = default_value; - // # of digits in the longest 32-bit signed int + sign + terminating null - const int len = 12; - char value_str[len]; - if (getenv(name, value_str, len) > 0) - value = atoi(value_str); - return value; -} - -FILE *fopen(const char *filename, const char *mode) { -#ifdef _WIN32 - FILE *fp = NULL; - return ::fopen_s(&fp, filename, mode) ? NULL : fp; -#else - return ::fopen(filename, mode); -#endif -} - -void *malloc(size_t size, int alignment) { - void *ptr; - -#ifdef _WIN32 - ptr = _aligned_malloc(size, alignment); - int rc = ptr ? 0 : -1; -#else - int rc = ::posix_memalign(&ptr, alignment, size); -#endif - - return (rc == 0) ? ptr : 0; -} - -void free(void *p) { -#ifdef _WIN32 - _aligned_free(p); -#else - ::free(p); -#endif -} - -// Atomic operations -int32_t fetch_and_add(int32_t *dst, int32_t val) { -#ifdef _WIN32 - return InterlockedExchangeAdd(reinterpret_cast(dst), val); -#else - return __sync_fetch_and_add(dst, val); -#endif -} - -static int jit_dump_flag = 0; -static bool jit_dump_flag_initialized = false; -bool jit_dump_enabled() { - if (!jit_dump_flag_initialized) { - jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP"); - jit_dump_flag_initialized = true; - } - return jit_dump_flag != 0; -} - -} -} - -mkldnn_status_t mkldnn_set_jit_dump(int enabled) { - using namespace mkldnn::impl::status; - mkldnn::impl::jit_dump_flag = enabled; - mkldnn::impl::jit_dump_flag_initialized = true; - return success; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp deleted file mode 100644 index d5a8ec513..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp +++ /dev/null @@ -1,370 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef UTILS_HPP -#define UTILS_HPP - -#include -#include -#include -#include -#include - -#if defined(__x86_64__) || defined(_M_X64) -#define MKLDNN_X86_64 -#endif - -#define MSAN_ENABLED 0 -#if defined(__has_feature) -#if __has_feature(memory_sanitizer) -#undef MSAN_ENABLED -#define MSAN_ENABLED 1 -#include -#endif -#endif - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "z_magic.hpp" - -namespace mkldnn { -namespace impl { - -// Sanity check for 64 bits -static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only"); - -#define CHECK(f) do { \ - status_t status = f; \ - if (status != status::success) \ - return status; \ -} while (0) - -#define IMPLICATION(cause, effect) (!(cause) || !!(effect)) - -namespace utils { - -/* a bunch of std:: analogues to be compliant with any msvs version - * - * Rationale: msvs c++ (and even some c) headers contain special pragma that - * injects msvs-version check into object files in order to abi-mismatches - * during the static linking. This makes sense if e.g. std:: objects are passed - * through between application and library, which is not the case for mkl-dnn - * (since there is no any c++-rt dependent stuff, ideally...). */ - -/* SFINAE helper -- analogue to std::enable_if */ -template struct enable_if {}; -template struct enable_if { typedef T type; }; - -/* analogue std::conditional */ -template struct conditional {}; -template struct conditional -{ typedef T type; }; -template struct conditional -{ typedef F type; }; - -template struct conditional3 {}; -template -struct conditional3 { typedef T type; }; -template -struct conditional3 { typedef FT type; }; -template -struct conditional3 { typedef FF type; }; - -template struct conditional_v {}; -template struct conditional_v -{ static constexpr U value = t; }; -template struct conditional_v -{ static constexpr U value = f; }; - -template struct remove_reference { typedef T type; }; -template struct remove_reference { typedef T type; }; -template struct remove_reference { typedef T type; }; - -template -inline T&& forward(typename utils::remove_reference::type &t) -{ return static_cast(t); } -template -inline T&& forward(typename utils::remove_reference::type &&t) -{ return static_cast(t); } - -template -inline typename remove_reference::type zero() -{ auto zero = typename remove_reference::type(); return zero; } - -template -inline bool everyone_is(T val, P item) { return val == item; } -template -inline bool everyone_is(T val, P item, Args... item_others) { - return val == item && everyone_is(val, item_others...); -} - -template -constexpr bool one_of(T val, P item) { return val == item; } -template -constexpr bool one_of(T val, P item, Args... item_others) { - return val == item || one_of(val, item_others...); -} - -template -inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } - -template -inline void array_copy(T *dst, const T *src, size_t size) { - for (size_t i = 0; i < size; ++i) dst[i] = src[i]; -} -template -inline bool array_cmp(const T *a1, const T *a2, size_t size) { - for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false; - return true; -} -template -inline void array_set(T *arr, const U& val, size_t size) { - for (size_t i = 0; i < size; ++i) arr[i] = static_cast(val); -} - -namespace product_impl { -template struct int2type{}; - -template -constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; } - -template -inline T product_impl(const T *arr, int2type) { - return arr[0]*product_impl(arr+1, int2type()); } -} - -template -inline T array_product(const T *arr) { - return product_impl::product_impl(arr, product_impl::int2type()); -} - -template -inline R array_product(const T *arr, size_t size) { - R prod = 1; - for (size_t i = 0; i < size; ++i) prod *= arr[i]; - return prod; -} - -/** sorts an array of values using @p comparator. While sorting the array - * of value, the function permutes an array of @p keys accordingly. - * - * @note The arrays of @p keys can be omitted. In this case the function - * sorts the array of @vals only. - */ -template -inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) { - if (size == 0) return; - - for (size_t i = 0; i < size - 1; ++i) { - bool swapped = false; - - for (size_t j = 0; j < size - i - 1; j++) { - if (comparator(vals[j], vals[j + 1]) > 0) { - nstl::swap(vals[j], vals[j + 1]); - if (keys) nstl::swap(keys[j], keys[j + 1]); - swapped = true; - } - } - - if (swapped == false) break; - } -} - -template -inline typename remove_reference::type div_up(const T a, const U b) { - assert(b); - return (a + b - 1) / b; -} - -template -inline typename remove_reference::type rnd_up(const T a, const U b) { - return div_up(a, b) * b; -} - -template -inline typename remove_reference::type rnd_dn(const T a, const U b) { - return (a / b) * b; -} - -template T *align_ptr(T *ptr, uintptr_t alignment) -{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); } - -template -inline U this_block_size(const T offset, const U max, const V block_size) { - assert(offset < max); - // TODO (Roma): can't use nstl::max() due to circular dependency... we - // need to fix this - const T block_boundary = offset + block_size; - if (block_boundary > max) - return max - offset; - else - return block_size; -} - -template -inline T nd_iterator_init(T start) { return start; } -template -inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) { - start = nd_iterator_init(start, utils::forward(tuple)...); - x = start % X; - return start / X; -} - -inline bool nd_iterator_step() { return true; } -template -inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) { - if (nd_iterator_step(utils::forward(tuple)...) ) { - x = (x + 1) % X; - return x == 0; - } - return false; -} - -template -inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X) -{ - U max_jump = end - cur; - U dim_jump = X - x; - if (dim_jump <= max_jump) { - x = 0; - cur += dim_jump; - return true; - } else { - cur += max_jump; - x += max_jump; - return false; - } -} -template -inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X, - Args &&... tuple) -{ - if (nd_iterator_jump(cur, end, utils::forward(tuple)...)) { - x = (x + 1) % X; - return x == 0; - } - return false; -} - -template -inline T pick(size_t i, const T &x0) { return x0; } -template -inline T pick(size_t i, const T &x0, Args &&... args) { - return i == 0 ? x0 : pick(i - 1, utils::forward(args)...); -} - -template -T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference, - const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) { - switch (prop_kind) { - case prop_kind::forward_inference: return val_fwd_inference; - case prop_kind::forward_training: return val_fwd_training; - case prop_kind::backward_data: return val_bwd_d; - case prop_kind::backward_weights: return val_bwd_w; - default: assert(!"unsupported prop_kind"); - } - return T(); -} - -template -T pick_by_prop_kind(prop_kind_t prop_kind, - const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w) -{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); } - -template -struct array_offset_calculator { - template - array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... } - { - _base_ptr = base; - } - template - inline Telem &operator()(Targs... Fargs) - { - return *(_base_ptr + _offset(1, Fargs...)); - } - -private: - template - inline size_t _offset(size_t const dimension, size_t element) - { - return element; - } - - template - inline size_t _offset(size_t const dimension, size_t theta, size_t element) - { - return element + (_dims[dimension] * theta); - } - - template - inline size_t _offset(size_t const dimension, size_t theta, size_t element, - Targs... Fargs) - { - size_t t_prime = element + (_dims[dimension] * theta); - return _offset(dimension + 1, t_prime, Fargs...); - } - - Telem *_base_ptr; - const int _dims[Tdims]; -}; - -} - -int32_t fetch_and_add(int32_t *dst, int32_t val); -inline void yield_thread() {} - -// Reads an environment variable 'name' and stores its string value in the -// 'buffer' of 'buffer_size' bytes on success. -// -// - Returns the length of the environment variable string value (excluding -// the terminating 0) if it is set and its contents (including the terminating -// 0) can be stored in the 'buffer' without truncation. -// -// - Returns negated length of environment variable string value and writes -// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to -// store the value (including the terminating 0) without truncation. -// -// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment -// variable is not set. -// -// - Returns INT_MIN if the 'name' is NULL. -// -// - Returns INT_MIN if the 'buffer_size' is negative. -// -// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than -// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to -// retrieve the length of the environment variable value string. -// -int getenv(const char *name, char *buffer, int buffer_size); -// Reads an integer from the environment -int getenv_int(const char *name, int default_value = 0); -bool jit_dump_enabled(); -FILE *fopen(const char *filename, const char *mode); - -constexpr int msan_enabled = MSAN_ENABLED; -inline void msan_unpoison(void *ptr, size_t size) { -#if MSAN_ENABLED - __msan_unpoison(ptr, size); -#endif -} - -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp deleted file mode 100644 index 89a57772c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp +++ /dev/null @@ -1,665 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#ifndef _WIN32 -#include -#endif - -#include "mkldnn.h" -#include "mkldnn_version.h" -#include "c_types_map.hpp" -#include "verbose.hpp" -#include "cpu/cpu_isa_traits.hpp" - -#include "batch_normalization_pd.hpp" -#include "pooling_pd.hpp" -#include "concat_pd.hpp" -#include "reorder_pd.hpp" -#include "convolution_pd.hpp" -#include "rnn_pd.hpp" -#include "deconvolution_pd.hpp" -#include "shuffle_pd.hpp" -#include "eltwise_pd.hpp" -#include "softmax_pd.hpp" -#include "inner_product_pd.hpp" -#include "sum_pd.hpp" -#include "lrn_pd.hpp" - -/* MKL-DNN CPU ISA info */ -#define ISA_ANY "No instruction set specific optimizations" -#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)" -#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)" -#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)" -#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ - "AVX-512)" -#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ - "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions" -#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \ - "AVX512-DL Boost)" -#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ - "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions" -#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ - "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions" - -namespace mkldnn { -namespace impl { - -static verbose_t verbose; -static bool initialized; -static bool version_printed = false; - -const verbose_t *mkldnn_verbose() { -#if !defined(DISABLE_VERBOSE) - if (!initialized) { - const int len = 2; - char val[len] = {0}; - if (getenv("MKLDNN_VERBOSE", val, len) == 1) - verbose.level = atoi(val); - initialized = true; - } - if (!version_printed && verbose.level > 0) { - printf("mkldnn_verbose,info," - "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n", - mkldnn_version()->major, mkldnn_version()->minor, - mkldnn_version()->patch, mkldnn_version()->hash, - get_isa_info()); - version_printed = true; - } -#else - verbose.level = 0; -#endif - return &verbose; -} - -double get_msec() { -#ifdef _WIN32 - static LARGE_INTEGER frequency; - if (frequency.QuadPart == 0) - QueryPerformanceFrequency(&frequency); - LARGE_INTEGER now; - QueryPerformanceCounter(&now); - return 1e+3 * now.QuadPart / frequency.QuadPart; -#else - struct timeval time; - gettimeofday(&time, NULL); - return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; -#endif -} - -const char *get_isa_info() { - using namespace mkldnn::impl::cpu; - if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS; - if (mayiuse(avx512_mic)) return AVX512_MIC; - if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI; - if (mayiuse(avx512_core)) return AVX512_CORE; - if (mayiuse(avx512_common)) return AVX512_COMMON; - if (mayiuse(avx2)) return AVX2; - if (mayiuse(avx)) return AVX; - if (mayiuse(sse42)) return SSE42; - return ISA_ANY; -} - -/* init_info section */ -namespace { -#if !defined(DISABLE_VERBOSE) -#define MKLDNN_VERBOSE_DAT_LEN 256 -#define MKLDNN_VERBOSE_AUX_LEN 384 -#define MKLDNN_VERBOSE_PRB_LEN 384 - -#define DECL_DAT_AUX_PRB_STRS() \ - int dat_written = 0, aux_written = 0, prb_written = 0; \ - MAYBE_UNUSED((dat_written * aux_written * prb_written)); \ - char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \ - char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \ - char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str) - -#define DFMT "%" PRId64 - -void clear_buf(char *buf, int &written) { - /* TODO: do it better */ - buf[0] = '#'; - buf[1] = '\0'; - written = 1; -} - -#define DPRINT(buf, buf_len, written, ...) do { \ - int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \ - if (l < 0 || written + l > buf_len) { \ - clear_buf(buf, written); \ - } else { \ - written += l; \ - } \ -} while(0) - -// XXX: Outputs strings corresponding to memory formats used for data tensors. -void format_prb_desc_str(char *str, int len, const memory_desc_t *md) { - const auto dims = md->dims; - int written = 0; - if (md->ndims == 1) - DPRINT(str, len, written, - "x" DFMT, dims[0]); - else if (md->ndims == 2) - DPRINT(str, len, written, - "mb" DFMT "ic" DFMT, dims[0], dims[1]); - else if (md->ndims == 3) - DPRINT(str, len, written, - "mb" DFMT "ic" DFMT "iw" DFMT, - dims[0], dims[1], dims[2]); - else if (md->ndims == 4) - DPRINT(str, len, written, - "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT, - dims[0], dims[1], dims[2], dims[3]); - else if (md->ndims == 5) - DPRINT(str, len, written, - "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT, - dims[0], dims[1], dims[2], dims[3], dims[4]); - else - mkldnn_md2dim_str(str, len, md); -} - -void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind, - const char *impl_str, mkldnn_prop_kind_t prop_kind, - const char *data_str, const char *aux_str, const char *prb_str) { - MAYBE_UNUSED(verbose_templ); - int written = 0; - DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s", - mkldnn_prim_kind2str(prim_kind), impl_str, - mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str); -} - -template static void init_info_bnorm(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // data - auto md = s->src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // diff data - auto md = s->diff_src_md(); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "flags:%u", s->desc()->flags); - - format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_conv(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // src - auto md = s->desc()->prop_kind == prop_kind::backward_data - ? s->diff_src_md() : s->src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // wei - auto md = s->desc()->prop_kind == prop_kind::backward_weights - ? s->diff_weights_md() : s->weights_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // bia - auto md = s->desc()->prop_kind == prop_kind::backward_weights - ? s->diff_weights_md(1) : s->weights_md(1); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - if (1) { // dst - auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); - - if (s->ndims() == 5) { - if (s->with_groups()) - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT - "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT - "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT - "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, - s->MB(), s->G(), s->IC(), s->OC(), - s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), - s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), - s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); - else - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "_ic" DFMT "oc" DFMT - "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT - "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT - "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, - s->MB(), s->IC(), s->OC(), - s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), - s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), - s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); - } else { - if (s->with_groups()) - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT - "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT - "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, - s->MB(), s->G(), s->IC(), s->OC(), - s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), - s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); - else - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "_ic" DFMT "oc" DFMT - "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT - "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, - s->MB(), s->IC(), s->OC(), - s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), - s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); - } - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_shuffle(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md(); - - if (1) { // data - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "axis:%d group_size:" DFMT, s->axis(), s->group_size()); - - mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_eltwise(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // data - auto md = s->src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // diff data - auto md = s->diff_src_md(); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); - - mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_iprod(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // src - auto md = s->desc()->prop_kind == prop_kind::backward_data - ? s->diff_src_md() : s->src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // wei - auto md = s->desc()->prop_kind == prop_kind::backward_weights - ? s->diff_weights_md() : s->weights_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // bia - auto md = s->desc()->prop_kind == prop_kind::backward_weights - ? s->diff_weights_md(1) : s->weights_md(1); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - if (1) { // dst - auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC()); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_lrn(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // data - auto md = s->src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // diff data - auto md = s->diff_src_md(); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); - - format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_mem(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // src - auto md = s->src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // dst - auto md = s->dst_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "num:%d", s->n_inputs()); - - mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); - - verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str, - aux_str, prb_str); -} - -template static void init_info_pool(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // src - auto md = s->is_fwd() ? s->src_md() : s->diff_src_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // dst - auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // ws - auto md = s->workspace_md(); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); - - if (s->is_3d()) { - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "ic" DFMT "_" - "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_" - "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" - "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "", - s->MB(), s->C(), - s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(), - s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), - s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); - } else { - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "mb" DFMT "ic" DFMT "_" - "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" - "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT, - s->MB(), s->C(), - s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), - s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); - } - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_softmax(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // data - auto md = s->dst_md(); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // diff data - auto md = s->diff_src_md(); - if (md) { - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - } - - mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -template static void init_info_rnn(pd_t *s, char *buffer) { - DECL_DAT_AUX_PRB_STRS(); - - if (1) { // src layer - auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // src iter - auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // wei_layer - auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // wei_iter - auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // bias - auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // dst layer - auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - if (1) { // dst iter - auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1); - DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_"); - int l = mkldnn_md2fmt_str(dat_str + dat_written, - MKLDNN_VERBOSE_DAT_LEN - dat_written, md); - if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); - } - - alg_kind_t alg_kind = s->cell_kind(); - rnn_direction_t rnn_dir = s->direction(); - DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, - "alg:%s_%s", mkldnn_alg_kind2str(alg_kind), - mkldnn_rnn_direction2str(rnn_dir)); - - DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, - "l" DFMT "t" DFMT "mb" DFMT - "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT, - s->L(), s->T(), s->MB(), - s->SIC(), s->SLC(), s->DIC(), s->DLC()); - - verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, - aux_str, prb_str); -} - -#undef DPRINT - -#else // !defined(DISABLE_VERBOSE) - -#define DEFINE_STUB(name) \ - template \ - static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \ - { UNUSED(s); UNUSED(buffer); } - -DEFINE_STUB(bnorm); -DEFINE_STUB(conv); -DEFINE_STUB(eltwise); -DEFINE_STUB(iprod); -DEFINE_STUB(lrn); -DEFINE_STUB(mem); -DEFINE_STUB(pool); -DEFINE_STUB(softmax); -DEFINE_STUB(rnn); -DEFINE_STUB(shuffle); -#undef DEFINE_STUB - -#endif // !defined(DISABLE_VERBOSE) -} - -void init_info(batch_normalization_pd_t *s, char *b) -{ init_info_bnorm(s, b); } -void init_info(concat_pd_t *s, char *b) -{ init_info_mem(s, b); } -void init_info(convolution_pd_t *s, char *b) -{ init_info_conv(s, b); } -void init_info(deconvolution_pd_t *s, char *b) -{ init_info_conv(s, b); } -void init_info(eltwise_pd_t *s, char *b) -{ init_info_eltwise(s, b); } -void init_info(inner_product_pd_t *s, char *b) -{ init_info_iprod(s, b); } -void init_info(lrn_pd_t *s, char *b) -{ init_info_lrn(s, b); } -void init_info(pooling_pd_t *s, char *b) -{ init_info_pool(s, b); } -void init_info(reorder_pd_t *s, char *b) -{ init_info_mem(s, b); } -void init_info(rnn_pd_t *s, char *b) -{ init_info_rnn(s, b); } -void init_info(shuffle_pd_t *s, char *b) -{ init_info_shuffle(s, b); } -void init_info(softmax_pd_t *s, char *b) -{ init_info_softmax(s, b); } -void init_info(sum_pd_t *s, char *b) -{ init_info_mem(s, b); } - -} -} - -mkldnn_status_t mkldnn_set_verbose(int level) { - using namespace mkldnn::impl::status; - if (level < 0 || level > 2) return invalid_arguments; - mkldnn::impl::verbose.level = level; - mkldnn::impl::initialized = true; - return success; -} - -const mkldnn_version_t *mkldnn_version() { - static mkldnn_version_t ver = { - MKLDNN_VERSION_MAJOR, - MKLDNN_VERSION_MINOR, - MKLDNN_VERSION_PATCH, - MKLDNN_VERSION_HASH}; - return &ver; -} diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp deleted file mode 100644 index e3049750c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef VERBOSE_HPP -#define VERBOSE_HPP - -#include -#include - -#include "mkldnn_debug.h" -#include "c_types_map.hpp" -#include "utils.hpp" -#include "z_magic.hpp" - -namespace mkldnn { -namespace impl { - -struct verbose_t { - int level; -}; - -const verbose_t *mkldnn_verbose(); -double get_msec(); -const char *get_isa_info(); - -#if !defined(DISABLE_VERBOSE) -#define MKLDNN_VERBOSE_BUF_LEN 1024 -#else -#define MKLDNN_VERBOSE_BUF_LEN 1 -#endif - -void init_info(batch_normalization_pd_t *s, char *buffer); -void init_info(concat_pd_t *s, char *buffer); -void init_info(convolution_pd_t *s, char *buffer); -void init_info(deconvolution_pd_t *s, char *buffer); -void init_info(eltwise_pd_t *s, char *buffer); -void init_info(inner_product_pd_t *s, char *buffer); -void init_info(lrn_pd_t *s, char *buffer); -void init_info(pooling_pd_t *s, char *buffer); -void init_info(reorder_pd_t *s, char *buffer); -void init_info(rnn_pd_t *s, char *buffer); -void init_info(shuffle_pd_t *s, char *buffer); -void init_info(softmax_pd_t *s, char *buffer); -void init_info(sum_pd_t *s, char *buffer); - -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp deleted file mode 100644 index 520bd4710..000000000 --- a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef Z_MAGIC_HPP -#define Z_MAGIC_HPP - -#define CHAIn2(a,b) a b -#define CHAIN2(a,b) CHAIn2(a,b) - -#define CONCAt2(a,b) a ## b -#define CONCAT2(a,b) CONCAt2(a,b) - -#define STRINGIFy(s) #s -#define STRINGIFY(s) STRINGIFy(s) - -#ifdef _MSC_VER -# define PRAGMA_MACRo(x) __pragma(x) -# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) -#else -# define PRAGMA_MACRo(x) _Pragma(#x) -# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) -#endif - -#define UNUSED(x) ((void)x) -#define MAYBE_UNUSED(x) UNUSED(x) - -#if defined(_WIN32) && !defined(__GNUC__) -#define __PRETTY_FUNCTION__ __FUNCSIG__ -#endif - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp deleted file mode 100644 index 7cf7822d9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "cpu_barrier.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace simple_barrier { - -void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, - Xbyak::Reg64 reg_nthr) { -# define BAR_CTR_OFF offsetof(ctx_t, ctr) -# define BAR_SENSE_OFF offsetof(ctx_t, sense) - using namespace Xbyak; - - Xbyak::Reg64 reg_tmp = [&]() { - /* returns register which is neither reg_ctx nor reg_nthr */ - Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx }; - for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i) - if (!utils::one_of(regs[i], reg_ctx, reg_nthr)) - return regs[i]; - return regs[0]; /* should not happen */ - }(); - - Label barrier_exit_label, barrier_exit_restore_label, spin_label; - - code.cmp(reg_nthr, 1); - code.jbe(barrier_exit_label); - - code.push(reg_tmp); - - /* take and save current sense */ - code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); - code.push(reg_tmp); - code.mov(reg_tmp, 1); - - if (mayiuse(avx512_mic)) { - code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); - code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); - } - - code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp); - code.add(reg_tmp, 1); - code.cmp(reg_tmp, reg_nthr); - code.pop(reg_tmp); /* restore previous sense */ - code.jne(spin_label); - - /* the last thread {{{ */ - code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx - - // notify waiting threads - code.not_(reg_tmp); - code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp); - code.jmp(barrier_exit_restore_label); - /* }}} the last thread */ - - code.CodeGenerator::L(spin_label); - code.pause(); - code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); - code.je(spin_label); - - code.CodeGenerator::L(barrier_exit_restore_label); - code.pop(reg_tmp); - - code.CodeGenerator::L(barrier_exit_label); -# undef BAR_CTR_OFF -# undef BAR_SENSE_OFF -} - -/** jit barrier generator */ -struct jit_t: public jit_generator { - void (*barrier)(ctx_t *ctx, size_t nthr); - - jit_t() { - generate(*this, abi_param1, abi_param2); - ret(); - barrier = reinterpret_cast(const_cast( - this->getCode())); - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t) -}; - -void barrier(ctx_t *ctx, int nthr) { - static jit_t j; /* XXX: constructed on load ... */ - j.barrier(ctx, nthr); -} - -} - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp deleted file mode 100644 index 0f55e33aa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_BARRIER_HPP -#define CPU_BARRIER_HPP - -#include - -#include "jit_generator.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace simple_barrier { - -STRUCT_ALIGN(64, -struct ctx_t { - enum { CACHE_LINE_SIZE = 64 }; - volatile size_t ctr; - char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; - volatile size_t sense; - char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; -}); - -inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero(); } -void barrier(ctx_t *ctx, int nthr); - -/** injects actual barrier implementation into another jitted code - * @params: - * code -- jit_generator object where the barrier is to be injected - * reg_ctx -- read-only register with pointer to the barrier context - * reg_nnthr -- read-only register with the # of synchronizing threads - */ -void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, - Xbyak::Reg64 reg_nthr); - -} - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp deleted file mode 100644 index 1ed5ad57b..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_BATCH_NORMALIZATION_PD_HPP -#define CPU_BATCH_NORMALIZATION_PD_HPP - -#include "batch_normalization_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t { - using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t; -}; - -struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t { - using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp deleted file mode 100644 index b8d5c4fca..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -#include "cpu_batch_normalization_utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { -namespace bnorm_utils { - -void cache_balance(size_t working_set_size, dim_t C_blks, - dim_t &C_blks_per_iter, int64_t &iters) { - int nthrs = mkldnn_get_max_threads(); - int l3_size = get_cache_size(3, true) * nthrs / 2; - - C_blks_per_iter = l3_size / working_set_size; - - if (C_blks_per_iter == 0) - C_blks_per_iter = 1; - if (C_blks_per_iter > C_blks) - C_blks_per_iter = C_blks; - - iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter; -} - -bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, - int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, - dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, - dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) { - if (nthr <= C_blks || !mkldnn_thr_syncable()) { - C_ithr = ithr; C_nthr = nthr; - N_ithr = 0; N_nthr = 1; - S_ithr = 0; S_nthr = 1; - N_s = 0; N_e = N; S_s = 0; S_e = SP; - balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); - } else { - if (do_blocking) { - N_nthr = (int)nstl::min(N, nthr); - C_nthr = (int)nstl::min(C_blks, nthr / N_nthr); - S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); - } else { - C_nthr = (int)math::gcd((dim_t)nthr, C_blks); - N_nthr = (int)nstl::min(N, nthr / C_nthr); - S_nthr = (int)nstl::min(SP, nthr / (C_nthr * N_nthr)); - } - - if (!spatial_thr_allowed) - S_nthr = 1; - - if (S_nthr < 1) S_nthr = 1; - if (ithr < C_nthr * N_nthr * S_nthr) { - N_ithr = (ithr / S_nthr) % N_nthr ; - C_ithr = ithr / (N_nthr * S_nthr); - S_ithr = ithr % S_nthr; - balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); - balance211(N, N_nthr, N_ithr, N_s, N_e); - balance211(SP, S_nthr, S_ithr, S_s, S_e); - } else { - S_ithr = N_ithr = C_ithr = -ithr; - S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1; - } - } - - // spatial_thr_allowed is meant to help maintain - // consistent decisions about spatial threading - // between mutiple invocations of this routine. - // It is caller's responsibility to check the - // return value and pass it as a flag to the - // next call if needed. - if (S_nthr == 1) - spatial_thr_allowed = false; - - return spatial_thr_allowed; -} - -bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, - int data_size) { - if (!mkldnn_thr_syncable()) return false; - - dim_t nthr = mkldnn_get_max_threads(); - dim_t SP = bdesc->W() * bdesc->D() * bdesc->H(); - dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md()) - .padded_dims()[1]; - assert(C_PADDED % simd_w == 0); - - size_t data = bdesc->MB() * C_PADDED * SP * data_size; - size_t l3_size_ = get_cache_size(3, true) * nthr / 2; - bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0); - dim_t C_blks_per_iter{ 1 }, iters{ 1 }; - dim_t C_blks = C_PADDED / simd_w; - - if (do_blocking) { - int num_tensors = bdesc->is_fwd() ? 1 : 2; - size_t working_set_size - = (bdesc->MB() * SP * simd_w * data_size) * num_tensors; - cache_balance(working_set_size, C_blks, C_blks_per_iter, iters); - } - - // Spatial threading decision made in this function shall be consistent - // with thread_balance() behavior. - C_blks = do_blocking ? C_blks_per_iter : C_blks; - - if (nthr <= C_blks) return false; - - dim_t S_nthr = 1; - if (do_blocking) { - dim_t N_nthr = nstl::min(bdesc->MB(), nthr); - dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr); - S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); - } else { - dim_t C_nthr = math::gcd(nthr, C_blks); - dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr); - S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); - } - - return S_nthr > 1; -} - -} -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp deleted file mode 100644 index 0daef0716..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP -#define CPU_BATCH_NORMALIZATION_UTILS_HPP - -#include "batch_normalization_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { -namespace bnorm_utils { - -void cache_balance(size_t working_set_size, dim_t C_blks, - dim_t &C_blks_per_iter, int64_t &iters); - -bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, - int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, - dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, - dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e); - -bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, - int data_size); - -} -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp deleted file mode 100644 index b92649120..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp +++ /dev/null @@ -1,51 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "cpu_engine.hpp" - -/* -#include "cpu/ref_concat.hpp" -#include "cpu/simple_concat.hpp" -*/ - -namespace mkldnn { -namespace impl { -namespace cpu { - -using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f; - -namespace { -#define INSTANCE(...) __VA_ARGS__::pd_t::create -static const cpd_create_f cpu_concat_impl_list[] = { - /* - INSTANCE(simple_concat_t), - INSTANCE(simple_concat_t), - INSTANCE(simple_concat_t), - INSTANCE(simple_concat_t), - INSTANCE(ref_concat_t), - */ - nullptr, -}; -#undef INSTANCE -} - -const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const { - return cpu_concat_impl_list; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp deleted file mode 100644 index 0b01bcf16..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_CONCAT_PD_HPP -#define CPU_CONCAT_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "concat_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_concat_pd_t: public concat_pd_t { - using concat_pd_t::concat_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp deleted file mode 100644 index 52a38a229..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_CONVOLUTION_PD_HPP -#define CPU_CONVOLUTION_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "convolution_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t { - using convolution_fwd_pd_t::convolution_fwd_pd_t; - - bool has_padded_dst() const { - memory_desc_wrapper dst_d(&dst_md_); - return OC() != dst_d.padded_dims()[1]; - } - - bool wants_padded_bias() const { - if (!with_bias()) return false; - return has_padded_dst(); - } - - bool wants_zero_pad_dst(bool jit_impl = true) const { - if (!has_padded_dst()) return false; - const auto &po = attr()->post_ops_; - int idx; - if ((idx = po.find(primitive_kind::eltwise)) == -1) return false; - return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg, - jit_impl); - } -}; - -struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t { - using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t; -}; - -struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t { - using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t; - - bool wants_padded_bias() const { - if (!with_bias()) return false; - memory_desc_wrapper diff_dst_d(&diff_dst_md_); - return OC() != diff_dst_d.padded_dims()[1]; - } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp deleted file mode 100644 index 164c8601d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_DECONVOLUTION_PD_HPP -#define CPU_DECONVOLUTION_PD_HPP - -#include - -#include "deconvolution_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t { - using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t; -}; - -struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t { - using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t; -}; - -struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t { - using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp deleted file mode 100644 index c52f00026..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_ELTWISE_PD_HPP -#define CPU_ELTWISE_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "eltwise_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t { - using eltwise_fwd_pd_t::eltwise_fwd_pd_t; -}; - -struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t { - using eltwise_bwd_pd_t::eltwise_bwd_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp deleted file mode 100644 index ce0a3667a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp +++ /dev/null @@ -1,324 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "type_helpers.hpp" -#include "verbose.hpp" - -#include "cpu_engine.hpp" -#include "cpu_memory.hpp" - -//#include "cpu/rnn/ref_rnn.hpp" - -//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" -//#include "cpu/jit_avx512_common_1x1_convolution.hpp" -#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp" -#include "cpu/jit_avx512_common_convolution_winograd.hpp" -//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp" -#include "cpu/jit_avx512_common_convolution.hpp" -//#include "cpu/jit_avx2_1x1_convolution.hpp" -//#include "cpu/jit_sse42_1x1_convolution.hpp" -#include "cpu/jit_avx2_convolution.hpp" -#include "cpu/jit_sse42_convolution.hpp" -//#include "cpu/gemm_convolution.hpp" -//#include "cpu/gemm_x8s8s32x_convolution.hpp" -//#include "cpu/ref_convolution.hpp" -//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp" -//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" -//#include "cpu/ref_deconvolution.hpp" -//#include "cpu/ref_shuffle.hpp" -//#include "cpu/jit_uni_eltwise.hpp" -//#include "cpu/ref_eltwise.hpp" -//#include "cpu/ref_softmax.hpp" -#include "cpu/jit_uni_pooling.hpp" -//#include "cpu/jit_uni_i8i8_pooling.hpp" -//#include "cpu/ref_pooling.hpp" -//#include "cpu/nchw_pooling.hpp" -//#include "cpu/nhwc_pooling.hpp" -//#include "cpu/jit_avx512_common_lrn.hpp" -//#include "cpu/jit_uni_lrn.hpp" -//#include "cpu/ref_lrn.hpp" -//#include "cpu/jit_uni_batch_normalization.hpp" -//#include "cpu/ref_batch_normalization.hpp" -//#include "cpu/ncsp_batch_normalization.hpp" -//#include "cpu/nspc_batch_normalization.hpp" -//#include "cpu/ref_inner_product.hpp" -//#include "cpu/gemm_inner_product.hpp" -//#include "cpu/gemm_x8s8s32x_inner_product.hpp" -//#include "cpu/jit_uni_dw_convolution.hpp" -//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp" -#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -status_t cpu_engine_t::memory_create(memory_t **memory, - const memory_desc_t *md, void *handle) { - auto _memory = new cpu_memory_t(this, md, handle); - if (_memory == nullptr) - return status::out_of_memory; - - status_t status = _memory->init(); - if (status != status::success) { - delete _memory; - return status; - } - - return safe_ptr_assign(*memory, _memory); -} - -using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; - -namespace { -using namespace mkldnn::impl::data_type; - -#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> -static const pd_create_f cpu_impl_list[] = { - /* RNN */ - /* - INSTANCE(ref_rnn_fwd_f32_t), - INSTANCE(ref_rnn_fwd_u8s8_t), - INSTANCE(ref_rnn_bwd_f32_t), - */ - /* conv */ - /* - INSTANCE(jit_avx512_common_dw_convolution_fwd_t), - INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t), - INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t), - INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t), - INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t), - INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t), - */ - INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t), - INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t), - //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t), - //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t), - INSTANCE(jit_avx512_common_convolution_winograd_fwd_t), - //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t), - //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t), - INSTANCE(jit_avx512_common_convolution_fwd_t), - //INSTANCE(jit_avx512_common_convolution_bwd_data_t), - //INSTANCE(jit_avx512_common_convolution_bwd_weights_t), - /* - INSTANCE(jit_avx2_dw_convolution_fwd_t), - INSTANCE(jit_avx2_dw_convolution_bwd_data_t), - INSTANCE(jit_avx2_dw_convolution_bwd_weights_t), - INSTANCE(jit_avx2_1x1_convolution_fwd_t), - INSTANCE(jit_avx2_1x1_convolution_bwd_data_t), - INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t), - INSTANCE(jit_sse42_dw_convolution_fwd_t), - INSTANCE(jit_sse42_dw_convolution_bwd_data_t), - INSTANCE(jit_sse42_dw_convolution_bwd_weights_t), - INSTANCE(jit_sse42_1x1_convolution_fwd_t), - */ - INSTANCE(jit_avx2_convolution_fwd_t), - //INSTANCE(jit_avx2_convolution_bwd_data_t), - //INSTANCE(jit_avx2_convolution_bwd_weights_t), - INSTANCE(jit_sse42_convolution_fwd_t), - /* - INSTANCE(gemm_convolution_fwd_t), - INSTANCE(gemm_convolution_bwd_data_t), - INSTANCE(gemm_convolution_bwd_weights_t), - INSTANCE(ref_convolution_fwd_t), - INSTANCE(ref_convolution_bwd_data_t), - INSTANCE(ref_convolution_bwd_weights_t), - */ - /* conv (int) */ - /* - INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), - INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), - INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), - INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_x8s8s32x_convolution_fwd_t), - INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), - INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), - INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), - INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t), - INSTANCE(ref_convolution_fwd_t), - INSTANCE(ref_convolution_fwd_t), - INSTANCE(ref_convolution_fwd_t), - INSTANCE(ref_convolution_fwd_t), - INSTANCE(ref_convolution_bwd_data_t), - INSTANCE(ref_convolution_bwd_data_t), - INSTANCE(ref_convolution_bwd_data_t), - INSTANCE(ref_convolution_bwd_data_t), - */ - /* deconv */ - /* - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t), - INSTANCE(ref_deconvolution_bwd_weights_t), - INSTANCE(ref_deconvolution_bwd_data_t), - INSTANCE(ref_deconvolution_fwd_t), - */ - /* shuffle */ - /* - INSTANCE(ref_shuffle_t<4>), // f32 or s32 - INSTANCE(ref_shuffle_t<1>), // s8 or u8 - */ - /* eltwise */ - /* - INSTANCE(jit_uni_eltwise_fwd_t), - INSTANCE(jit_uni_eltwise_bwd_t), - INSTANCE(jit_uni_eltwise_fwd_t), - INSTANCE(jit_uni_eltwise_bwd_t), - INSTANCE(jit_uni_eltwise_fwd_t), - INSTANCE(jit_uni_eltwise_bwd_t), - INSTANCE(ref_eltwise_fwd_t), - INSTANCE(ref_eltwise_bwd_t), - */ - /* eltwise (int) */ - /* - INSTANCE(ref_eltwise_fwd_t), - INSTANCE(ref_eltwise_fwd_t), - INSTANCE(ref_eltwise_fwd_t), - INSTANCE(ref_eltwise_bwd_t), - */ - /* softmax */ - /* - INSTANCE(ref_softmax_fwd_t), - INSTANCE(ref_softmax_bwd_t), - */ - /* pool */ - INSTANCE(jit_uni_pooling_fwd_t), - //INSTANCE(jit_uni_pooling_bwd_t), - INSTANCE(jit_uni_pooling_fwd_t), - //INSTANCE(jit_uni_pooling_bwd_t), - INSTANCE(jit_uni_pooling_fwd_t), - //INSTANCE(jit_uni_pooling_bwd_t), - /* - INSTANCE(nchw_pooling_fwd_t), - INSTANCE(nchw_pooling_bwd_t), - INSTANCE(nhwc_pooling_fwd_t), - INSTANCE(nhwc_pooling_bwd_t), - INSTANCE(ref_pooling_fwd_t), - INSTANCE(ref_pooling_bwd_t), - */ - /* pool (int) */ - /* - INSTANCE(jit_uni_i8i8_pooling_fwd_t), - INSTANCE(jit_uni_i8i8_pooling_fwd_t), - INSTANCE(ref_pooling_fwd_t), - INSTANCE(ref_pooling_fwd_t), - INSTANCE(ref_pooling_fwd_t), - INSTANCE(ref_pooling_bwd_t), - */ - /* lrn */ - /* - INSTANCE(jit_avx512_common_lrn_fwd_t), - INSTANCE(jit_avx512_common_lrn_bwd_t), - INSTANCE(jit_uni_lrn_fwd_t), - INSTANCE(jit_uni_lrn_bwd_t), - INSTANCE(jit_uni_lrn_fwd_t), - INSTANCE(ref_lrn_fwd_t), - INSTANCE(ref_lrn_bwd_t), - */ - /* batch normalization */ - /* - INSTANCE(jit_uni_batch_normalization_fwd_t), - INSTANCE(jit_uni_batch_normalization_bwd_t), - INSTANCE(jit_uni_batch_normalization_fwd_t), - INSTANCE(jit_uni_batch_normalization_bwd_t), - INSTANCE(jit_uni_batch_normalization_fwd_t), - INSTANCE(jit_uni_batch_normalization_bwd_t), - INSTANCE(ncsp_batch_normalization_fwd_t), - INSTANCE(ncsp_batch_normalization_bwd_t), - INSTANCE(nspc_batch_normalization_fwd_t), - INSTANCE(nspc_batch_normalization_bwd_t), - INSTANCE(ref_batch_normalization_fwd_t), - INSTANCE(ref_batch_normalization_bwd_t), - INSTANCE(ref_batch_normalization_fwd_t), - */ - /* inner product */ - /* - INSTANCE(gemm_inner_product_fwd_t), - INSTANCE(gemm_inner_product_bwd_data_t), - INSTANCE(gemm_inner_product_bwd_weights_t), - INSTANCE(ref_inner_product_fwd_t), - INSTANCE(ref_inner_product_bwd_data_t), - INSTANCE(ref_inner_product_bwd_weights_t), - */ - /* inner product (int) */ - /* - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(gemm_x8s8s32x_inner_product_fwd_t), - INSTANCE(ref_inner_product_fwd_t), - INSTANCE(ref_inner_product_fwd_t), - INSTANCE(ref_inner_product_fwd_t), - INSTANCE(ref_inner_product_fwd_t), - */ - /* eol */ - nullptr, -}; -#undef INSTANCE -} - -const pd_create_f* cpu_engine_t::get_implementation_list() const { - return cpu_impl_list; -} - -cpu_engine_factory_t engine_factory; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp deleted file mode 100644 index e4c877ee0..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_ENGINE_HPP -#define CPU_ENGINE_HPP - -#include - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "../common/engine.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -class cpu_engine_t: public engine_t { -public: - cpu_engine_t(): engine_t(engine_kind::cpu) {} - - /* implementation part */ - - virtual status_t memory_create(memory_t **memory, - const memory_desc_t *md, void *handle) override; - - virtual const concat_primitive_desc_create_f* - get_concat_implementation_list() const override; - virtual const reorder_primitive_desc_create_f* - get_reorder_implementation_list() const override; - virtual const sum_primitive_desc_create_f* - get_sum_implementation_list() const override; - virtual const primitive_desc_create_f* - get_implementation_list() const override; -}; - -class cpu_engine_factory_t: public engine_factory_t { -public: - virtual size_t count() const override { return 1; } - virtual engine_kind_t kind() const override { return engine_kind::cpu; } - virtual status_t engine_create(engine_t **engine, - size_t index) const override { - assert(index == 0); - *engine = new cpu_engine_t(); - return status::success; - }; -}; - -extern cpu_engine_factory_t engine_factory; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp deleted file mode 100644 index 5880d3450..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_INNER_PRODUCT_PD_HPP -#define CPU_INNER_PRODUCT_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "inner_product_pd.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace { -inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d, - const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { - using namespace utils; - - auto strides_compatible = [&]() { - bool ok = true; - auto w_str = wei_d.blocking_desc().strides; - auto d_str = src_d.blocking_desc().strides; - for (int i = 1; i < src_d.ndims() - 1; i++) { - ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; - } - return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); - }; - return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc() - && src_d.ndims() == wei_d.ndims() - && src_d.blocking_desc().inner_nblks - == wei_d.blocking_desc().inner_nblks - && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1) - && array_cmp(src_d.blocking_desc().inner_blks, - wei_d.blocking_desc().inner_blks, - wei_d.blocking_desc().inner_nblks) - && array_cmp(src_d.blocking_desc().inner_idxs, - wei_d.blocking_desc().inner_idxs, - wei_d.blocking_desc().inner_nblks) - && strides_compatible() - && dst_d.matches_tag(format_tag::nc) - && src_d.only_padded_dim(1) - && wei_d.only_padded_dim(1) - && src_d.padded_dims()[1] == wei_d.padded_dims()[1] - && src_d.is_dense(true) - && dst_d.is_dense() - && wei_d.is_dense(true); -} -} - -struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t { - using inner_product_fwd_pd_t::inner_product_fwd_pd_t; -}; - -struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t { - using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; -}; - -struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t { - using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp deleted file mode 100644 index da6e9dac8..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp +++ /dev/null @@ -1,151 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_ISA_TRAITS_HPP -#define CPU_ISA_TRAITS_HPP - -#include - -#define XBYAK64 -#define XBYAK_NO_OP_NAMES -/* in order to make selinux happy memory that would be marked with X-bit should - * be obtained with mmap */ -#define XBYAK_USE_MMAP_ALLOCATOR -#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) -/* turn off `size_t to other-type implicit casting` warning - * currently we have a lot of jit-generated instructions that - * take uint32_t, but we pass size_t (e.g. due to using sizeof). - * FIXME: replace size_t parameters with the appropriate ones */ -#pragma warning (disable: 4267) -#endif -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -typedef enum { - isa_any, - sse41, - sse42, - avx, - avx2, - avx512_common, - avx512_core, - avx512_core_vnni, - avx512_mic, - avx512_mic_4ops, -} cpu_isa_t; - -template struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */ - -template <> struct cpu_isa_traits { - typedef Xbyak::Xmm Vmm; - static constexpr int vlen_shift = 4; - static constexpr int vlen = 16; - static constexpr int n_vregs = 16; -}; -template <> struct cpu_isa_traits { - typedef Xbyak::Ymm Vmm; - static constexpr int vlen_shift = 5; - static constexpr int vlen = 32; - static constexpr int n_vregs = 16; -}; -template <> struct cpu_isa_traits: - public cpu_isa_traits {}; - -template <> struct cpu_isa_traits { - typedef Xbyak::Zmm Vmm; - static constexpr int vlen_shift = 6; - static constexpr int vlen = 64; - static constexpr int n_vregs = 32; -}; -template <> struct cpu_isa_traits: - public cpu_isa_traits {}; - -template <> struct cpu_isa_traits: - public cpu_isa_traits {}; - -template <> struct cpu_isa_traits: - public cpu_isa_traits {}; - -namespace { - -static Xbyak::util::Cpu cpu; -static inline bool mayiuse(const cpu_isa_t cpu_isa) { - using namespace Xbyak::util; - - switch (cpu_isa) { - case sse41: - case sse42: - // FIXME: SSE4.2 is actually NOT required - //return cpu.has(Cpu::tSSE42); - return cpu.has(Cpu::tSSE41); - case avx: - return cpu.has(Cpu::tAVX); - case avx2: - return cpu.has(Cpu::tAVX2); - case avx512_common: - return cpu.has(Cpu::tAVX512F); - case avx512_core: - return true - && cpu.has(Cpu::tAVX512F) - && cpu.has(Cpu::tAVX512BW) - && cpu.has(Cpu::tAVX512VL) - && cpu.has(Cpu::tAVX512DQ); - case avx512_core_vnni: - return true - && cpu.has(Cpu::tAVX512F) - && cpu.has(Cpu::tAVX512BW) - && cpu.has(Cpu::tAVX512VL) - && cpu.has(Cpu::tAVX512DQ) - && cpu.has(Cpu::tAVX512_VNNI); - case avx512_mic: - return true - && cpu.has(Cpu::tAVX512F) - && cpu.has(Cpu::tAVX512CD) - && cpu.has(Cpu::tAVX512ER) - && cpu.has(Cpu::tAVX512PF); - case avx512_mic_4ops: - return true - && mayiuse(avx512_mic) - && cpu.has(Cpu::tAVX512_4FMAPS) - && cpu.has(Cpu::tAVX512_4VNNIW); - case isa_any: - return true; - } - return false; -} -} - -/* whatever is required to generate string literals... */ -#include "z_magic.hpp" -#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \ - (isa == sse42 ? prefix STRINGIFY(sse42) : \ - (isa == avx ? prefix STRINGIFY(avx) : \ - (isa == avx2 ? prefix STRINGIFY(avx2) : \ - (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \ - (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \ - (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \ - (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \ - prefix suffix_if_any))))))) - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp deleted file mode 100644 index 49988f4c2..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_LRN_PD_HPP -#define CPU_LRN_PD_HPP - -#include - -#include "lrn_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t { - using lrn_fwd_pd_t::lrn_fwd_pd_t; -}; - -struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t { - using lrn_bwd_pd_t::lrn_bwd_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp deleted file mode 100644 index 3c0624cf4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp +++ /dev/null @@ -1,277 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "mkldnn_traits.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_memory.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl; -using namespace mkldnn::impl::data_type; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::format_tag; - -enum blk_kind_t { a, b, c, ab, ba, bc, cb }; - -template -void typed_zero_pad_blk( - const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { - using data_t = typename prec_traits
::type; - const auto &dims = m_d.dims(); - const auto &pdims = m_d.padded_dims(); - const auto &blk = m_d.blocking_desc(); - auto dim_is_blocked = [&](int dim) { - for (int i = 0; i < blk.inner_nblks; i++) - if (blk.inner_idxs[i] == dim) - return true; - return false; - }; - bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1), - C_blocked = dim_is_blocked(2); - - assert(blk.inner_nblks < 4); - assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked) - || (C_blocked && B_blocked)); - - const int a_tail_s = A_blocked ? dims[0] % blksize : 0; - const int b_tail_s = B_blocked ? dims[1] % blksize : 0; - const int c_tail_s = C_blocked ? dims[2] % blksize : 0; - assert(a_tail_s || b_tail_s || c_tail_s); - - const int A = A_blocked ? pdims[0] / blksize : dims[0]; - const int B = B_blocked ? pdims[1] / blksize : dims[1]; - const int C = C_blocked ? pdims[2] / blksize : dims[2]; - const int D = m_d.ndims() > 3 ? dims[3] : 1; - const int E = m_d.ndims() > 4 ? dims[4] : 1; - const int F = m_d.ndims() > 5 ? dims[5] : 1; - const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1; - - auto zeroize_tail = [&](data_t *d, const int tail_s) { - for (int b = tail_s; b < blksize; ++b) - d[b] = 0; - }; - auto zeroize_tail_inner = [&](data_t *d, const int tail_s) { - for (int b1 = 0; b1 < blksize; ++b1) - for (int b2 = tail_s; b2 < blksize; ++b2) - d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 - + b1 % inner_blk] - = 0; - }; - auto zeroize_tail_outer = [&](data_t *d, const int tail_s) { - for (int b1 = tail_s; b1 < blksize; ++b1) - for (int b2 = 0; b2 < blksize; ++b2) - d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 - + b1 % inner_blk] - = 0; - }; - - if (c_tail_s) { - parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) { - auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)]; - if (blk_kind == c) - zeroize_tail(x, c_tail_s); - else if (blk_kind == bc) - zeroize_tail_inner(x, c_tail_s); - else if (blk_kind == cb) - zeroize_tail_outer(x, c_tail_s); - }); - } - - if (b_tail_s) { - parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) { - auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)]; - if (blk_kind == b) - zeroize_tail(x, b_tail_s); - else if (blk_kind == ab || blk_kind == cb) - zeroize_tail_inner(x, b_tail_s); - else if (blk_kind == ba || blk_kind == bc) - zeroize_tail_outer(x, b_tail_s); - }); - } - - if (a_tail_s) { - parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) { - auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)]; - if (blk_kind == a) - zeroize_tail(x, a_tail_s); - else if (blk_kind == ba) - zeroize_tail_inner(x, a_tail_s); - else if (blk_kind == ab) - zeroize_tail_outer(x, a_tail_s); - }); - } -} - -/* - * all - */ -template -void typed_zero_pad_generic_blocked( - const memory_desc_wrapper &m_d, typename prec_traits
::type *data) { - const int ndims = m_d.ndims(); - const auto &dims = m_d.dims(); - const auto &pdims = m_d.padded_dims(); - - const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true); - - /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1] - * | \ / - * | --------------------- - * has contiguous - * padding - * - * step <-- D_k+1 * ... * D_ndims-1 - * step_dim <-- k - */ - - ptrdiff_t step = 1; - int step_dim = ndims - 1; - for (; step_dim >= 0; --step_dim) { - if (dims[step_dim] != pdims[step_dim]) - break; - step *= dims[step_dim]; - } - - assert(step_dim >= 0 && "no zero padding is required"); - if (step_dim < 0) - return; - - parallel_nd(nelems / step, [&](ptrdiff_t e1) { - bool need_zero = false; - - ptrdiff_t idx = e1; - for (int d = step_dim; d >= 0; --d) { - if (idx % pdims[d] >= dims[d]) { - need_zero = true; - break; - } - idx /= pdims[d]; - } - - if (need_zero) { - for (ptrdiff_t e0 = 0; e0 < step; ++e0) - data[m_d.off_l(e1 * step + e0, true)] = 0; - } - }); -} - -template -status_t cpu_memory_t::typed_zero_pad() const { - const memory_desc_wrapper mdw(md()); - - if (mdw.format_kind() != format_kind::blocked) - return unimplemented; - - if (mdw.nelems(false) == mdw.nelems(true)) - return success; - - auto *data = (typename prec_traits
::type *)data_; - auto blk = mdw.blocking_desc(); - - auto get_blksize = [&](int ind) { - int blksize = 1; - for (int i = 0; i < blk.inner_nblks; i++) { - if (blk.inner_idxs[i] == ind) - blksize *= blk.inner_blks[i]; - } - return blksize; - }; - const int blksize = get_blksize(blk.inner_idxs[0]); - -# define CASE(blksize_, blk_kind) \ - do { \ - if (blksize == blksize_) { \ - typed_zero_pad_blk(mdw, data); \ - return success; \ - } \ - } while(0) - - switch (blk.inner_nblks) { - case 1: - if (blk.inner_idxs[0] == 0) { - CASE(4, a); - CASE(8, a); - CASE(16, a); - } else if (blk.inner_idxs[0] == 1) { - CASE(4, b); - CASE(8, b); - CASE(16, b); - } - break; - case 2: - case 3: - if (!IMPLICATION(blk.inner_nblks == 3, - blk.inner_idxs[0] == blk.inner_idxs[2])) - break; - - if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) { - CASE(4, ab); - CASE(8, ab); - CASE(16, ab); - } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) { - CASE(4, ba); - CASE(8, ba); - CASE(16, ba); - } - if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) { - CASE(4, bc); - CASE(8, bc); - CASE(16, bc); - } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) { - CASE(4, cb); - CASE(8, cb); - CASE(16, cb); - } - break; - default: break; - } - -# undef CASE - - // the last line of defence - typed_zero_pad_generic_blocked
(mdw, data); - return success; -} - -status_t cpu_memory_t::zero_pad() const { - memory_desc_wrapper mdw(md()); - const bool skip_zeroing = false - || data_ == nullptr - || mdw.is_zero() - || !mdw.is_blocking_desc(); - if (skip_zeroing) return success; - - switch (mdw.data_type()) { - case f32: return typed_zero_pad(); - case s32: return typed_zero_pad(); - case s8: return typed_zero_pad(); - case u8: return typed_zero_pad(); - default: assert(!"memory is undefined"); return unimplemented; - } - return unimplemented; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp deleted file mode 100644 index 2c01bcc6a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_MEMORY_HPP -#define CPU_MEMORY_HPP - -#include - -#include "c_types_map.hpp" -#include "memory.hpp" -#include "memory_desc_wrapper.hpp" - -#include "cpu_engine.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_memory_t: public memory_t { - cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle) - : memory_t(engine, md) - , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE) - , data_((char *)handle) {} - - cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md) - : cpu_memory_t(engine, md, nullptr) {} - - ~cpu_memory_t() { if (own_data_) free(data_); } - - virtual status_t init() override { - if (own_data_) { - data_ = nullptr; - const size_t size = memory_desc_wrapper(this->md()).size(); - if (size) { - data_ = (char *)malloc(size, 64); - if (data_ == nullptr) - return status::out_of_memory; - } - } - return zero_pad(); - } - - cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); } - - virtual status_t get_data_handle(void **handle) const override { - *handle = static_cast(data_); - return status::success; - } - - virtual mkldnn::impl::status_t set_data_handle(void *handle) override { - if (own_data_) { free(data_); own_data_ = false; } - data_ = static_cast(handle); - return zero_pad(); - } - - virtual mkldnn::impl::status_t zero_pad() const override; - -private: - bool own_data_; - char *data_; - - template - mkldnn::impl::status_t typed_zero_pad() const; - - cpu_memory_t(const cpu_memory_t &) = delete; - cpu_memory_t &operator=(const cpu_memory_t &) = delete; - cpu_memory_t &operator=(cpu_memory_t &&) = delete; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp deleted file mode 100644 index ac2daa415..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_POOLING_PD_HPP -#define CPU_POOLING_PD_HPP - -#include "pooling_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t { - using pooling_fwd_pd_t::pooling_fwd_pd_t; -}; - -struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t { - using pooling_bwd_pd_t::pooling_bwd_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp deleted file mode 100644 index 56127f36c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_PRIMITIVE_HPP -#define CPU_PRIMITIVE_HPP - -#include "mkldnn.h" - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "primitive.hpp" -#include "scratchpad.hpp" - -#define CTX_IN_MEM(type, arg) static_cast(ctx.input(arg)) -#define CTX_OUT_MEM(type, arg) static_cast(ctx.output(arg)) - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_memory_t; - -struct cpu_primitive_t: public primitive_t { - cpu_primitive_t(const primitive_desc_t *pd, - bool use_global_scratchpad = false) - : primitive_t(pd) - , scratchpad_buffer_(nullptr) - , global_scratchpad_(nullptr) - { - const size_t scratchpad_size = - this->pd()->scratchpad_size(scratchpad_mode::library); - - if (scratchpad_size) { - if (use_global_scratchpad) - global_scratchpad_ = create_scratchpad(scratchpad_size); - else - scratchpad_buffer_ = malloc(scratchpad_size, 64); - } - } - - virtual ~cpu_primitive_t() { - delete global_scratchpad_; - free(scratchpad_buffer_); - } - -protected: - memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const { - void *ptr = nullptr; - if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) { - ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD); - } else { - ptr = global_scratchpad_ - ? global_scratchpad_->get() : scratchpad_buffer_; - } - - return pd()->scratchpad_registry().grantor(ptr); - } - -private: - void *scratchpad_buffer_; - scratchpad_t *global_scratchpad_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp deleted file mode 100644 index 1d41ac5ce..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp +++ /dev/null @@ -1,544 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "mkldnn_thread.hpp" -#include "mkldnn_types.h" -#include "nstl.hpp" -#include "utils.hpp" - -#include "cpu_reducer.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace memory_tracking::names; - -void reduce_balancer_t::balance() { - using namespace nstl; - using namespace utils; - - assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0); - - const int job_complexity = 1; - - const int min_njobs_per_group = max(1, njobs_ / nthr_); - const int max_njobs_per_group = max(1, - static_cast(max_buffer_size_ / (nthr_ * job_size_))); - - /* initial guess */ - int ngroups = min(njobs_ / min_njobs_per_group, nthr_); - int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1; - int njobs_per_group_ub = div_up(njobs_, ngroups); - - /* rough upper-bound estimation, will be fixed during brute force */ - size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_; - - /* brute force parameters for the best balance... */ - for (int c_njobs_per_group = min_njobs_per_group; - c_njobs_per_group < njobs_; ++c_njobs_per_group) { - /* current assumption */ - int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_); - int c_nthr_per_group = syncable_ - ? min(nthr_ / c_ngroups, reduction_size_) : 1; - int c_njobs_per_group_ub = div_up(njobs_, c_ngroups); - - if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group) - continue; - - int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group); - size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub; - size_t c_thread_complexity_ub = c_group_size_ub * ( - job_complexity * c_thread_reduction_ub - + (c_nthr_per_group != 1)); - - if (c_thread_complexity_ub < thread_complexity_ub) { - ngroups = c_ngroups; - nthr_per_group = c_nthr_per_group; - njobs_per_group_ub = c_njobs_per_group_ub; - thread_complexity_ub = c_thread_complexity_ub; - } - } - - assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1); - assert(ngroups * nthr_per_group <= nthr_); - assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_ - || nthr_per_group == 1); /* no reduction buffer overflow */ - assert(IMPLICATION(!syncable_, nthr_per_group == 1)); - - ngroups_ = ngroups; - nthr_per_group_ = nthr_per_group; - njobs_per_group_ub_ = njobs_per_group_ub; -} - -/* reducer jit-ted driver */ - -using namespace Xbyak; - -template -struct reducer_2d_driver_t: public c_compatible { - typedef typename prec_traits::type data_t; - - reducer_2d_driver_t(int n_src, size_t src_ld, - size_t src_step, size_t dst_step, bool nullify_dst) - : n_src_(n_src), src_ld_(src_ld), src_step_(src_step) - , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {} - virtual ~reducer_2d_driver_t() {} - void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx) - { assert(ker_); ker_(dst, srcs, ny, nx); } - -protected: - int n_src_; - size_t src_ld_, src_step_, dst_step_; - bool nullify_dst_; - void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx); -}; - -template -struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t, - public jit_generator -{ - DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) - - /* cpu specific part */ - using Vmm = typename utils::conditional::type; - const AddressFrame &vmmword = (isa == avx2) ? yword : zword; - void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op) - { if (data_type == data_type::f32) vaddps(x1, x2, op); - else vpaddd(x1, x2, op); } - void uni_add(const Xmm& x1, const Operand& op) - { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); } - - const int vlen = cpu_isa_traits::vlen; - const int typesize - = sizeof(typename mkldnn::impl::prec_traits::type); - Xbyak::Reg64 reg_dst = abi_param1; - Xbyak::Reg64 reg_src = abi_param2; - Xbyak::Reg64 reg_ny = abi_param3; - Xbyak::Reg64 reg_nx = abi_param4; - - Xbyak::Reg64 reg_x = rax; - Xbyak::Reg64 reg_src_id = r10; - - reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step, - size_t dst_step, bool nullify_dst) - : reducer_2d_driver_t(n_src, src_ld, src_step, - dst_step, nullify_dst) - { generate(); } - - void nullify_dst(int nloads, int load_len) { - UNUSED(load_len); - for (int i = 0; i < nloads; ++i) - uni_vpxor(Vmm(i), Vmm(i), Vmm(i)); - /* prefetches[dst] ? */ - } - - void load_dst(int nloads, int load_len) { - for (int i = 0; i < nloads; ++i) { - if (load_len == typesize) - movd(Xmm(i), ptr[reg_dst + i * load_len]); - else if (load_len == vlen) - vmovups(Vmm(i), ptr[reg_dst + i * load_len]); - else - assert(!"unsupported"); - } - } - - void store_dst(int nloads, int load_len) { - for (int i = 0; i < nloads; ++i) { - if (load_len == typesize) - movd(ptr[reg_dst + i * load_len], Xmm(i)); - else if (load_len == vlen) - vmovups(ptr[reg_dst + i * load_len], Vmm(i)); - else - assert(!"unsupported"); - } - } - - void accumulate(int nloads, int load_len, size_t base_off) { - for (int i = 0; i < nloads; ++i) { - size_t off = base_off + i * load_len; - - if (load_len == typesize) - uni_add(Xmm(i), ptr[reg_src + off]); - else if (load_len == vlen) - uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); - else - assert(!"unsupported"); - } - } - - void loop_x() { - const int nloads[] = {cpu_isa_traits::n_vregs, 1, 1}; - const int nbranches = sizeof(nloads) / sizeof(nloads[0]); - - const int load_len[nbranches] = {vlen, vlen, typesize}; - Label loop_x_label[nbranches + 1]; - - mov(reg_x, reg_nx); - - for (int id = 0; id < nbranches; ++id) { - L(loop_x_label[id]); - - cmp(reg_x, nloads[id] * load_len[id]); - jl(loop_x_label[id + 1], T_NEAR); - - if (this->nullify_dst_) - nullify_dst(nloads[id], load_len[id]); - else - load_dst(nloads[id], load_len[id]); - - if (nloads[id] > 1) { - Label loop_srcs; - mov(reg_src_id, this->n_src_); - L(loop_srcs); - - accumulate(nloads[id], load_len[id], 0); - add(reg_src, this->src_ld_ * typesize); - - dec(reg_src_id); - jnz(loop_srcs, T_NEAR); - - sub(reg_src, this->n_src_ * this->src_ld_ * typesize); - } else { - for (int src_id = 0; src_id < this->n_src_; ++src_id) { - const size_t base_off = src_id * this->src_ld_ * typesize; - accumulate(nloads[id], load_len[id], base_off); - } - } - - store_dst(nloads[id], load_len[id]); - - add(reg_src, nloads[id] * load_len[id]); - add(reg_dst, nloads[id] * load_len[id]); - - sub(reg_x, nloads[id] * load_len[id]); - - jmp(loop_x_label[id], T_NEAR); - } - - L(loop_x_label[nbranches]); - - /* restore address registers */ - sub(reg_src, reg_nx); - sub(reg_dst, reg_nx); - } - - void generate() { - assert(isa == avx2 || isa == avx512_common || isa == avx512_mic); - - preamble(); - - shl(reg_nx, 2); - - Label ny_loop; - L(ny_loop); - - loop_x(); - - add(reg_dst, this->dst_step_ * typesize); - add(reg_src, this->src_step_ * typesize); - - dec(reg_ny); - jnz(ny_loop, T_NEAR); - - postamble(); - this->ker_ = reinterpret_castker_)>( - const_cast(this->getCode())); - } -}; - -template -inline reducer_2d_driver_t *create_reduce_2d_drv(int n_src, - size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) { - if (mayiuse(avx512_common)) - return new reducer_2d_driver_f_s_32_t(n_src, - src_ld, src_step, dst_step, nullify_dst); - else if (mayiuse(avx2)) - return new reducer_2d_driver_f_s_32_t(n_src, src_ld, - src_step, dst_step, nullify_dst); - assert(!"unimplemented"); - return nullptr; -} - -/* cpu_reducer_t */ - -template -void cpu_reducer_t::conf_t::init_scratchpad( - memory_tracking::registrar_t &scratchpad) const { - if (balancer_.nthr_per_group_ == 1) return; - - const size_t space_size = balancer_.ngroups_ - * (balancer_.nthr_per_group_ - 1) - * cpu_reducer_t::space_per_thread(balancer_); - scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K); - scratchpad.book(key_reducer_space_bctx, - sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); -} - -template -cpu_reducer_t::cpu_reducer_t(const conf_t &conf) - : conf_(conf), drv_(nullptr) -{ - if (balancer().nthr_per_group_ == 1) return; - - drv_ = create_reduce_2d_drv(balancer().nthr_per_group_ - 1, - space_per_thread(balancer()), 0, 0, false); -} - -template -cpu_reducer_t::~cpu_reducer_t() { delete drv_; } - -template -typename cpu_reducer_t::data_t * -cpu_reducer_t::get_local_ptr(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - const int id_in_grp = balancer().id_in_group(ithr); - - /* threads 0 from each group writes directly to the destination */ - if (id_in_grp == 0) - return dst + balancer().ithr_job_off(ithr) * balancer().job_size_; - - const int grp_id = balancer().group_id(ithr); - const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1) - + (id_in_grp - 1); - - auto space = scratchpad.template get(key_reducer_space); - return space + offset_factor * space_per_thread(balancer()); -} - -template -void cpu_reducer_t::reduce_nolock(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - bool redundant_reduction = balancer().nthr_per_group_ == 1 - || balancer().idle(ithr); - if (redundant_reduction) return; - -#ifdef SIMPLE_IMPL - if (balancer().id_in_group(ithr) != 0) - return; /* only threads 0 do the reduction */ - - const int njobs_in_grp = balancer().ithr_njobs(ithr); - data_t *d = get_local_ptr(ithr, dst, scratchpad); - for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp) - { - const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad); - for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i) - d[i] += space[i]; - } -#else - using namespace utils; - - const int id_in_grp = balancer().id_in_group(ithr); - const int njobs_in_grp = balancer().ithr_njobs(ithr); - const size_t cl = 64 / sizeof(data_t); - - const size_t reduction_size = njobs_in_grp * balancer().job_size_; - size_t start{0}, end{0}; - balance211(div_up(reduction_size, cl), balancer().nthr_per_group_, - id_in_grp, start, end); - - if (start == end) return; - - data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl; - const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad) - + start * cl; - const size_t len = nstl::min(end * cl, reduction_size) - start * cl; - - (*drv_)(d, space, 1, len); -#endif -} - -template struct cpu_reducer_t; -template struct cpu_reducer_t; - -/* cpu_reducer_2d_t */ - -template -void cpu_reducer_2d_t::conf_t::init_scratchpad( - memory_tracking::registrar_t &scratchpad) const { - if (balancer_.nthr_per_group_ == 1) return; - - const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_ - * cpu_reducer_2d_t::space_per_thread(balancer_); - scratchpad.book(key_reducer_space, sizeof(data_t) * space_size); - scratchpad.book(key_reducer_space_bctx, - sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); -} - -template -cpu_reducer_2d_t::cpu_reducer_2d_t(const conf_t &conf) - : conf_(conf), drv_(nullptr) -{ - if (balancer().nthr_per_group_ == 1) return; - - drv_ = create_reduce_2d_drv(balancer().nthr_per_group_, - space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_, - true); -} - -template -cpu_reducer_2d_t::~cpu_reducer_2d_t() { delete drv_; } - -template -typename cpu_reducer_2d_t::data_t *cpu_reducer_2d_t:: -get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const { - const int id_in_grp = balancer().id_in_group(ithr); - const int grp_id = balancer().group_id(ithr); - const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp; - auto space = scratchpad.template get(key_reducer_space); - return space + offset_factor * space_per_thread(balancer()); -} - -template -int cpu_reducer_2d_t::choose_x_blocking(int nx, int ny, - int nthr_per_grp) const { - // find x_blocking for better balance reducing work between threads - assert(conf_.x_block_ > 0 && nx > conf_.x_block_ - && nx % conf_.x_block_ == 0); - int x_blocking = nx / conf_.x_block_; - int min_x_blocking = - utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny)); - while (true) { - if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2) - x_blocking /= 2; - else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3) - x_blocking /= 3; - else - break; - } - if (x_blocking >= min_x_blocking * 4) x_blocking = 1; - x_blocking *= conf_.x_block_; - return x_blocking; -} - -template -void cpu_reducer_2d_t::reduce_block(const data_t* space_base, - data_t *dst, int job, int start_y, int start_x, - int ny_start, int nx_start, int ny_step, int nx_step) const { - data_t *d = dst + (start_y + ny_start) * conf_.dst_x_ - + start_x + nx_start; - const data_t *space = space_base + job * balancer().job_size_ - + ny_start * conf_.job_size_x_ + nx_start; -#ifdef SIMPLE_IMPL - for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) { - const data_t *w = &space[idg * space_per_thread(balancer())]; - for (int y = 0; y < ny_step; ++y) - for (int x = 0; x < nx_step; ++x) { - d[y * conf_.dst_x_ + x] - = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x]) - + w[y * conf_.job_size_x_ + x]; - } - } -#else - (*drv_)(d, space, ny_step, nx_step); -#endif -} - -template -void cpu_reducer_2d_t::reduce_nolock(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - bool redundant_reduction = balancer().nthr_per_group_ == 1 - || balancer().idle(ithr); - if (redundant_reduction) return; - - const int id_in_grp = balancer().id_in_group(ithr); - const int njobs_in_grp = balancer().ithr_njobs(ithr); - const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_); - const int global_job_start = balancer().ithr_job_off(ithr); - - const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad); - - const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_); - const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps; - - if (id_in_grp >= pr_grps * pr_nthr_per_grp) - return; /* idle */ - - const int pr_my_grp = id_in_grp / pr_nthr_per_grp; - const int pr_my_id = id_in_grp % pr_nthr_per_grp; - - int pr_job_start{0}, pr_job_end{0}; - balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end); - - for (int j = pr_job_start; j < pr_job_end; ++j) { - const int global_job = global_job_start + j; - const int j_y = global_job / njobs_x; - const int j_x = global_job % njobs_x; - const int start_y = j_y * conf_.job_size_y_; - const int start_x = j_x * conf_.job_size_x_; - const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_); - const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_); - int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp); - - int nxy_start{0}, nxy_end{0}; - balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id, - nxy_start, nxy_end); - if (nxy_start == nxy_end) continue; - nxy_start *= x_blocking; - nxy_end *= x_blocking; - - int nxy = nxy_start; - if (nxy % nx != 0) { - int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy); - reduce_block(space_base, dst, j, start_y, start_x, - nxy / nx, nxy % nx, 1, nx_step); - nxy += nx_step; - } - if ((nxy_end - nxy) > nx) { - int ny_step = (nxy_end - nxy) / nx; - reduce_block(space_base, dst, j, start_y, start_x, - nxy / nx, nxy % nx, ny_step, nx); - nxy += nx * ny_step; - } - if ((nxy_end - nxy) > 0) { - reduce_block(space_base, dst, j, start_y, start_x, - nxy / nx, nxy % nx, 1, nxy_end - nxy); - } - } -} - -template struct cpu_reducer_2d_t; -template struct cpu_reducer_2d_t; - -/* accumulator section */ - -template -cpu_accumulator_1d_t::cpu_accumulator_1d_t(): drv_(nullptr) { - drv_ = create_reduce_2d_drv(1, 0, 0, 0, false); -} - -template -cpu_accumulator_1d_t::~cpu_accumulator_1d_t() { - delete drv_; -} - -template -void cpu_accumulator_1d_t::accumulate(data_t *dst, - const data_t *src, size_t size) { - (*drv_)(dst, src, 1, size); -} - -template struct cpu_accumulator_1d_t; -template struct cpu_accumulator_1d_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp deleted file mode 100644 index 27f5939cd..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp +++ /dev/null @@ -1,334 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REDUCER_HPP -#define CPU_REDUCER_HPP - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "mkldnn_types.h" -#include "nstl.hpp" -#include "type_helpers.hpp" - -#include "cpu_barrier.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -/** class to perform balancing over 3D array - * - * Conceptually the reduction happens according to the picture below: - * - * <--job_size-> - * +-----------+ +-----------+ +-----------+ ^ - * | | | | | | | - * | | | | | | | - * | 1 | | 2 | . . . | njobs | | reduction_size - * | | | | | | | - * | | | | | | | - * +-----------+ +-----------+ +-----------+ v - * - * | | | | | | | | | - * v v v v v v v v v - * ===================================================== vertical reduction - * - * +-----------+ +-----------+ . . . +-----------+ result - * - * In a simple case the result must be contiguous in memory. - * @class cpu_reducer_t is an implementation. - * - * Threads are divided into groups. The groups are independent of each other. - * Each group may work on several jobs (the distribution is not uniform, since - * njobs might be not a multiple of groups). Threads within a group work on - * different parts of the reduction dimension. Thread 0 in each group is called - * master (@sa reduce_balancer_t::master()). - * - * If threading driver does not allow sync between sub-group of threads (e.g. - * Intel(R) TBB) the # of thread per group is enforced to be 1. - */ -struct reduce_balancer_t { - reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */ - reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size, - size_t max_buffer_size) - { init(nthr, job_size, njobs, reduction_size, max_buffer_size); } - - reduce_balancer_t &init(int nthr, int job_size, int njobs, - int reduction_size, size_t max_buffer_size) - { - syncable_ = mkldnn_thr_syncable(); - nthr_ = nthr; - job_size_ = job_size; - njobs_ = njobs; - reduction_size_ = reduction_size; - max_buffer_size_ = max_buffer_size; - balance(); - return *this; - } - - bool syncable_; - int nthr_; - int job_size_, njobs_, reduction_size_; - - int ngroups_; /** number of independent work (thread) groups */ - int nthr_per_group_; /** number of threads within a single work group */ - int njobs_per_group_ub_; /** the max # of jobs within a work group */ - - bool master(int ithr) const { return id_in_group(ithr) == 0; } - bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; } - - int group_id(int ithr) const { return ithr / nthr_per_group_; } - int id_in_group(int ithr) const { return ithr % nthr_per_group_; } - - int grp_njobs(int grp) const { - if (grp >= ngroups_) return 0; - return njobs_ / ngroups_ + (grp < njobs_ % ngroups_); - } - int grp_job_off(int grp) const { - if (grp >= ngroups_) return njobs_; - return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_); - } - - int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); } - int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); } - -private: - size_t max_buffer_size_; - void balance(); -}; - -/** forward declaration of reduce driver */ -template struct reducer_2d_driver_t; - -/** class to perform a reduction over 3D array - * - * Balancing is based on @class reduce_balancer_t. - * Restrictions: the result of the reduction must be contiguous in memory. * - * The reduction happens according to the picture below (once more): - * - * <--job_size-> - * +-----------+ +-----------+ +-----------+ ^ - * | | | | | | | - * | | | | | | | - * | 1 | | 2 | . . . | njobs | | reduction_size - * | | | | | | | - * | | | | | | | - * +-----------+ +-----------+ +-----------+ v - * - * | | | | | | | | | - * v v v v v v v v v - * ===================================================== vertical reduction - * - * +-----------+ +-----------+ . . . +-----------+ (contiguous) result - * - * An example how work might be shared is shown below. - * - * In this example group 0 owns 2 (independent) jobs -- 2 big squares. - * The number of threads per group is also 2 (thread 0 of group 0 and thread 1 - * of group 0). Master threads (i.e. threads with id 0 in corresponding group) - * from each group put the partial result directly into destination memory, - * while all the other threads with-in the group use workspace (on the picture - * the only thread 1). Once intermediate results obtained each group reduces - * corresponding part (own jobs) to the destination memory. - * - * <------- group 0 -------> - * - * +-----------+ +-----------+ ^ - * | | | | | thread 0 of reduces to the dest-memory - * | | | | | group 0 +-----------+ +-----------+ - * |- - - - - -| |- - - - - -| X - * | | | | | thread 1 of reduces to workspace[tid=1]: - * | | | | | group 0 +-----------+ +-----------+ - * +-----------+ +-----------+ v - * | | | | | | - * v v v v v v - * ((barrier)) ============================= - * - * dest-memory: +-----------+ +-----------+ - */ -template -struct cpu_reducer_t { - typedef typename prec_traits::type data_t; - - struct conf_t { - conf_t() = default; - conf_t &init(const reduce_balancer_t &balancer) - { balancer_ = balancer; return *this; } - - void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; - - reduce_balancer_t balancer_; - }; - - cpu_reducer_t(const conf_t &conf); - ~cpu_reducer_t(); - - /** initializes reducer. - * Must be called from a single thread prior to actual usage */ - void init(const memory_tracking::grantor_t &scratchpad) const { - if (balancer().nthr_per_group_ == 1) return; - - auto bctx = scratchpad.template get( - memory_tracking::names::key_reducer_space_bctx); - for (int i = 0; i < balancer().ngroups_; ++i) - simple_barrier::ctx_init(&bctx[i]); - } - - /** for given thread returns the pointer where to put partial results. - * Reduction destination @p dst must be provided as well (master threads - * from each group will use it for partial result to reduce memory - * pressure). - * - * @note: job offset is already applied by get_local_ptr(), which means all - * threads should start writing from the very beginning of returned - * address. - */ - data_t *get_local_ptr(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; - - /** performs the reduction with built-in synchronization. */ - void reduce(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - bool redundant_reduction = balancer().nthr_per_group_ == 1 - || balancer().idle(ithr); - if (redundant_reduction) return; - - auto bctx = scratchpad.template get( - memory_tracking::names::key_reducer_space_bctx); - simple_barrier::barrier(&bctx[balancer().group_id(ithr)], - balancer().nthr_per_group_); - - reduce_nolock(ithr, dst, scratchpad); - } - - const reduce_balancer_t &balancer() const { return conf_.balancer_; } - -private: - static size_t space_per_thread(const reduce_balancer_t &balancer) - { return balancer.njobs_per_group_ub_ * balancer.job_size_; } - - /* The scratchpad is organized as follows: - * - * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; - * simple_barrier::ctx_t barriers[groups_]; */ - - const conf_t conf_; - reducer_2d_driver_t *drv_; - - void reduce_nolock(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; -}; - -template -struct cpu_reducer_2d_t { - typedef typename prec_traits::type data_t; - - struct conf_t { - conf_t() = default; - conf_t &init(const reduce_balancer_t &balancer, int job_size_x, - int job_size_y, int x_block, int dst_x, int dst_y) { - balancer_ = balancer; - job_size_x_ = job_size_x; - job_size_y_ = job_size_y; - x_block_ = x_block; - dst_x_ = dst_x; - dst_y_ = dst_y; - return *this; - } - - void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; - - reduce_balancer_t balancer_; - int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_; - }; - - cpu_reducer_2d_t(const conf_t &conf); - ~cpu_reducer_2d_t(); - - /** initializes reducer. - * Must be called from a single thread prior to actual usage */ - void init(const memory_tracking::grantor_t &scratchpad) const { - if (balancer().nthr_per_group_ == 1) return; - - auto bctx = scratchpad.template get( - memory_tracking::names::key_reducer_space_bctx); - for (int i = 0; i < balancer().ngroups_; ++i) - simple_barrier::ctx_init(&bctx[i]); - } - - /** for given thread returns the pointer where to put partial results */ - data_t *get_local_ptr(int ithr, - const memory_tracking::grantor_t &scratchpad) const; - - /** performs the reduction with built-in synchronization. */ - void reduce(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - bool redundant_reduction = balancer().nthr_per_group_ == 1 - || balancer().idle(ithr); - if (redundant_reduction) return; - - auto bctx = scratchpad.template get( - memory_tracking::names::key_reducer_space_bctx); - simple_barrier::barrier(&bctx[balancer().group_id(ithr)], - balancer().nthr_per_group_); - - reduce_nolock(ithr, dst, scratchpad); - } - - const reduce_balancer_t &balancer() const { return conf_.balancer_; } - -private: - static size_t space_per_thread(const reduce_balancer_t &balancer) - { return balancer.njobs_per_group_ub_ * balancer.job_size_; } - - /* The scratchpad is organized as follows: - * - * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; - * simple_barrier::ctx_t barriers[groups_]; */ - - const conf_t conf_; - reducer_2d_driver_t *drv_; - - int choose_x_blocking(int nx, int ny, int nthr_per_grp) const; - void reduce_block(const data_t* space_base, data_t *dst, - int job, int start_y, int start_x, - int ny_start, int nx_start, int ny_step, int nx_step) const; - void reduce_nolock(int ithr, data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; -}; - -/** simple 1d accumulator: y[:] += x[:] */ -template -struct cpu_accumulator_1d_t { - typedef typename prec_traits::type data_t; - - cpu_accumulator_1d_t(); - ~cpu_accumulator_1d_t(); - void accumulate(data_t *dst, const data_t *src, size_t size); - - reducer_2d_driver_t *drv_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp deleted file mode 100644 index 82be70353..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp +++ /dev/null @@ -1,262 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "cpu_engine.hpp" -#include "cpu_primitive.hpp" -#include "cpu_reorder_pd.hpp" -#include "cpu_memory.hpp" -#include "type_helpers.hpp" - -#include "cpu/jit_uni_reorder.hpp" -#include "cpu/simple_reorder.hpp" -#include "cpu/wino_reorder.hpp" -#include "cpu/rnn/rnn_reorders.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; - -namespace { -using namespace mkldnn::impl::data_type; -using namespace mkldnn::impl::format_tag; - -#define REG_SR(idt, ifmt, odt, ofmt, ...) \ - simple_reorder_t::pd_t::create - -#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ - REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ - REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) - -#define REG_SR_DIRECT_COPY(idt, odt) \ - REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ - REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) - -static const rpd_create_f cpu_reorder_impl_list[] = { - /* winograd */ - wino_reorder_t::pd_t::create, - //wino_reorder_t::pd_t::create, - - /* rnn reorders */ - rnn_data_reorder_t::pd_t::create, - rnn_weights_reorder_t::pd_t::create, - rnn_weights_reorder_t::pd_t::create, - - /* conv reorders w/ compensation */ - REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), - REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), - - REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), - - REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), - - REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), - - REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), - - REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), - REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), - REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), - - /* regular reorders */ - -#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) - /* Direct copy for icc which is faster than jitted code; - * Direct copy for gcc which might or might not be faster than jitted - * code, but still worth it because doesn't require jitting, i.e. much - * faster creation time. This is tentative solution and should be removed - * later (when we will cache jitted code?...). */ - REG_SR_DIRECT_COPY(f32, f32), -#endif - -#ifdef __INTEL_COMPILER - /* direct copy for icc, which is faster than jitted code */ - /* - REG_SR_DIRECT_COPY(f32, s32), - REG_SR_DIRECT_COPY(f32, s8), - REG_SR_DIRECT_COPY(f32, u8), - REG_SR_DIRECT_COPY(s32, f32), - REG_SR_DIRECT_COPY(s32, s32), - REG_SR_DIRECT_COPY(s32, s8), - REG_SR_DIRECT_COPY(s32, u8), - REG_SR_DIRECT_COPY(s8, f32), - REG_SR_DIRECT_COPY(s8, s32), - REG_SR_DIRECT_COPY(s8, s8), - REG_SR_DIRECT_COPY(s8, u8), - REG_SR_DIRECT_COPY(u8, f32), - REG_SR_DIRECT_COPY(u8, s32), - REG_SR_DIRECT_COPY(u8, s8), - REG_SR_DIRECT_COPY(u8, u8), - */ -#endif - - /* jit */ - jit_uni_reorder_create, - - /* fp32: flat <-> blocked with tail */ - /* - REG_SR_BIDIR(f32, any, f32, nCw4c), - REG_SR_BIDIR(f32, any, f32, nCw8c), - REG_SR_BIDIR(f32, any, f32, OIw4i4o), - REG_SR_BIDIR(f32, any, f32, OIw8i8o), - REG_SR_BIDIR(f32, any, f32, OIw8o8i), - REG_SR_BIDIR(f32, any, f32, gOIw4i4o), - REG_SR_BIDIR(f32, any, f32, gOIw8i8o), - REG_SR_BIDIR(f32, any, f32, gOIw8o8i), - - REG_SR_BIDIR(f32, any, f32, nCw16c), - REG_SR_BIDIR(f32, any, f32, OIw16o16i), - REG_SR_BIDIR(f32, any, f32, OIw16i16o), - REG_SR_BIDIR(f32, any, f32, IOw16o16i), - REG_SR_BIDIR(f32, any, f32, gOIw16o16i), - REG_SR_BIDIR(f32, any, f32, gOIw16i16o), - REG_SR_BIDIR(f32, any, f32, gIOw16o16i), - - REG_SR_BIDIR(f32, any, f32, nChw4c), - REG_SR_BIDIR(f32, any, f32, nChw8c), - REG_SR_BIDIR(f32, any, f32, OIhw4i4o), - REG_SR_BIDIR(f32, any, f32, Ohwi8o), - - REG_SR_BIDIR(f32, any, f32, OIhw8i8o), - REG_SR_BIDIR(f32, any, f32, OIhw8o8i), - REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), - REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), - REG_SR_BIDIR(f32, any, f32, gOhwi8o), - REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), - REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), - - REG_SR_BIDIR(f32, any, f32, nChw16c), - REG_SR_BIDIR(f32, any, f32, Oihw4o), - REG_SR_BIDIR(f32, any, f32, Oihw16o), - REG_SR_BIDIR(f32, any, f32, Ohwi4o), - REG_SR_BIDIR(f32, any, f32, Ohwi16o), - REG_SR_BIDIR(f32, any, f32, OIhw16o16i), - REG_SR_BIDIR(f32, any, f32, OIhw16i16o), - REG_SR_BIDIR(f32, any, f32, IOhw16o16i), - REG_SR_BIDIR(f32, any, f32, gOihw4o), - REG_SR_BIDIR(f32, any, f32, gOihw16o), - REG_SR_BIDIR(f32, any, f32, gOhwi4o), - REG_SR_BIDIR(f32, any, f32, gOhwi16o), - REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), - REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), - REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), - - REG_SR_BIDIR(f32, any, f32, nCdhw4c), - REG_SR_BIDIR(f32, any, f32, nCdhw8c), - REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), - REG_SR_BIDIR(f32, any, f32, Odhwi8o), - REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), - REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), - REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), - REG_SR_BIDIR(f32, any, f32, gOdhwi8o), - REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), - REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), - - REG_SR_BIDIR(f32, any, f32, nCdhw16c), - REG_SR_BIDIR(f32, any, f32, Oidhw4o), - REG_SR_BIDIR(f32, any, f32, Oidhw16o), - REG_SR_BIDIR(f32, any, f32, Odhwi16o), - REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), - REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), - REG_SR_BIDIR(f32, any, f32, gOidhw4o), - REG_SR_BIDIR(f32, any, f32, gOidhw16o), - REG_SR_BIDIR(f32, any, f32, gOdhwi16o), - REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), - REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), - */ - - /* fp32: blocked <-> blocked with tail */ - REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), - REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), - REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), - - /* int: flat <-> blocked with tail */ - /* - REG_SR_BIDIR(f32, any, s32, nChw16c), - REG_SR_BIDIR(f32, any, s8, nChw16c), - REG_SR_BIDIR(f32, any, u8, nChw16c), - REG_SR_BIDIR(s32, any, f32, nChw16c), - REG_SR_BIDIR(s32, any, s32, nChw16c), - REG_SR_BIDIR(s32, any, s8, nChw16c), - REG_SR_BIDIR(s32, any, u8, nChw16c), - REG_SR_BIDIR(s8, any, f32, nChw16c), - REG_SR_BIDIR(s8, any, s32, nChw16c), - REG_SR_BIDIR(s8, any, s8, nChw16c), - REG_SR_BIDIR(s8, any, u8, nChw16c), - REG_SR_BIDIR(u8, any, f32, nChw16c), - REG_SR_BIDIR(u8, any, s32, nChw16c), - REG_SR_BIDIR(u8, any, s8, nChw16c), - REG_SR_BIDIR(u8, any, u8, nChw16c), - - REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), - REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), - REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), - REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), - REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), - REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), - REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), - REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), - */ - - /* reference: the last line of defence */ - /* - REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), - REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), - REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), - REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), - - REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), - REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), - REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), - REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), - - REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), - REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), - REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), - REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), - - REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), - REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), - REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), - REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), - */ - - /* eol */ - nullptr, -}; -} - -const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { - return cpu_reorder_impl_list; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp deleted file mode 100644 index 1622eb684..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp +++ /dev/null @@ -1,48 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REORDER_PD_HPP -#define CPU_REORDER_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "reorder_pd.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_reorder_pd_t: public reorder_pd_t { - using reorder_pd_t::reorder_pd_t; - - status_t init() { - const auto &post_ops = attr()->post_ops_; - bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1 - && post_ops.entry_[0].kind == primitive_kind::sum); - scratchpad_engine_ = src_engine_; - return args_ok ? status::success : status::unimplemented; - } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp deleted file mode 100644 index f16587b99..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_SHUFFLE_PD_HPP -#define CPU_SHUFFLE_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "shuffle_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_shuffle_pd_t: public shuffle_pd_t { - using shuffle_pd_t::shuffle_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp deleted file mode 100644 index 3a39eab97..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_SOFTMAX_PD_HPP -#define CPU_SOFTMAX_PD_HPP - -#include - -#include "c_types_map.hpp" -#include "softmax_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t { - using softmax_fwd_pd_t::softmax_fwd_pd_t; -}; - -struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t { - using softmax_bwd_pd_t::softmax_bwd_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp deleted file mode 100644 index 1ab5d9f17..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp +++ /dev/null @@ -1,48 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "cpu_engine.hpp" - -/* -#include "cpu/ref_sum.hpp" -#include "cpu/simple_sum.hpp" -*/ - -namespace mkldnn { -namespace impl { -namespace cpu { - -using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f; - -namespace { -#define INSTANCE(...) __VA_ARGS__::pd_t::create -static const spd_create_f cpu_sum_impl_list[] = { - /* - INSTANCE(simple_sum_t), - INSTANCE(ref_sum_t), - */ - nullptr, -}; -#undef INSTANCE -} - -const spd_create_f *cpu_engine_t::get_sum_implementation_list() const { - return cpu_sum_impl_list; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp deleted file mode 100644 index 0965129f9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp +++ /dev/null @@ -1,39 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_SUM_PD_HPP -#define CPU_SUM_PD_HPP - -#include "c_types_map.hpp" -#include "sum_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_sum_pd_t: public sum_pd_t { - using sum_pd_t::sum_pd_t; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp deleted file mode 100644 index a9810dec2..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp +++ /dev/null @@ -1,372 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ -#include - -#include "mkldnn_thread.hpp" -#include "utils.hpp" -#include "gemm_utils_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { -namespace gemm_utils { -#define BM_NOCOPY_AVX 64 -#define BN_NOCOPY_AVX 48 -#define BK_NOCOPY_AVX 384 -#define BN_LARGE_NOCOPY_AVX 192 -#define BM_SMALL_NOCOPY_AVX 16 -#define BN_SMALL_NOCOPY_AVX 1 -#define BK_SMALL_NOCOPY_AVX 4 -// Determine number of threads for each dimension of a 3-D partitioning -// algorithm based on input parameters -// m/n/k - First/second/third parameter for GEMM -// nthrs - total available number of threads -// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension -// BM/BN/BK - blocking values -void calc_nthr_nocopy_avx(int m, int n, int k, - int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, - int *BK) -{ - int nthr, nthr_m, nthr_n, nthr_k; - int MB, NB, KB; - - nthr = nthrs; - nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX; - nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX; - nthr_k = 1; - - // Partition along K dimension - // - if threading allows having barriers (e.g. OMP) - // - if there is not enough parallelism along M or N - if (mkldnn_thr_syncable()) { - int nthr_other = nthr_k = 1; - while ((nthr_m * nthr_n * nthr_other < nthr) - && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) { - nthr_other++; - if ((nthr / nthr_other) * nthr_other > 0.9 * nthr) - nthr_k = nthr_other; - } - } - nthr /= nthr_k; - - if (nthr_m == 1) - nthr_n = nthr; - if (nthr_n == 1) - nthr_m = nthr; - - // Simple partition reduction - while (nthr_m * nthr_n > nthr) - if (nthr_m > nthr_n) - nthr_m--; - else - nthr_n--; - while (nthr_m * nthr_n < nthr) - if (nthr_m < nthr_n) - nthr_m++; - else - nthr_n++; - - if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) { - - if (nthr_m <= nthr_n) { - nthr_m = (int)sqrt((double)nthr); - if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX) - nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX; - nthr_n = nthr / nthr_m; - - while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { - nthr_m--; - nthr_n = nthr / nthr_m; - } - } else { - nthr_n = (int)sqrt((double)nthr); - if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX) - nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX; - nthr_m = nthr / nthr_n; - - while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { - nthr_n--; - nthr_m = nthr / nthr_n; - } - } - } - - MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1; - MB -= MB % BM_SMALL_NOCOPY_AVX; - NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1; - NB -= NB % BN_SMALL_NOCOPY_AVX; - KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1; - KB -= KB % BK_SMALL_NOCOPY_AVX; - - if (MB * nthr_m > m) - nthr_m = (m + MB - 1) / MB; - if (NB * nthr_n > n) - nthr_n = (n + NB - 1) / NB; - if (KB * nthr_k > k) - nthr_k = (k + KB - 1) / KB; - - *nthrs_m = nthr_m; - *nthrs_n = nthr_n; - *nthrs_k = nthr_k; - - *BM = MB; - *BN = NB; - *BK = KB; -} -#undef BM_NOCOPY_AVX -#undef BN_NOCOPY_AVX -#undef BK_NOCOPY_AVX -#undef BN_LARGE_NOCOPY_AVX -#undef BM_SMALL_NOCOPY_AVX -#undef BN_SMALL_NOCOPY_AVX -#undef BK_SMALL_NOCOPY_AVX - -#define BM_NOCOPY_AVX512_COMMON 32 -#define BN_NOCOPY_AVX512_COMMON 64 -#define BK_NOCOPY_AVX512_COMMON 192 -#define BN_LARGE_NOCOPY_AVX512_COMMON 192 -#define BM_SMALL_NOCOPY_AVX512_COMMON 16 -#define BN_SMALL_NOCOPY_AVX512_COMMON 1 -#define BK_SMALL_NOCOPY_AVX512_COMMON 4 -// Determine number of threads for each dimension of a 3-D partitioning -// algorithm based on input parameters -// m/n/k - First/second/third parameter for GEMM -// nthrs - total available number of threads -// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension -// BM/BN/BK - blocking values -void calc_nthr_nocopy_avx512_common(int m, - int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, - int *BM, int *BN, int *BK) -{ - int nthr, nthr_m, nthr_n, nthr_k = 1; - int MB, NB, KB; - nthr = nthrs; - - int counter = 0; - float ratio_float = 1.; - int ratio = 1; - nthr = nthrs; - int nthr_m_gt_n; - - // Partition along K dimension - // - if threading allows having barriers (e.g. OMP) - // - if there is not enough parallelism along M or N - if (mkldnn_thr_syncable()) { - if (n <= 2 * BN_NOCOPY_AVX512_COMMON && - m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) { - nthr_k = k / BK_NOCOPY_AVX512_COMMON; - if (nthr_k > nthr / 4) - nthr_k = nthr / 4; - if (nthr_k < 1) - nthr_k = 1; - - while ((nthr_k > 1) && (nthr % nthr_k)) { - nthr_k--; - } - nthr /= nthr_k; - } else { - nthr_k = 1; - } - } - nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON; - nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON; - - if (nthr_m < 1) - nthr_m = 1; - if (nthr_n < 1) - nthr_n = 1; - - nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0; - ratio_float = (float)nthr_m / nthr_n; - - if (nthr_m_gt_n) - ratio = (int)ratio_float; - else - ratio = (int)(1. / ratio_float); - - // scale down nthr_m and nthr_n if they are too large - while (nthr_m * nthr_n > 4 * nthr) { - nthr_m /= 2; - nthr_n /= 2; - } - - if (nthr_m < 1) - nthr_m = 1; - if (nthr_n < 1) - nthr_n = 1; - - // Simple partition reduction - counter = 0; - while (nthr_m * nthr_n > nthr) { - if (nthr_m > nthr_n) { - if (counter < ratio) - nthr_m--; - else { - nthr_n--; - counter = -1; - } - } else { - if (counter < ratio) - nthr_n--; - else { - nthr_m--; - counter = -1; - } - } - counter++; - } - - // Simple partition increment - counter = 0; - while (nthr_m * nthr_n < 0.95 * nthr) { - if (nthr_m > nthr_n) { - if (counter < ratio) - nthr_m++; - else { - nthr_n++; - counter = -1; - } - } else { - if (counter < ratio) - nthr_n++; - else { - nthr_m++; - counter = -1; - } - } - counter++; - } - - // if nothing works out, then this should work - if ((nthr_m * nthr_n > nthr)) { - - if (nthr_m <= nthr_n) { - nthr_m = (int)sqrt((double)nthr); - if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) - / BM_SMALL_NOCOPY_AVX512_COMMON) - nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) - / BM_SMALL_NOCOPY_AVX512_COMMON; - nthr_n = nthr / nthr_m; - - while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { - nthr_m--; - nthr_n = nthr / nthr_m; - } - } else { - nthr_n = (int)sqrt((double)nthr); - if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) - / BN_SMALL_NOCOPY_AVX512_COMMON) - nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) - / BN_SMALL_NOCOPY_AVX512_COMMON; - nthr_m = nthr / nthr_n; - - while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { - nthr_n--; - nthr_m = nthr / nthr_n; - } - } - } - - MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1; - MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON; - NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1; - NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON; - KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1; - KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON; - - if (MB * nthr_m > m) - nthr_m = (m + MB - 1) / MB; - if (NB * nthr_n > n) - nthr_n = (n + NB - 1) / NB; - if (KB * nthr_k > k) - nthr_k = (k + KB - 1) / KB; - - *nthrs_m = nthr_m; - *nthrs_n = nthr_n; - *nthrs_k = nthr_k; - - *BM = MB; - *BN = NB; - *BK = KB; -} -#undef BM_NOCOPY_AVX512_COMMON -#undef BN_NOCOPY_AVX512_COMMON -#undef BK_NOCOPY_AVX512_COMMON -#undef BN_LARGE_NOCOPY_AVX512_COMMON -#undef BM_SMALL_NOCOPY_AVX512_COMMON -#undef BN_SMALL_NOCOPY_AVX512_COMMON -#undef BK_SMALL_NOCOPY_AVX512_COMMON - -// Partition n values as equally as possible among nthr threads -// and set the offset (t_offset) and number of values (t_block) for ithr -// Assumption: 0 <= ithr < nthr -void partition_unit_diff( - int ithr, int nthr, int n, int *t_offset, int *t_block) -{ - int band = n / nthr; - if (band == 0) - band = 1; - int tail = n - band * nthr; - if (tail < 0) - tail = 0; - - if (ithr < tail) { - band++; - *t_offset = band * ithr; - *t_block = band; - } else { - *t_offset = band * ithr + tail; - *t_block = band; - } - - if (*t_offset >= n) { - *t_offset = 0; - *t_block = 0; - } - - if (*t_offset + *t_block > n) { - *t_block = n - *t_offset; - } -} - -// Sum the m*n values from p_src into p_dst, assuming the two-dimensional -// arrays have leading dimensions ld_src and ld_dst, respectively -template -void sum_two_matrices(int m, int n, - data_t * __restrict p_src, dim_t ld_src, - data_t * __restrict p_dst, dim_t ld_dst) -{ - int i, j; - for (j = 0; j < n; j++) { - for (i = 0; i < m; i++) { - p_dst[i + j * ld_dst] += p_src[i + j * ld_src]; - } - } -} - -template -void sum_two_matrices(int m, int n, - float * __restrict p_src, dim_t ld_src, - float * __restrict p_dst, dim_t ld_dst); - -template -void sum_two_matrices(int m, int n, - double * __restrict p_src, dim_t ld_src, - double * __restrict p_dst, dim_t ld_dst); -} -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp deleted file mode 100644 index 3352298b4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef GEMM_UTILS_HPP -#define GEMM_UTILS_HPP - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace gemm_utils { -// Alias for any dimension related variable. -typedef ptrdiff_t dim_t; - -template -struct gemm_traits {}; - -template -struct gemm_traits { - static constexpr int m = 8; - static constexpr int n = 6; - static constexpr int BM = 4032; - static constexpr int BN = isTransA ? 96 : 192; - static constexpr int BK = isTransB ? 96 : 512; -}; - -template -struct gemm_traits { - static constexpr int m = 16; - static constexpr int n = 6; - static constexpr int BM = 4032; - static constexpr int BN = isTransA ? 96 : 48; - static constexpr int BK = isTransB ? 96 : 256; -}; - -template -using unroll_factor = gemm_traits; - -template -void sum_two_matrices(int m, int n, - data_t * __restrict p_src, dim_t ld_src, - data_t * __restrict p_dst, dim_t ld_dst); - -void calc_nthr_nocopy_avx512_common(int m, - int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, - int *BM, int *BN, int *BK); - -void calc_nthr_nocopy_avx(int m, int n, int k, - int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, - int *BK); - -void partition_unit_diff( - int ithr, int nthr, int n, int *t_offset, int *t_block); -}; - -} -} -} -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp deleted file mode 100644 index d7be43e39..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp +++ /dev/null @@ -1,2131 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "ref_gemm_f32.hpp" -#include "gemm_utils_f32.hpp" -#include "jit_avx512_common_gemm_f32.hpp" - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -#define CACHE_LINE_SIZE 64 - -#define STACKSIZE get_size_of_abi_save_regs() -#ifdef _WIN32 -#define STACK_K_CAPACITY 32 -#else -#define STACK_K_CAPACITY 2048 -#endif -#define SIZE 4 -#define OFFSET 128 -#define BASE_SHIFT 2 -#define SECOND_FETCH unroll_n -#define UNROLL_M 48 -#define UNROLL_N 8 - -namespace avx512_common_gemm_f32 { -using namespace gemm_utils; - -struct xbyak_gemm : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm) - - xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, - void *code_ptr = nullptr, - size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) - : jit_generator(code_ptr, code_size) - { - using namespace Xbyak; - - enum { ver_avx512_core, ver_avx512_mic } ver = - mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic; - - bool isBeta0 = (beta == 0.0); - bool isBetaN = (!isBeta0 && beta != 1.0); - - // various definitions for convenience - auto ARG_M = abi_param1; - auto ARG_N = abi_param2; - auto K = abi_param3; - auto ARG_ALPHA = abi_param4; -#ifdef _WIN32 - auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; - auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + - sizeof(float *) + STACKSIZE]; - const auto stackOffset = OFFSET_SHADOWSPACE + - sizeof(float *) + STACKSIZE; - auto A = rsi; - auto LDA = rdi; -#else - auto ARG_A = r8; - auto ARG_LDA = r9; - const auto stackOffset = STACKSIZE; - auto A = ARG_A; - auto LDA = ARG_LDA; -#endif - auto ARG_B = ptr[rsp + 8 + stackOffset]; - auto ARG_LDB = ptr[rsp + 16 + stackOffset]; - auto ARG_BETA = ptr[rsp + 24 + stackOffset]; - auto ARG_C = ptr[rsp + 32 + stackOffset]; - auto ARG_LDC = ptr[rsp + 40 + stackOffset]; - auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; - auto ARG_WS = ptr[rsp + 56 + stackOffset]; - - auto B = r11; - auto LDB = rbx; - auto LDC = r13; - auto LL = rax; - auto AO1 = abi_param2; - auto BO1 = abi_param4; - auto BO2 = rbp; - auto CO1 = r14; - auto CO2 = r15; - auto LDB3 = r10; - auto LDA4 = abi_param1; - auto AA = r12; - auto BIAS1 = abi_param1; - - auto M = qword[rsp + 0]; - auto N = qword[rsp + 8]; - auto FLAG = qword[rsp + 16]; - auto I = qword[rsp + 24]; - auto C = qword[rsp + 32]; - auto BIAS = qword[rsp + 40]; - auto ALPHA = qword[rsp + 48]; - auto BETA = qword[rsp + 64]; - auto ORIG_A = qword[rsp + 80]; - auto ORIG_SP = qword[rsp + 120]; - - auto ZSTRIDE = zmm4; - auto VALPHA = zmm6; - auto VBETA = zmm7; - auto VBIAS1 = zmm1; - auto VBIAS2 = zmm2; - auto VBIAS3 = zmm3; - - auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80; - auto PREFETCHSIZEB = 16; - - Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15, - zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, - zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 }; - - // Function for packing if needed - auto do_pack = [&](int unroll_m) { - Label pack2, pack3, pack4, pack10; - - mov(BO1, A); - lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); - mov(LL, K); - sar(LL, 2); - jle(pack3, T_NEAR); - align(16); - - L(pack2); - if (!isTransA) { - for (int i = 0; i < 4; i++) { - vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); - if (unroll_m > 16) - vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); - if (unroll_m > 32) - vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); - add(BO1, LDA); - - vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE] - | k1, - zmm0); - if (unroll_m > 16) - vmovups(ptr[AO1 - + (unroll_m * i + 1 * 16 - OFFSET) - * SIZE] - | k2, - zmm1); - if (unroll_m > 32) - vmovups(ptr[AO1 - + (unroll_m * i + 2 * 16 - OFFSET) - * SIZE] - | k3, - zmm2); - } - } else { - for (int i = 0; i < 4; i++) { - kmovw(k4, k1); - vgatherqps(ymm5 | k4, - ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]); - lea(BO2, ptr[BO1 + LDA * 8]); - kshiftrw(k4, k1, 8); - vgatherqps(ymm6 | k4, - ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); - vshuff64x2(zmm0, zmm5, zmm6, 0x44); - - if (unroll_m > 16) { - lea(BO2, ptr[BO2 + LDA * 8]); - kmovw(k4, k2); - vgatherqps(ymm5 | k4, - ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 8]); - kshiftrw(k4, k2, 8); - vgatherqps(ymm6 | k4, - ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); - vshuff64x2(zmm1, zmm5, zmm6, 0x44); - } - - if (unroll_m > 32) { - lea(BO2, ptr[BO2 + LDA * 8]); - kmovw(k4, k3); - vgatherqps(ymm5 | k4, - ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 8]); - kshiftrw(k4, k3, 8); - vgatherqps(ymm6 | k4, - ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 8]); - vshuff64x2(zmm2, zmm5, zmm6, 0x44); - } - - vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE], - zmm0 | k1); - if (unroll_m > 16) - vmovups(ptr[AO1 - + (unroll_m * i + 1 * 16 - OFFSET) - * SIZE], - zmm1 | k2); - if (unroll_m > 32) - vmovups(ptr[AO1 - + (unroll_m * i + 2 * 16 - OFFSET) - * SIZE], - zmm2 | k3); - } - add(BO1, 4 * SIZE); - } - add(AO1, unroll_m * 4 * SIZE); - - sub(LL, 1); - jg(pack2, T_NEAR); - align(16); - - L(pack3); - mov(LL, K); - and_(LL, 3); - jle(pack10, T_NEAR); - align(16); - - L(pack4); - if (!isTransA) { - vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); - if (unroll_m > 16) - vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); - if (unroll_m > 32) - vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); - add(BO1, LDA); - } else { - kmovw(k4, k1); - vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]); - lea(BO2, ptr[BO1 + LDA * 8]); - kshiftrw(k4, k1, 8); - vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); - vshuff64x2(zmm0, zmm5, zmm6, 0x44); - - if (unroll_m > 16) { - lea(BO2, ptr[BO2 + LDA * 8]); - kmovw(k4, k2); - vgatherqps(ymm5 | k4, - ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 8]); - kshiftrw(k4, k2, 8); - vgatherqps(ymm6 | k4, - ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); - vshuff64x2(zmm1, zmm5, zmm6, 0x44); - } - - if (unroll_m > 32) { - lea(BO2, ptr[BO2 + LDA * 8]); - kmovw(k4, k3); - vgatherqps(ymm5 | k4, - ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 8]); - kshiftrw(k4, k3, 8); - vgatherqps(ymm6 | k4, - ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 8]); - vshuff64x2(zmm2, zmm5, zmm6, 0x44); - } - add(BO1, SIZE); - } - - vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], - zmm0 | k1); - if (unroll_m > 16) - vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE], - zmm1 | k2); - if (unroll_m > 32) - vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE], - zmm2 | k3); - - add(AO1, unroll_m * SIZE); - sub(LL, 1); - jg(pack4, T_NEAR); - align(16); - - L(pack10); - }; - - // Function to update C, covering masking and other considerations - auto update = [&](Zmm reg, bool useCO1, int offset, int mask, - bool useScale = false) { - vmulps(reg, reg, VALPHA); - if (!isBeta0) { - if (!useScale) { - switch (mask) { - case 0: - if (useCO1) - vmovups(zmm0, ptr[CO1 + offset * SIZE]); - else - vmovups(zmm0, ptr[CO2 + offset * SIZE]); - break; - case 1: - if (useCO1) - vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]); - else - vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]); - break; - case 2: - if (useCO1) - vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]); - else - vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]); - break; - case 3: - if (useCO1) - vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]); - else - vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]); - break; - } - } else { - switch (mask) { - case 0: - if (useCO1) - vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]); - else - vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]); - break; - case 1: - if (useCO1) - vmovups(zmm0 | k1 | T_z, - ptr[CO1 + LDC + offset * SIZE]); - else - vmovups(zmm0 | k1 | T_z, - ptr[CO2 + LDC + offset * SIZE]); - break; - case 2: - if (useCO1) - vmovups(zmm0 | k2 | T_z, - ptr[CO1 + LDC + offset * SIZE]); - else - vmovups(zmm0 | k2 | T_z, - ptr[CO2 + LDC + offset * SIZE]); - break; - case 3: - if (useCO1) - vmovups(zmm0 | k3 | T_z, - ptr[CO1 + LDC + offset * SIZE]); - else - vmovups(zmm0 | k3 | T_z, - ptr[CO2 + LDC + offset * SIZE]); - break; - } - } - if (!isBetaN) { - vaddps(zmm0, reg, zmm0); - } else { - vfmadd132ps(zmm0, reg, VBETA); - } - if (!useScale) { - switch (mask) { - case 0: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], zmm0); - else - vmovups(ptr[CO2 + offset * SIZE], zmm0); - break; - case 1: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1); - else - vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1); - break; - case 2: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2); - else - vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2); - break; - case 3: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3); - else - vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3); - break; - } - } else { - switch (mask) { - case 0: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0); - break; - case 1: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1); - break; - case 2: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2); - break; - case 3: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3); - break; - } - } - } else { - if (!useScale) { - switch (mask) { - case 0: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], reg); - else - vmovups(ptr[CO2 + offset * SIZE], reg); - break; - case 1: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], reg | k1); - else - vmovups(ptr[CO2 + offset * SIZE], reg | k1); - break; - case 2: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], reg | k2); - else - vmovups(ptr[CO2 + offset * SIZE], reg | k2); - break; - case 3: - if (useCO1) - vmovups(ptr[CO1 + offset * SIZE], reg | k3); - else - vmovups(ptr[CO2 + offset * SIZE], reg | k3); - break; - } - } else { - switch (mask) { - case 0: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], reg); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], reg); - break; - case 1: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1); - break; - case 2: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2); - break; - case 3: - if (useCO1) - vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3); - else - vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3); - break; - } - } - } - vpxorq(reg, reg, reg); - }; - - // Loop with unroll_n - 2 FMAs; called by innerkernel - auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) { - for (int i = 2; i < unroll_n; i++) { - if (ver == ver_avx512_core) { - if (!isTransB) { - switch (i) { - case 2: - vbroadcastss( - zmm3, - ptr[BO1 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - break; - case 3: - vbroadcastss( - zmm3, - ptr[BO1 + LDB3 - + (iteration - OFFSET) * SIZE]); - break; - case 4: - vbroadcastss(zmm3, - ptr[BO2 + (iteration - OFFSET) * SIZE]); - break; - case 5: - vbroadcastss( - zmm3, - ptr[BO2 + LDB * 1 - + (iteration - OFFSET) * SIZE]); - break; - case 6: - vbroadcastss( - zmm3, - ptr[BO2 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - break; - case 7: - vbroadcastss( - zmm3, - ptr[BO2 + LDB3 - + (iteration - OFFSET) * SIZE]); - break; - } - } else { - vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); - } - vfmadd231ps(regs[i], zmm3, zmm0); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm3, zmm1); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm3, zmm2); - } else { - if (!isTransB) { - switch (i) { - case 2: - vfmadd231ps(regs[i], zmm0, - zword_b[BO1 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO1 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO1 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - break; - case 3: - vfmadd231ps(regs[i], zmm0, - zword_b[BO1 + LDB3 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO1 + LDB3 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO1 + LDB3 - + (iteration - OFFSET) * SIZE]); - break; - case 4: - vfmadd231ps(regs[i], zmm0, - zword_b[BO2 + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO2 + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO2 + (iteration - OFFSET) * SIZE]); - break; - case 5: - vfmadd231ps(regs[i], zmm0, - zword_b[BO2 + LDB * 1 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO2 + LDB * 1 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO2 + LDB * 1 - + (iteration - OFFSET) * SIZE]); - break; - case 6: - vfmadd231ps(regs[i], zmm0, - zword_b[BO2 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO2 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO2 + LDB * 2 - + (iteration - OFFSET) * SIZE]); - break; - case 7: - vfmadd231ps(regs[i], zmm0, - zword_b[BO2 + LDB3 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO2 + LDB3 - + (iteration - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO2 + LDB3 - + (iteration - OFFSET) * SIZE]); - break; - } - } else { - vfmadd231ps( - regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[i + 8], zmm1, - zword_b[BO1 + (i - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[i + 16], zmm2, - zword_b[BO1 + (i - OFFSET) * SIZE]); - } - } - } - }; - - // Innerkernel; called by kernel - auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, - bool isCopy, bool doCPrefetch, bool isUnmasked = true) { - for (int i = 0; i < 8; i++) { - if (!isDirect) { - prefetcht0(ptr[AO1 - + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET) - * SIZE]); - if (unroll_m >= 32) - prefetcht0(ptr[AO1 - + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET) - * SIZE]); - if (unroll_m >= 48) - prefetcht0(ptr[AO1 - + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET) - * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]); - if (unroll_m >= 32) - prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]); - if (unroll_m >= 48) - prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]); - } - - if (!isDirect) { - if (i != 0) { - if (isUnmasked || unroll_m > 16) { - vmovups(zmm0, - ptr[AO1 - + (unroll_m * i + 0 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm0 | k1 | T_z, - ptr[AO1 - + (unroll_m * i + 0 * 16 - OFFSET) - * SIZE]); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(zmm1, ptr[AO1 - + (unroll_m * i + 1 * 16 - - OFFSET) - * SIZE]); - } else { - vmovups(zmm1 | k2 | T_z, - ptr[AO1 - + (unroll_m * i + 1 * 16 - - OFFSET) - * SIZE]); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(zmm2, ptr[AO1 - + (unroll_m * i + 2 * 16 - - OFFSET) - * SIZE]); - } else { - vmovups(zmm2 | k3 | T_z, - ptr[AO1 - + (unroll_m * i + 2 * 16 - - OFFSET) - * SIZE]); - } - } - } - } else { - if (isUnmasked || unroll_m > 16) { - vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm0 | k1 | T_z, - ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm1 | k2 | T_z, - ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm2 | k3 | T_z, - ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); - } - } - add(AO1, LDA); - } - - if (ver == ver_avx512_core) { - if (!isTransB) { - vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); - } - vfmadd231ps(regs[0], zmm3, zmm0); - if (unroll_m >= 32) - vfmadd231ps(regs[0 + 8], zmm3, zmm1); - if (unroll_m >= 48) - vfmadd231ps(regs[0 + 16], zmm3, zmm2); - } else { - if (!isTransB) { - vfmadd231ps(regs[0], zmm0, - zword_b[BO1 + (i - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[0 + 8], zmm1, - zword_b[BO1 + (i - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[0 + 16], zmm2, - zword_b[BO1 + (i - OFFSET) * SIZE]); - } else { - vfmadd231ps(regs[0], zmm0, - zword_b[BO1 + (0 - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[0 + 8], zmm1, - zword_b[BO1 + (0 - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[0 + 16], zmm2, - zword_b[BO1 + (0 - OFFSET) * SIZE]); - } - } - - if (unroll_n >= i + 1) { - if (!isTransB) { - switch (i) { - case 0: - prefetcht0( - ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 1: - prefetcht0(ptr[BO1 + LDB - + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 2: - prefetcht0(ptr[BO1 + LDB * 2 - + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 3: - prefetcht0(ptr[BO1 + LDB3 - + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 4: - prefetcht0( - ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 5: - prefetcht0(ptr[BO2 + LDB - + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 6: - prefetcht0(ptr[BO2 + LDB * 2 - + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - case 7: - prefetcht0(ptr[BO2 + LDB3 - + (PREFETCHSIZEB - OFFSET) * SIZE]); - break; - } - } - } - - if (unroll_n >= 2) { - if (ver == ver_avx512_core) { - if (!isTransB) { - vbroadcastss(zmm3, - ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]); - } - vfmadd231ps(regs[1], zmm3, zmm0); - if (unroll_m >= 32) - vfmadd231ps(regs[1 + 8], zmm3, zmm1); - if (unroll_m >= 48) - vfmadd231ps(regs[1 + 16], zmm3, zmm2); - } else { - if (!isTransB) { - vfmadd231ps(regs[1], zmm0, - zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[1 + 8], zmm1, - zword_b[BO1 + LDB * 1 - + (i - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[1 + 16], zmm2, - zword_b[BO1 + LDB * 1 - + (i - OFFSET) * SIZE]); - } else { - vfmadd231ps(regs[1], zmm0, - zword_b[BO1 + (1 - OFFSET) * SIZE]); - if (unroll_m >= 32) - vfmadd231ps(regs[1 + 8], zmm1, - zword_b[BO1 + (1 - OFFSET) * SIZE]); - if (unroll_m >= 48) - vfmadd231ps(regs[1 + 16], zmm2, - zword_b[BO1 + (1 - OFFSET) * SIZE]); - } - } - } - - if (isCopy) { - if (isUnmasked || unroll_m > 16) { - vmovups(ptr[LDA4 - + (unroll_m * i + 0 * 16 - OFFSET) - * SIZE], - zmm0); - } else { - vmovups(ptr[LDA4 - + (unroll_m * i + 0 * 16 - OFFSET) - * SIZE], - zmm0 | k1); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(ptr[LDA4 - + (unroll_m * i + 1 * 16 - OFFSET) - * SIZE], - zmm1); - } else { - vmovups(ptr[LDA4 - + (unroll_m * i + 1 * 16 - OFFSET) - * SIZE], - zmm1 | k2); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(ptr[LDA4 - + (unroll_m * i + 2 * 16 - OFFSET) - * SIZE], - zmm2); - } else { - vmovups(ptr[LDA4 - + (unroll_m * i + 2 * 16 - OFFSET) - * SIZE], - zmm2 | k3); - } - } - if (i == 7) - sub(LDA4, -unroll_m * 8 * SIZE); - } - fmaloop(unroll_m, unroll_n, i); - - if (i == 1) { - if (doCPrefetch) { - if (ver == ver_avx512_core) - prefetchw(ptr[CO2 + 0 * 16 * SIZE]); - else - prefetcht0(ptr[CO2 + 0 * 16 * SIZE]); - } - } - if (i == 3) { - if (doCPrefetch && unroll_m >= 32) { - if (ver == ver_avx512_core) - prefetchw(ptr[CO2 + 1 * 16 * SIZE]); - else - prefetcht0(ptr[CO2 + 1 * 16 * SIZE]); - } - if (!isTransA) { - if (ver == ver_avx512_core) - prefetcht0(ptr[AA + 16 * 0 * SIZE]); - else - prefetcht2(ptr[AA + 16 * 0 * SIZE]); - } - } - if (i == 5) { - if (doCPrefetch) { - if (unroll_m >= 48) { - if (ver == ver_avx512_core) - prefetchw(ptr[CO2 + 2 * 16 * SIZE]); - else - prefetcht0(ptr[CO2 + 2 * 16 * SIZE]); - } - add(CO2, LDC); - } - if (!isTransA) { - if (unroll_m >= 32) { - if (ver == ver_avx512_core) - prefetcht0(ptr[AA + 16 * 1 * SIZE]); - else - prefetcht2(ptr[AA + 16 * 1 * SIZE]); - } - } - } - - if (isTransB) { - prefetcht0(ptr[BO1 + BO2]); - add(BO1, LDB); - } - } // end of for loop - - if (!isTransB) { - sub(BO1, -8 * SIZE); - if (unroll_n >= 4) - sub(BO2, -8 * SIZE); - } - if (!isTransA) { - if (unroll_m >= 48) { - if (ver == ver_avx512_core) - prefetcht0(ptr[AA + 16 * 2 * SIZE]); - else - prefetcht2(ptr[AA + 16 * 2 * SIZE]); - } - lea(AA, ptr[AA + LDA]); - } - - if (!isDirect) { - if (isUnmasked || unroll_m > 16) { - vmovups(zmm0, - ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm0 | k1 | T_z, - ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(zmm1, ptr[AO1 - + (unroll_m * 8 + 1 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm1 | k2 | T_z, - ptr[AO1 - + (unroll_m * 8 + 1 * 16 - OFFSET) - * SIZE]); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(zmm2, ptr[AO1 - + (unroll_m * 8 + 2 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm2 | k3 | T_z, - ptr[AO1 - + (unroll_m * 8 + 2 * 16 - OFFSET) - * SIZE]); - } - } - sub(AO1, -unroll_m * 8 * SIZE); - } - - sub(LL, 1); - }; - - // Main kernel; does prefetching and calls innerkernel - // After calculating results in registers, writes back to C matrix by - // calling update - auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, - bool isCopy, bool isUnmasked = true) { - if (!isDirect) { - lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); - } else { - mov(AO1, A); - } - - if (isCopy) { - lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]); - } else { - auto step = ver == ver_avx512_core ? 2 : 4; - lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]); - } - - if (isTransB) { - lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]); - } - - if (!isDirect) { - if (isUnmasked || unroll_m > 16) { - vmovups(zmm0, - ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm0 | k1 | T_z, - ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(zmm1, ptr[AO1 - + (unroll_m * 0 + 1 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm1 | k2 | T_z, - ptr[AO1 - + (unroll_m * 0 + 1 * 16 - OFFSET) - * SIZE]); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(zmm2, ptr[AO1 - + (unroll_m * 0 + 2 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm2 | k3 | T_z, - ptr[AO1 - + (unroll_m * 0 + 2 * 16 - OFFSET) - * SIZE]); - } - } - } - - Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; - - mov(LL, K); - sar(LL, 3); - sub(LL, SECOND_FETCH); - jle(kernel13, T_NEAR); - align(16); - - L(kernel12); - innerkernel( - unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); - jg(kernel12, T_NEAR); - align(16); - - L(kernel13); - lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); - add(LL, unroll_n); - jle(kernel15, T_NEAR); - align(16); - - L(kernel14); - innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); - jg(kernel14, T_NEAR); - align(16); - - L(kernel15); - mov(LL, K); - and_(LL, 7); - jle(kernel18, T_NEAR); - align(16); - - L(kernel16); - if (isDirect) { - if (isUnmasked || unroll_m > 16) { - vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm0 | k1 | T_z, - ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm1 | k2 | T_z, - ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm2 | k3 | T_z, - ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); - } - } - add(AO1, LDA); - } - - for (int i = 0; i < unroll_n; i++) { - if (!isTransB) { - switch (i) { - case 0: - vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); - break; - case 1: - vbroadcastss( - zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); - break; - case 2: - vbroadcastss( - zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); - break; - case 3: - vbroadcastss( - zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]); - break; - case 4: - vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]); - break; - case 5: - vbroadcastss( - zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); - break; - case 6: - vbroadcastss( - zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); - break; - case 7: - vbroadcastss( - zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]); - break; - } - } else { - vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); - } - vfmadd231ps(regs[i], zmm3, zmm0); - if (unroll_m >= 32) { - vfmadd231ps(regs[i + 8], zmm3, zmm1); - } - if (unroll_m >= 48) { - vfmadd231ps(regs[i + 16], zmm3, zmm2); - } - } - - if (isCopy) { - if (isUnmasked || unroll_m > 16) { - vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], - zmm0); - } else { - vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], - zmm0 | k1); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(ptr[LDA4 - + (unroll_m * 0 + 1 * 16 - OFFSET) - * SIZE], - zmm1); - } else { - vmovups(ptr[LDA4 - + (unroll_m * 0 + 1 * 16 - OFFSET) - * SIZE], - zmm1 | k2); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(ptr[LDA4 - + (unroll_m * 0 + 2 * 16 - OFFSET) - * SIZE], - zmm2); - } else { - vmovups(ptr[LDA4 - + (unroll_m * 0 + 2 * 16 - OFFSET) - * SIZE], - zmm2 | k3); - } - } - sub(LDA4, -unroll_m * SIZE); - } - - if (!isDirect) { - if (isUnmasked || unroll_m > 16) { - vmovups(zmm0, - ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); - } else { - vmovups(zmm0 | k1 | T_z, - ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); - } - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) { - vmovups(zmm1, ptr[AO1 - + (unroll_m * 1 + 1 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm1 | k2 | T_z, - ptr[AO1 - + (unroll_m * 1 + 1 * 16 - OFFSET) - * SIZE]); - } - } - if (unroll_m >= 48) { - if (isUnmasked) { - vmovups(zmm2, ptr[AO1 - + (unroll_m * 1 + 2 * 16 - OFFSET) - * SIZE]); - } else { - vmovups(zmm2 | k3 | T_z, - ptr[AO1 - + (unroll_m * 1 + 2 * 16 - OFFSET) - * SIZE]); - } - } - sub(AO1, -unroll_m * SIZE); - } - - if (!isTransB) { - sub(BO1, -SIZE); - if (unroll_n >= 4) { - sub(BO2, -SIZE); - } - } else { - add(BO1, LDB); - } - - sub(LL, 1); - jg(kernel16, T_NEAR); - align(16); - - L(kernel18); - vbroadcastss(VALPHA, ALPHA); - - if (isBetaN) { - vbroadcastss(VBETA, BETA); - } - - // Write back the results; all beta cases need to be handled - if (hasBias) { - mov(BIAS1, BIAS); - if (isUnmasked || unroll_m > 16) - vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); - else - vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]); - if (unroll_m >= 32) { - if (isUnmasked || unroll_m > 32) - vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]); - else - vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]); - } - if (unroll_m >= 48) { - if (isUnmasked) - vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]); - else - vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]); - } - } - - for (int i = 0; i < unroll_n; i++) { - bool useScale = i % 2 != 0; - bool useCO1 = i < 2; - if (i == 2) - lea(CO2, ptr[CO1 + LDC * 2]); - if (i == 4 || i == 6) - lea(CO2, ptr[CO2 + LDC * 2]); - if (hasBias) - vaddps(regs[i], VBIAS1, regs[i]); - if (isUnmasked || unroll_m > 16) { - update(regs[i], useCO1, 0, 0, useScale); - } else { - update(regs[i], useCO1, 0, 1, useScale); - } - if (unroll_m >= 32) { - if (hasBias) - vaddps(regs[i + 8], VBIAS2, regs[i + 8]); - if (isUnmasked || unroll_m > 32) { - update(regs[i + 8], useCO1, 16, 0, useScale); - } else { - update(regs[i + 8], useCO1, 16, 2, useScale); - } - } - if (unroll_m >= 48) { - if (hasBias) - vaddps(regs[i + 16], VBIAS3, regs[i + 16]); - if (isUnmasked) { - update(regs[i + 16], useCO1, 32, 0, useScale); - } else { - update(regs[i + 16], useCO1, 32, 3, useScale); - } - } - } - - switch (unroll_n) { - case 1: add(CO1, LDC); break; - case 2: lea(CO1, ptr[CO1 + LDC * 2]); break; - case 3: lea(CO1, ptr[CO2 + LDC * 1]); break; - case 4: lea(CO1, ptr[CO2 + LDC * 2]); break; - case 5: lea(CO1, ptr[CO2 + LDC * 1]); break; - case 6: lea(CO1, ptr[CO2 + LDC * 2]); break; - case 7: lea(CO1, ptr[CO2 + LDC * 1]); break; - case 8: lea(CO1, ptr[CO2 + LDC * 2]); break; - } - - // Compute next address of B - if (!isTransB) { - lea(rax, ptr[K * SIZE]); - switch (unroll_n) { - case 1: - add(BO1, LDB); - add(BO2, LDB); - break; - case 2: - lea(BO1, ptr[BO1 + LDB * 2]); - lea(BO2, ptr[BO2 + LDB * 2]); - break; - case 3: - lea(BO1, ptr[BO1 + LDB3]); - lea(BO2, ptr[BO2 + LDB3]); - break; - case 4: - lea(BO1, ptr[BO1 + LDB * 4]); - lea(BO2, ptr[BO2 + LDB * 4]); - break; - case 5: - lea(BO1, ptr[BO1 + LDB * 4]); - add(BO1, LDB); - lea(BO2, ptr[BO2 + LDB * 4]); - add(BO2, LDB); - break; - case 6: - lea(BO1, ptr[BO1 + LDB3 * 2]); - lea(BO2, ptr[BO2 + LDB3 * 2]); - break; - case 7: - lea(BO1, ptr[BO1 + LDB * 8]); - sub(BO1, LDB); - lea(BO2, ptr[BO2 + LDB * 8]); - sub(BO2, LDB); - break; - case 8: - lea(BO1, ptr[BO1 + LDB * 8]); - lea(BO2, ptr[BO2 + LDB * 8]); - break; - } - sub(BO1, rax); - sub(BO2, rax); - } else { - mov(rax, LDB); - imul(rax, K); - sub(BO1, rax); - add(BO1, unroll_n * SIZE); - } - }; - - // High-level subroutine; does packing if needed, then splits C matrix. - // Operates on chunks of 48 rows, 8 columns at a time (handling tail - // cases appropriately by doing 32 or 16 rows, and/or with masking, - // and/or fewer columns). - auto subloop = [&](int unroll_m) { - Label l_subloop_20x[8], l_subloop_mask_20x[8]; - Label l_subloop_30x[8], l_subloop_mask_30x[8]; - - Label subloop11, subloop11mask; - Label subloop30, subloop30mask; - Label subloop31, subloop31mask; - Label subloop96; - Label subloop98, subloop98mask; - Label subloop99; - - // Create mask - mov(BO1, rcx); - mov(rcx, M); - sub(rcx, unroll_m - 16); - mov(CO1, 16); - cmp(rcx, 16); - - cmovg(rcx, CO1); - mov(rax, 1); - sal(rax, cl); - sub(rax, 1); - mov(rcx, 0xffff); - - if (unroll_m == 16) { - kmovw(k1, eax); - } else if (unroll_m == 32) { - kmovw(k1, ecx); - kmovw(k2, eax); - } else { - kmovw(k1, ecx); - kmovw(k2, ecx); - kmovw(k3, eax); - } - mov(rcx, BO1); - - and_(rax, 0xffff); - cmp(rax, 0xffff); - jne(subloop96, T_NEAR); - - if (isTransA) { - do_pack(unroll_m); - } - - mov(CO1, C); - add(C, unroll_m * SIZE); - - mov(BO1, B); - if (!isTransB) { - lea(BO2, ptr[B + LDB * 4]); - } - - if (!isTransA) { - lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); - cmp(M, UNROLL_M); - jg(subloop98, T_NEAR); - - mov(AA, ORIG_A); - lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); - L(subloop98); - } - - mov(LL, N); - mov(I, LL); - if (!isTransA) { - // If N is too small, skip copy operation - cmp(LL, UNROLL_N * 3); - jle(subloop30, T_NEAR); - - // If A is not aligned to cache line - cmp(FLAG, 0); - je(subloop30, T_NEAR); - } else { - cmp(LL, UNROLL_N); - jl(l_subloop_20x[1], T_NEAR); - } - align(16); - - if (!isTransA) { - kernel(unroll_m, UNROLL_N, true, true); - } else { - kernel(unroll_m, UNROLL_N, false, false); - } - - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jl(l_subloop_20x[1], T_NEAR); - align(16); - - L(subloop11); - kernel(unroll_m, UNROLL_N, false, false); - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jge(subloop11, T_NEAR); - align(16); - - for (int i = 1; i <= 7; i++) { - L(l_subloop_20x[i]); - cmp(I, i); - if (i < 7) { - jne(l_subloop_20x[i + 1], T_NEAR); - } else { - jne(subloop99, T_NEAR); - } - kernel(unroll_m, i, false, false); - jmp(subloop99, T_NEAR); - align(16); - } - - if (!isTransA) { - L(subloop30); - cmp(I, UNROLL_N); - jl(l_subloop_30x[1], T_NEAR); - align(16); - - L(subloop31); - kernel(unroll_m, UNROLL_N, true, false); - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jge(subloop31, T_NEAR); - align(16); - - for (int i = 1; i <= 7; i++) { - L(l_subloop_30x[i]); - cmp(I, i); - if (i < 7) { - jne(l_subloop_30x[i + 1], T_NEAR); - } else { - jne(subloop99, T_NEAR); - } - kernel(unroll_m, i, true, false); - if (i < 7) - jmp(subloop99, T_NEAR); - align(16); - } - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop96); - if (isTransA) { - do_pack(unroll_m); - } - - mov(CO1, C); - add(C, unroll_m * SIZE); - mov(BO1, B); - if (!isTransB) { - lea(BO2, ptr[B + LDB * 4]); - } - - if (!isTransA) { - lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); - cmp(M, UNROLL_M); - jg(subloop98mask, T_NEAR); - mov(AA, ORIG_A); - lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); - L(subloop98mask); - } - - mov(LL, N); - mov(I, LL); - if (!isTransA) { - // If N is too small, skip copy operation - cmp(LL, UNROLL_N * 3); - jle(subloop30mask, T_NEAR); - - // If A is not aligned to cache line - cmp(FLAG, 0); - je(subloop30mask, T_NEAR); - } else { - cmp(LL, UNROLL_N); - jl(l_subloop_mask_20x[1], T_NEAR); - } - align(16); - - if (!isTransA) { - kernel(unroll_m, UNROLL_N, true, true, false); - } else { - kernel(unroll_m, UNROLL_N, false, false, false); - } - - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jl(l_subloop_mask_20x[1], T_NEAR); - align(16); - - L(subloop11mask); - kernel(unroll_m, UNROLL_N, false, false, false); - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jge(subloop11mask, T_NEAR); - align(16); - - for (int i = 1; i <= 7; i++) { - L(l_subloop_mask_20x[i]); - cmp(I, i); - if (i < 7) { - jne(l_subloop_mask_20x[i + 1], T_NEAR); - } else { - jne(subloop99, T_NEAR); - } - kernel(unroll_m, i, false, false, false); - jmp(subloop99, T_NEAR); - align(16); - } - - if (!isTransA) { - L(subloop30mask); - cmp(I, UNROLL_N); - jl(l_subloop_mask_30x[1], T_NEAR); - align(16); - - L(subloop31mask); - kernel(unroll_m, UNROLL_N, true, false, false); - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jge(subloop31mask, T_NEAR); - align(16); - - for (int i = 1; i <= 7; i++) { - L(l_subloop_mask_30x[i]); - cmp(I, i); - if (i < 7) { - jne(l_subloop_mask_30x[i + 1], T_NEAR); - } else { - jne(subloop99, T_NEAR); - } - kernel(unroll_m, i, true, false, false); - if (i < 7) - jmp(subloop99, T_NEAR); - align(16); - } - } - - L(subloop99); - // Compute address for A - if (!isTransA) { - add(A, unroll_m * SIZE); - } else { - mov(rax, LDA); - imul(rax, rax, unroll_m); - add(A, rax); - } - - // Compute next address of BIAS - if (hasBias) { - add(BIAS, unroll_m * SIZE); - } - }; - - preamble(); - - Label buffer_in_ws, buffer_allocated; - - // Get the registers - mov(B, ARG_B); - mov(LDB, ARG_LDB); - mov(r15, ARG_BETA); - mov(r12, ARG_C); - if (hasBias) - mov(r10, ARG_BIAS); - mov(LDC, ARG_LDC); - mov(rbp, rsp); - - vmovss(xmm0, ptr[ARG_ALPHA]); - vmovss(xmm1, ptr[r15]); - -#if _WIN32 - mov(A, ARG_A); - mov(LDA, ARG_LDA); -#endif - - cmp(K, STACK_K_CAPACITY); - jg(buffer_in_ws, T_NEAR); - - // Create buffer and align to 4kB page - lea(rax, ptr[K * SIZE]); - imul(rax, rax, 0x30); - add(rax, 256); - sub(rsp, rax); - and_(rsp, -PAGE_4K); - jmp(buffer_allocated, T_NEAR); - - L(buffer_in_ws); - mov(rsp, ARG_WS); - - L(buffer_allocated); - - mov(ORIG_SP, rbp); - mov(M, ARG_M); - mov(N, ARG_N); - mov(C, r12); - if (hasBias) - mov(BIAS, r10); - vmovss(ALPHA, xmm0); - vmovss(BETA, xmm1); - sub(A, -OFFSET * SIZE); - sub(B, -OFFSET * SIZE); - mov(ORIG_A, A); - sal(LDA, BASE_SHIFT); - sal(LDB, BASE_SHIFT); - sal(LDC, BASE_SHIFT); - lea(LDB3, ptr[LDB + LDB * 2]); - - if (isTransA) { - vpbroadcastq(zmm2, LDA); - vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE); - mov(rax, -2); - kmovw(k4, eax); - - for (int i = 0; i < 6; i++) { - vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); - kshiftlw(k4, k4, 1); - } - vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); - } - - // Check A alignment and leading dimension; take copy-based path as - // needed - mov(rax, LDA); - or_(rax, A); - and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f); - mov(FLAG, rax); - - for (int i = 8; i < 16; i++) { - for (int j = 0; j < 3; j++) { - vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j)); - } - } - - Label main0, main1, main2, main999; - - cmp(M, 32); - jle(main0, T_NEAR); - align(16); - - L(main1); - subloop(48); - sub(M, UNROLL_M); - cmp(M, 32); - jg(main1, T_NEAR); - align(16); - - L(main0); - cmp(M, 16); - jle(main2, T_NEAR); - - subloop(32); - jmp(main999, T_NEAR); - align(16); - - L(main2); - cmp(M, 0); - jle(main999, T_NEAR); - subloop(16); - align(16); - - L(main999); - // Restore original stack - mov(rsp, ORIG_SP); - - vzeroupper(); - postamble(); - - ker_ = this->getCode(); - } - - typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, - const float *alpha, const float *a, dim_t lda, - const float *b, dim_t ldb, const float *beta, float *c, - dim_t ldc, const float *bias, float *ws); - - void operator()(dim_t m, dim_t n, dim_t k, - const float *alpha, const float *a, dim_t lda, - const float *b, dim_t ldb, const float *beta, float *c, - dim_t ldc, const float *bias, float *ws) const - { - ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); - } - -private: - ker_t ker_; -}; - -const xbyak_gemm *get_xbyak_gemm( - bool isTransA, bool isTransB, float beta, bool hasBias) { - auto beta_idx = [](float beta) { - return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); - }; - - // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] - static xbyak_gemm *kernel_table[2][2][2][3]; - static std::once_flag initialized; - std::call_once(initialized, [=]{ - for (bool isTransA: {false, true}) - for (bool isTransB: {false, true}) - for (bool hasBias: {false, true}) - for (float beta: {0.0f, 1.0f, 2.0f}) { - // nocopy sgemm with bias for beta != 0.0 is not supported - if (hasBias && beta != 0.0) - continue; - kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = - new xbyak_gemm(isTransA, isTransB, beta, hasBias); - } - }); - - return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; -} - -void sgemm_nocopy_driver(const char *transa, - const char *transb, int m, int n, int k, const float *alpha, - const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, - float *c, dim_t ldc, const float *bias, float *ws) -{ - bool isTransA = (*transa == 'T' || *transa == 't'); - bool isTransB = (*transb == 'T' || *transb == 't'); - - int Bm, sizeM, Bn, sizeN, Bk, sizeK; - - int i, j; - - if ((m <= 0) || (n <= 0)) - return; - - if ((k <= 0) || (alpha[0] == 0.)) { - - if (beta[0] == 0.) { - for (j = 0; j < n; j++) - for (i = 0; i < m; i++) - c[i + j * ldc] = 0.0; - } else if (beta[0] != 1.) { - for (j = 0; j < n; j++) - for (i = 0; i < m; i++) - c[i + j * ldc] *= beta[0]; - } - - return; - } - - assert(IMPLICATION(bias != nullptr, *beta == 0.0)); - - // XXX: this happens on every thread... - bool hasBias = (bias != nullptr); - auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); - auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); - auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); - assert(ker_bn && ker_b1 && ker_b0); - - int BM = 4032, BN, BK; - if (mayiuse(avx512_core)) { - BN = isTransA ? 384 : 64; - BK = 384; - } else { - BN = isTransA ? 96 : 64; - BK = isTransB ? 96 : 192; - if (!isTransA && !isTransB) - BK = 128; - } - const float *curA, *curB, *curBias = nullptr; - float *curC; - - for (Bk = 0; Bk < k; Bk += sizeK) { - sizeK = k - Bk; - if (sizeK >= BK * 2) - sizeK = BK; - else { - if (sizeK > BK) - sizeK = (sizeK + 1) / 2; - } - - for (Bm = 0; Bm < m; Bm += sizeM) { - sizeM = m - Bm; - if (sizeM >= BM * 2) - sizeM = BM; - else { - if (sizeM > BM + BM / 2) - sizeM = (sizeM + 1) / 2; - } - - for (Bn = 0; Bn < n; Bn += sizeN) { - sizeN = n - Bn; - if (sizeN >= BN * 2) - sizeN = BN; - else { - if (sizeN > BN + BN / 2) - sizeN = (sizeN + 1) / 2; - } - - if (!isTransA) { - curA = a + Bm + Bk * lda; - } else { - curA = a + Bk + Bm * lda; - } - if (!isTransB) { - curB = b + Bk + Bn * ldb; - } else { - curB = b + Bn + Bk * ldb; - } - curC = c + Bm + (size_t)Bn * ldc; - if (bias != nullptr) { - if (Bk == 0) { - curBias = bias + Bm; - } else { - curBias = nullptr; - } - } - if (Bk == 0) { - if (*beta == 0.0 && bias == nullptr) - (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, - alpha, curA, lda, curB, ldb, beta, curC, ldc, - curBias, ws); - else - (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, - alpha, curA, lda, curB, ldb, beta, curC, ldc, - curBias, ws); - } else { - (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, - alpha, curA, lda, curB, ldb, beta, curC, ldc, - curBias, ws); - } - } - } - } -} - -} - -mkldnn_status_t jit_avx512_common_gemm_f32( - const char *transa, const char *transb, - const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, - const float *A, const int *p_lda, const float *B, const int *p_ldb, - const float *p_beta, float *C, const int *p_ldc, const float *bias) -{ - using namespace mkldnn::impl::utils; - using namespace avx512_common_gemm_f32; - using namespace gemm_utils; - - if (*p_beta != 0 && bias) - return ref_gemm(transa, transb, p_m, p_n, p_k, - p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); - - int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); - - int m = *p_m; - int n = *p_n; - int k = *p_k; - dim_t lda = *p_lda; - dim_t ldb = *p_ldb; - dim_t ldc = *p_ldc; - float beta = *p_beta; - int MB, NB, KB; - - int nthr_m, nthr_n, nthr_k, nthr_mn; - - // Determine threading partitioning - calc_nthr_nocopy_avx512_common( - m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); - - // May not happen, but just in case - if (nthr < nthr_m * nthr_n * nthr_k) - nthr = nthr_m * nthr_n * nthr_k; - - nthr_mn = nthr_m * nthr_n; - - unsigned char * ompstatus_ = nullptr; - unsigned char volatile *ompstatus = nullptr; - - float *c_buffers = nullptr; - float *ws_buffers = nullptr; - - if (nthr_k > 1) { - ompstatus_ = (unsigned char *) malloc( - nthr * CACHE_LINE_SIZE, - CACHE_LINE_SIZE); - ompstatus = (unsigned char volatile *) ompstatus_; - assert(ompstatus); - - for (int i = 0; i < nthr; i++) - ompstatus[i * CACHE_LINE_SIZE] = 0; - - c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB - * sizeof(float), PAGE_4K); - } - - const size_t ws_elems_per_thr = (size_t)k * 48 + 64; - const size_t ws_size_per_thr - = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); - if (k > STACK_K_CAPACITY) { - ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); - } - - parallel_nd(nthr, [&](const int ithr) { - int ithr_m, ithr_n, ithr_k, ithr_mn; - int m_from, m_to, myM; - int n_from, n_to, myN; - int k_from, k_to, myK; - int cbase, ibase; - const float *myA, *myB, *myBias = nullptr; - float *myC = C, myBeta; - float *ws = ws_buffers ? - ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; - dim_t ld = ldc; - - int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); - - if (ithr < nthr_m * nthr_n * nthr_k) { - - ithr_mn = ithr % nthr_mn; - ithr_m = ithr_mn % nthr_m; - ithr_n = ithr_mn / nthr_m; - ithr_k = ithr / nthr_mn; - - /* swap ithr_k for performance improvement */ - if (ithr_k == 0) - ithr_k = nthr_k - 1; - else if (ithr_k == nthr_k - 1) - ithr_k = 0; - - m_from = MB * (ithr_m); - m_to = MB * (ithr_m + 1); - if (m_to > m) - m_to = m; - myM = m_to - m_from; - - n_from = NB * (ithr_n); - n_to = NB * (ithr_n + 1); - if (n_to > n) - n_to = n; - myN = n_to - n_from; - - k_from = KB * (ithr_k); - k_to = KB * (ithr_k + 1); - if (k_to > k) - k_to = k; - myK = k_to - k_from; - - cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; - - if ((myM > 0) && (myN > 0)) { - - if (*transa == 'N' || *transa == 'n') { - myA = &(A[m_from + k_from * lda]); - } else { - myA = &(A[k_from + m_from * lda]); - } - if (*transb == 'N' || *transb == 'n') { - myB = &(B[k_from + n_from * ldb]); - } else { - myB = &(B[n_from + k_from * ldb]); - } - if (ithr_k == 0) { - myC = &(C[m_from + n_from * ldc]); - myBeta = beta; - ld = ldc; - if (bias) - myBias = &(bias[m_from]); - } else { - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); - myBeta = 0.0; - ld = MB; - myBias = nullptr; - } - - sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, - lda, myB, ldb, &myBeta, myC, ld, myBias, ws); - - if (nthr_k > 1 && !sum_later) - ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; - } - - if (nthr_k > 1 && !sum_later) { - - // sum matrices partitioned along K dimension - int n1, n2; - - partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); - - if (ithr_k > 0) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) - + (dim_t)n1 * MB; - /* need to wait until main thread finishes */ - while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { - }; - - /* my cache is hot */ - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - - for (int ik = 1; ik < nthr_k; ++ik) { - if (ik != ithr_k) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) - + (dim_t)n1 * MB; - - while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { - }; - - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - } - } - } - }); - - - // handle C summation later - if (nthr_k > 1 && ompstatus[0] == 0) { - - parallel_nd(nthr, [&](const int ithr) { - int ithr_m, ithr_n, ithr_k, ithr_mn; - int m_from, m_to, myM; - int n_from, n_to, myN; - int cbase; - float *myC = C; - - if (ithr < nthr_m * nthr_n * nthr_k) { - - ithr_mn = ithr % nthr_mn; - ithr_m = ithr_mn % nthr_m; - ithr_n = ithr_mn / nthr_m; - ithr_k = ithr / nthr_mn; - - /* swap ithr_k for performance improvement */ - if (ithr_k == 0) - ithr_k = nthr_k - 1; - else if (ithr_k == nthr_k - 1) - ithr_k = 0; - - m_from = MB * (ithr_m); - m_to = MB * (ithr_m + 1); - if (m_to > m) - m_to = m; - myM = m_to - m_from; - - n_from = NB * (ithr_n); - n_to = NB * (ithr_n + 1); - if (n_to > n) - n_to = n; - myN = n_to - n_from; - - cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - - if (nthr_k > 1) { - // sum matrices partitioned along K dimension - int n1, n2; - - partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); - - if (ithr_k > 0) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) - + (dim_t)n1 * MB; - - /* my cache is hot */ - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - - for (int ik = 1; ik < nthr_k; ++ik) { - if (ik != ithr_k) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) - + (dim_t)n1 * MB; - - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - } - } - } - }); - } - - free(c_buffers); - free(ompstatus_); - free(ws_buffers); - - return mkldnn_success; -} - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp deleted file mode 100644 index d581b7fd7..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP -#define JIT_AVX512_COMMON_GEMM_F32_HPP - -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t jit_avx512_common_gemm_f32( - const char *transa, const char *transb, const int *M, - const int *N, const int *K, const float *alpha, const float *A, - const int *lda, const float *B, const int *ldb, const float *beta, - float *C, const int *ldc, const float *bias = nullptr); - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp deleted file mode 100644 index 60d422083..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp +++ /dev/null @@ -1,2705 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "ref_gemm_f32.hpp" -#include "gemm_utils_f32.hpp" -#include "jit_avx_gemm_f32.hpp" - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -#define CACHE_LINE_SIZE 64 - -#define STACKSIZE get_size_of_abi_save_regs() -#if _WIN32 -#define STACK_K_CAPACITY 128 -#else -#define STACK_K_CAPACITY 8192 -#endif -#define SIZE 4 -#define OFFSET 32 -#define BASE_SHIFT 2 -#define SECOND_FETCH 14 - -namespace avx_gemm_f32 { -using namespace gemm_utils; - -struct xbyak_gemm : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm) - - xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, - void *code_ptr = nullptr, - size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) - : jit_generator(code_ptr, code_size) - { - using namespace Xbyak; - - const bool is_avx2 = mayiuse(avx2); - assert(IMPLICATION(!is_avx2, mayiuse(avx))); - - const int UNROLL_M = is_avx2 ? 16 : 8; - const int UNROLL_N = 6; - - bool isBeta0 = (beta == 0.0); - bool isBetaN = (!isBeta0 && beta != 1.0); - - // various definitions for convenience - auto ARG_M = abi_param1; - auto ARG_N = abi_param2; - auto K = abi_param3; - auto ARG_ALPHA = abi_param4; -#ifdef _WIN32 - auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; - auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + - sizeof(float *) + STACKSIZE]; - const auto stackOffset = OFFSET_SHADOWSPACE + - sizeof(float *) + STACKSIZE; - auto A = rsi; - auto LDA = rdi; -#else - auto ARG_A = r8; - auto ARG_LDA = r9; - const auto stackOffset = STACKSIZE; - auto A = ARG_A; - auto LDA = ARG_LDA; -#endif - auto ARG_B = ptr[rsp + 8 + stackOffset]; - auto ARG_LDB = ptr[rsp + 16 + stackOffset]; - auto ARG_BETA = ptr[rsp + 24 + stackOffset]; - auto ARG_C = ptr[rsp + 32 + stackOffset]; - auto ARG_LDC = ptr[rsp + 40 + stackOffset]; - auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; - auto ARG_WS = ptr[rsp + 56 + stackOffset]; - - auto B = r11; - auto LDB = rbx; - auto LDC = r13; - auto LL = rax; - auto AO1 = abi_param2; - auto BO1 = abi_param4; - auto BO2 = rbp; - auto CO1 = r14; - auto CO2 = r15; - auto LDB3 = r10; - auto LDA4 = abi_param1; - auto AA = r12; - auto BIAS1 = abi_param1; - - auto M = qword[rsp + 0]; - auto N = qword[rsp + 8]; - auto FLAG = qword[rsp + 16]; - auto I = qword[rsp + 24]; - auto C = qword[rsp + 32]; - auto BIAS = qword[rsp + 40]; - auto ALPHA = qword[rsp + 48]; - auto BETA = qword[rsp + 64]; - auto ORIG_A = qword[rsp + 80]; - auto MASK = dword[rsp + 88]; - auto STRIDE = qword[rsp + 120]; - auto ORIG_SP = qword[rsp + 152]; - - auto VALPHA = ymm1; - auto VBETA = ymm2; - auto VMASK = ymm3; - auto VBIAS1 = ymm2; - auto VBIAS2 = ymm4; - - auto PREFETCHSIZEA = 128; - auto PREFETCHSIZEB = (!isTransB) ? -16 : 0; - - // Function for packing if needed - auto do_pack = [&]( - int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { - Label pack2, pack3, pack4, pack10; - - int regIdx; - Reg64 reg; - - mov(BO1, A); - lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); - - if (isTransA) { - lea(BO2, ptr[BO1 + LDA * 4]); - lea(CO1, ptr[LDA + LDA * 2]); - vmovupd(ymm7, STRIDE); - } - - mov(LL, K); - sar(LL, 2); - jle(pack3, T_NEAR); - align(16); - - L(pack2); - if (!isTransA) { - for (int i = 0; i < 4; i++) { - regIdx = (i % 2 == 0) ? 4 : 6; - if (isLoad1Unmasked) { - vmovups(Ymm(regIdx), - ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(Ymm(regIdx), VMASK, - ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m > 8) { - if (isLoad2Unmasked) { - vmovups(Ymm(regIdx + 1), - ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(Ymm(regIdx + 1), VMASK, - ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); - } - } - add(BO1, LDA); - - vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], - Ymm(regIdx)); - if (unroll_m > 8) { - vmovups(ptr[AO1 - + (unroll_m * i + 1 * 8 - OFFSET) - * SIZE], - Ymm(regIdx + 1)); - } - } - - } else { - if (isLoad1Unmasked) { - for (int i = 0; i < 2; i++) { - reg = (i % 2 == 0) ? BO1 : BO2; - vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]); - vmovups(xmm1, - ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[reg + LDA * 2]); - vunpcklps(xmm4, xmm0, xmm1); - vunpckhps(xmm5, xmm0, xmm1); - vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); - vmovups(xmm1, - ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 2]); - vunpcklps(xmm6, xmm0, xmm1); - vunpckhps(xmm2, xmm0, xmm1); - - vunpcklpd(xmm0, xmm4, xmm6); - vunpckhpd(xmm1, xmm4, xmm6); - vmovups(ptr[AO1 - + (unroll_m * 0 + i * 4 - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * 1 + i * 4 - OFFSET) - * SIZE], - xmm1); - vunpcklpd(xmm0, xmm5, xmm2); - vunpckhpd(xmm1, xmm5, xmm2); - vmovups(ptr[AO1 - + (unroll_m * 2 + i * 4 - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * 3 + i * 4 - OFFSET) - * SIZE], - xmm1); - } - } else if (is_avx2) { - for (int i = 0; i < 2; i++) { - vmovaps(xmm4, xmm3); - vgatherqps(xmm0, - ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE], - xmm4); - vmovaps(xmm4, xmm3); - vgatherqps(xmm1, - ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], - xmm4); - - vmovups(ptr[AO1 - + (unroll_m * (2 * i) + 0 * 4 - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * (2 * i + 1) + 0 * 4 - - OFFSET) - * SIZE], - xmm1); - } - - lea(BO2, ptr[BO1 + LDA * 4]); - - for (int i = 0; i < 2; i++) { - vextractf128(xmm4, ymm3, 1); - vgatherqps(xmm0, - ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], - xmm4); - vextractf128(xmm4, ymm3, 1); - vgatherqps(xmm1, - ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], - xmm4); - - vmovups(ptr[AO1 - + (unroll_m * (2 * i) + 1 * 4 - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * (2 * i + 1) + 1 * 4 - - OFFSET) - * SIZE], - xmm1); - } - - lea(BO2, ptr[BO2 + LDA * 4]); - } else { - vxorps(xmm4, xmm4, xmm4); - lea(BO2, ptr[BO1 + LDA * 4]); - - auto el_cp = [&](int section, int ld_step) { - RegExp src_addr = section == 0 ? BO1 : BO2; - if (ld_step == 1 || ld_step == 2) - src_addr = src_addr + LDA * ld_step; - else if (ld_step == 3) - src_addr = src_addr + CO1; - src_addr = src_addr - OFFSET * SIZE; - - vmovups(Xmm(ld_step % 2), ptr[src_addr]); - RegExp dst_addr = AO1 - + (ld_step + section * 4 - OFFSET) * SIZE; - for (int off = 0; off < 4; ++off) - pextrd(ptr[dst_addr + unroll_m * off * SIZE], - Xmm(ld_step % 2), off); - }; - - Label l_end; - el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); - el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); - el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); - el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); - el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); - el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); - el_cp(1, 2); - L(l_end); - - lea(BO2, ptr[BO2 + LDA * 4]); - } - - if (unroll_m >= 16) { - assert(is_avx2); - if (isLoad2Unmasked) { - for (int i = 0; i < 2; i++) { - vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); - vmovups(xmm1, ptr[BO2 + LDA * 1 - + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 2]); - vunpcklps(xmm4, xmm0, xmm1); - vunpckhps(xmm5, xmm0, xmm1); - vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); - vmovups(xmm1, ptr[BO2 + LDA * 1 - + (0 * 8 - OFFSET) * SIZE]); - if (i == 0) - lea(BO2, ptr[BO2 + LDA * 2]); - vunpcklps(xmm6, xmm0, xmm1); - vunpckhps(xmm2, xmm0, xmm1); - - vunpcklpd(xmm0, xmm4, xmm6); - vunpckhpd(xmm1, xmm4, xmm6); - vmovups(ptr[AO1 - + (unroll_m * 0 + (i + 2) * 4 - - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * 1 + (i + 2) * 4 - - OFFSET) - * SIZE], - xmm1); - vunpcklpd(xmm0, xmm5, xmm2); - vunpckhpd(xmm1, xmm5, xmm2); - vmovups(ptr[AO1 - + (unroll_m * 2 + (i + 2) * 4 - - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * 3 + (i + 2) * 4 - - OFFSET) - * SIZE], - xmm1); - } - } else { - for (int i = 0; i < 2; i++) { - vmovaps(xmm4, xmm3); - vgatherqps(xmm0, - ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], - xmm4); - vmovaps(xmm4, xmm3); - vgatherqps(xmm1, - ptr[BO2 + ymm7 - + ((2 * i + 1) - OFFSET) * SIZE], - xmm4); - - vmovups(ptr[AO1 - + (unroll_m * (2 * i) + 2 * 4 - - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * (2 * i + 1) + 2 * 4 - - OFFSET) - * SIZE], - xmm1); - } - - lea(BO2, ptr[BO2 + LDA * 4]); - - for (int i = 0; i < 2; i++) { - vextractf128(xmm4, ymm3, 1); - vgatherqps(xmm0, - ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], - xmm4); - vextractf128(xmm4, ymm3, 1); - vgatherqps(xmm1, - ptr[BO2 + ymm7 - + ((2 * i + 1) - OFFSET) * SIZE], - xmm4); - - vmovups(ptr[AO1 - + (unroll_m * (2 * i) + 3 * 4 - - OFFSET) - * SIZE], - xmm0); - vmovups(ptr[AO1 - + (unroll_m * (2 * i + 1) + 3 * 4 - - OFFSET) - * SIZE], - xmm1); - } - - lea(BO2, ptr[BO2 + LDA * 4]); - } - } - add(BO1, (4 * SIZE)); - } - - add(AO1, unroll_m * 4 * SIZE); - sub(LL, 1); - jg(pack2, T_NEAR); - align(16); - - L(pack3); - mov(LL, K); - and_(LL, 3); - jle(pack10, T_NEAR); - align(16); - - L(pack4); - if (!isTransA) { - if (isLoad1Unmasked) { - vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m > 8) { - if (isLoad2Unmasked) { - vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm5, VMASK, - ptr[BO1 + (1 + 8 - OFFSET) * SIZE]); - } - } - add(BO1, LDA); - vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], - ymm4); - if (unroll_m > 8) { - vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], - ymm5); - } - } else { - if (isLoad1Unmasked) { - for (int i = 0; i < 2; i++) { - reg = (i % 2 == 0) ? BO1 : BO2; - vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]); - vmovss(xmm0, - ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[reg + LDA * 2]); - vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); - } - vunpcklpd(xmm1, xmm1, xmm2); - vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], - xmm1); - - for (int i = 0; i < 2; i++) { - vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); - vmovss(xmm0, - ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 2]); - vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); - } - vunpcklpd(xmm1, xmm1, xmm2); - vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], - xmm1); - } else if (is_avx2) { - vmovaps(xmm4, xmm3); - vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE], - xmm4); - lea(BO2, ptr[BO1 + LDA * 4]); - vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], - xmm1); - - vextractf128(xmm4, ymm3, 1); - vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], - xmm4); - lea(BO2, ptr[BO2 + LDA * 4]); - vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], - xmm1); - } else { - vxorps(xmm4, xmm4, xmm4); - lea(BO2, ptr[BO1 + LDA * 4]); - - auto el_cp = [&](int section, int ld_step) { - RegExp src_addr = section == 0 ? BO1 : BO2; - if (ld_step == 1 || ld_step == 2) - src_addr = src_addr + LDA * ld_step; - else if (ld_step == 3) - src_addr = src_addr + CO1; - src_addr = src_addr - OFFSET * SIZE; - - vmovss(xmm1, ptr[src_addr]); - RegExp dst_addr = AO1 - + (ld_step + section * 4 - OFFSET) * SIZE; - movss(ptr[dst_addr], xmm1); - }; - - Label l_end; - el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); - el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); - el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); - el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); - el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); - el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); - el_cp(1, 2); - L(l_end); - - lea(BO2, ptr[BO2 + LDA * 4]); - } - - if (unroll_m >= 16) { - assert(is_avx2); - if (isLoad2Unmasked) { - for (int i = 0; i < 2; i++) { - vmovss(Xmm(i + 1), - ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); - vmovss(xmm0, ptr[BO2 + LDA * 1 - + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 2]); - vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); - } - vunpcklpd(xmm1, xmm1, xmm2); - } else { - vmovaps(xmm4, xmm3); - vgatherqps(xmm1, - ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], - xmm4); - lea(BO2, ptr[BO2 + LDA * 4]); - } - vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE], - xmm1); - - if (isLoad2Unmasked) { - for (int i = 0; i < 2; i++) { - vmovss(Xmm(i + 1), - ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); - vmovss(xmm0, ptr[BO2 + LDA * 1 - + (0 * 8 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDA * 2]); - vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); - } - vunpcklpd(xmm1, xmm1, xmm2); - } else { - vextractf128(xmm4, ymm3, 1); - vgatherqps(xmm1, - ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], - xmm4); - } - vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE], - xmm1); - } - add(BO1, SIZE); - } - - add(AO1, unroll_m * SIZE); - sub(LL, 1); - jg(pack4, T_NEAR); - align(16); - - L(pack10); - }; - - // Fused multiply add; may become one or two instructions - auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2, - bool overWrite = false) { - if (useFma) { - if (is_avx2) { - vfmadd231ps(reg2, reg1, reg0); - } else { - assert(UNROLL_M == 8); - auto tent_vreg = overWrite ? reg1 : ymm1; - vmulps(tent_vreg, reg1, reg0); - vaddps(reg2, reg2, tent_vreg); - } - } else { - if (!overWrite) { - vmulps(ymm15, reg1, reg0); - vaddps(reg2, reg2, ymm15); - } else { - vmulps(reg1, reg1, reg0); - vaddps(reg2, reg2, reg1); - } - } - }; - - // Inner kernel with k=8 - auto innerkernel8 = [&](int unroll_m, int unroll_n, - bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, - bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, - Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, - Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, - Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, - Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, - Ymm reg23) { - - Ymm fmareg; - - if (!isDirect) { - prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - - for (int i = 0; i < 8; i++) { - if (isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, - ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } - } - add(AO1, LDA); - } - - if (!isTransB) { - vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg00 : reg12; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg06 : reg18; - fma(useFma, ymm1, ymm2, fmareg); - } - if (i == 0) { - if (!isTransB) { - prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); - } - } - if (unroll_n >= 2) { - if (!isTransB) { - if (i == 1) { - prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg01 : reg13; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg07 : reg19; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (isCopy) { - vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], - ymm0); - if (unroll_m >= 16) { - vmovups(ptr[LDA4 - + (unroll_m * i + 1 * 8 - OFFSET) - * SIZE], - ymm1); - } - if (i == 7) { - sub(LDA4, -unroll_m * 8 * SIZE); - } - } - - if (unroll_n >= 3) { - if (!isTransB) { - if (i == 2) { - prefetcht0( - ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg02 : reg14; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg08 : reg20; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (i == 7) { - if (!isTransB) { - sub(BO1, -8 * SIZE); - } - } - - if (unroll_n >= 4) { - if (!isTransB) { - if (i == 3) { - prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg03 : reg15; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg09 : reg21; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 5) { - if (!isTransB) { - if (i == 4) { - prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg04 : reg16; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg10 : reg22; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 6) { - if (!isTransB) { - if (i == 5) { - prefetcht0( - ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg05 : reg17; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg11 : reg23; - fma(useFma, ymm1, ymm2, fmareg); - } - } - if (isTransB) { - prefetcht0(ptr[BO1 + BO2]); - add(BO1, LDB); - } - - if (i == 0) { - if (unroll_m >= 4) { - if (!isDirect) { - prefetcht0( - ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - } - } - if (i == 1 || i == 2) { - if (unroll_m >= 8) { - if (!isDirect) { - prefetcht0(ptr[AO1 - + (PREFETCHSIZEA + (2 + 2 * i) * 8) - * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - } - } - if (i == 3 || i == 4 || i == 5 || i == 6) { - if (unroll_m >= 16) { - if (!isDirect) { - prefetcht0(ptr[AO1 - + (PREFETCHSIZEA + (2 + 2 * i) * 8) - * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - } - } - if (i == 7) { - if (!isTransB) { - if (unroll_n >= 4) { - sub(BO2, -8 * SIZE); - } - } - if (!isTransA) { - prefetcht2(ptr[AA]); - lea(AA, ptr[AA + LDA]); - } - } - - if (!isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, - ptr[AO1 - + (unroll_m * (i + 1) + 0 * 8 - OFFSET) - * SIZE]); - } else { - vmaskmovps( - ymm0, VMASK, - ptr[AO1 - + (unroll_m * (i + 1) + 0 * 8 - OFFSET) - * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 - + (unroll_m * (i + 1) + 1 * 8 - - OFFSET) - * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 - + (unroll_m * (i + 1) + 1 * 8 - - OFFSET) - * SIZE]); - } - } - } - } - - if (!isDirect) { - sub(AO1, -unroll_m * 8 * SIZE); - } - sub(LL, 1); - - }; - - // Inner kernel with k=4 - auto innerkernel4 = [&](int unroll_m, int unroll_n, - bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, - bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, - Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, - Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, - Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, - Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, - Ymm reg23) { - - Ymm fmareg; - - if (!isDirect) { - prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - - for (int i = 0; i < 4; i++) { - if (isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, - ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } - } - add(AO1, LDA); - } - - if (!isTransB) { - vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg00 : reg12; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg06 : reg18; - fma(useFma, ymm1, ymm2, fmareg); - } - if (i == 0) { - if (!isTransB) { - prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); - } - } - if (unroll_n >= 2) { - if (!isTransB) { - if (i == 1) { - prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg01 : reg13; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg07 : reg19; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (isCopy) { - vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], - ymm0); - if (unroll_m >= 16) { - vmovups(ptr[LDA4 - + (unroll_m * i + 1 * 8 - OFFSET) - * SIZE], - ymm1); - } - if (i == 3) { - sub(LDA4, -unroll_m * 4 * SIZE); - } - } - - if (unroll_n >= 3) { - if (!isTransB) { - if (i == 2) { - prefetcht0( - ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg02 : reg14; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg08 : reg20; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (i == 7) { - if (!isTransB) { - sub(BO1, -8 * SIZE); - } - } - - if (unroll_n >= 4) { - if (!isTransB) { - if (i == 3) { - prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg03 : reg15; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg09 : reg21; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 5) { - if (!isTransB) { - if (i == 4) { - prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg04 : reg16; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg10 : reg22; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 6) { - if (!isTransB) { - if (i == 5) { - prefetcht0( - ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg05 : reg17; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg11 : reg23; - fma(useFma, ymm1, ymm2, fmareg); - } - } - if (isTransB) { - prefetcht0(ptr[BO1 + BO2]); - add(BO1, LDB); - } - - if (i == 0) { - if (unroll_m >= 4) { - if (!isDirect) { - prefetcht0( - ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - } - } - if (i == 1 || i == 2) { - if (unroll_m >= 8) { - if (!isDirect) { - prefetcht0(ptr[AO1 - + (PREFETCHSIZEA + (2 + 2 * i) * 8) - * SIZE]); - } else { - prefetcht0(ptr[AO1 + LDA4]); - } - } - } - if (i == 3) { - if (!isTransB) { - sub(BO1, -4 * SIZE); - if (unroll_n >= 4) { - sub(BO2, -4 * SIZE); - } - } - } - - if (!isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, - ptr[AO1 - + (unroll_m * (i + 1) + 0 * 8 - OFFSET) - * SIZE]); - } else { - vmaskmovps( - ymm0, VMASK, - ptr[AO1 - + (unroll_m * (i + 1) + 0 * 8 - OFFSET) - * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 - + (unroll_m * (i + 1) + 1 * 8 - - OFFSET) - * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 - + (unroll_m * (i + 1) + 1 * 8 - - OFFSET) - * SIZE]); - } - } - } - } - - if (!isDirect) { - sub(AO1, -unroll_m * 4 * SIZE); - } - - }; - - // Inner kernel with k=2 - auto innerkernel2 = [&](int unroll_m, int unroll_n, - bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, - bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, - Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, - Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, - Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, - Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, - Ymm reg23) { - - Ymm fmareg; - - for (int i = 0; i < 2; i++) { - if (isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, - ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } - } - add(AO1, LDA); - } - - if (!isTransB) { - vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg00 : reg12; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg06 : reg18; - fma(useFma, ymm1, ymm2, fmareg); - } - if (unroll_n >= 2) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg01 : reg13; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg07 : reg19; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 3) { - if (!isTransB) { - if (i == 2) { - prefetcht0( - ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); - } - vbroadcastss( - ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg02 : reg14; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg08 : reg20; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 4) { - if (!isTransB) { - vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg03 : reg15; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg09 : reg21; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 5) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg04 : reg16; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg10 : reg22; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (unroll_n >= 6) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); - } - fmareg = (i % 2 == 0) ? reg05 : reg17; - fma(useFma, ymm0, ymm2, fmareg); - if (unroll_m >= 16) { - fmareg = (i % 2 == 0) ? reg11 : reg23; - fma(useFma, ymm1, ymm2, fmareg); - } - } - - if (isCopy) { - vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], - ymm0); - if (unroll_m >= 16) { - vmovups(ptr[LDA4 - + (unroll_m * 0 + 1 * 8 - OFFSET) - * SIZE], - ymm1); - } - sub(LDA4, -unroll_m * SIZE); - } - - if (!isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, ptr[AO1 - + (unroll_m * 1 + 0 * 8 - OFFSET) - * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, - ptr[AO1 - + (unroll_m * 1 + 0 * 8 - OFFSET) - * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, - ptr[AO1 - + (unroll_m * 1 + 1 * 8 - OFFSET) - * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 - + (unroll_m * 1 + 1 * 8 - OFFSET) - * SIZE]); - } - } - sub(AO1, -unroll_m * SIZE); - } - - if (!isTransB) { - sub(BO1, -SIZE); - if (unroll_n >= 4) { - sub(BO2, -SIZE); - } - } else { - add(BO1, LDB); - } - } - - }; - - // Inner kernel with k=1 - auto innerkernel1 = [&](int unroll_m, int unroll_n, - bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, - bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, - Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, - Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) { - - if (isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); - } - } - add(AO1, LDA); - } - - if (!isTransB) { - vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); - } - fma(useFma, ymm0, ymm2, reg00); - if (unroll_m >= 16) { - fma(useFma, ymm1, ymm2, reg06); - } - - if (unroll_n >= 2) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); - } - fma(useFma, ymm0, ymm2, reg01); - if (unroll_m >= 16) { - fma(useFma, ymm1, ymm2, reg07); - } - } - - if (unroll_n >= 3) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); - } - fma(useFma, ymm0, ymm2, reg02); - if (unroll_m >= 16) { - fma(useFma, ymm1, ymm2, reg08); - } - } - - if (unroll_n >= 4) { - if (!isTransB) { - vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); - } - fma(useFma, ymm0, ymm2, reg03); - if (unroll_m >= 16) { - fma(useFma, ymm1, ymm2, reg09); - } - } - - if (unroll_n >= 5) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); - } - fma(useFma, ymm0, ymm2, reg04); - if (unroll_m >= 16) { - fma(useFma, ymm1, ymm2, reg10); - } - } - - if (unroll_n >= 6) { - if (!isTransB) { - vbroadcastss( - ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); - } else { - vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); - } - fma(useFma, ymm0, ymm2, reg05); - if (unroll_m >= 16) { - fma(useFma, ymm1, ymm2, reg11); - } - } - - if (isCopy) { - vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], - ymm0); - if (unroll_m >= 16) { - vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], - ymm1); - } - sub(LDA4, -unroll_m * SIZE); - } - - if (!isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, - ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, - ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 - + (unroll_m * 1 + 1 * 8 - OFFSET) - * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 - + (unroll_m * 1 + 1 * 8 - OFFSET) - * SIZE]); - } - } - sub(AO1, -unroll_m * SIZE); - } - - if (!isTransB) { - sub(BO1, -SIZE); - if (unroll_n >= 4) { - sub(BO2, -SIZE); - } - } else { - add(BO1, LDB); - } - - }; - - // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as - // appropriate - // After calculating results in registers, writes back to C matrix - auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma, - Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6), - Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9), - Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12), - Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15), - Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6), - Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9), - Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12), - Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) { - if (!isDirect) { - lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); - } else { - mov(AO1, A); - } - - if (isCopy) { - lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]); - } else { - lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]); - } - - if (isTransB) { - lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]); - lea(BO2, ptr[BO2 + LDB * 2]); - } - - if (!isDirect) { - if (isLoad1Unmasked) { - vmovups(ymm0, - ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); - } else { - vmaskmovps(ymm0, VMASK, - ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); - } - if (unroll_m >= 16) { - if (isLoad2Unmasked) { - vmovups(ymm1, ptr[AO1 - + (unroll_m * 0 + 1 * 8 - OFFSET) - * SIZE]); - } else { - vmaskmovps(ymm1, VMASK, - ptr[AO1 - + (unroll_m * 0 + 1 * 8 - OFFSET) - * SIZE]); - } - } - } - - for (int i = 4; i < 10; i++) { - vxorps(Ymm(i), Ymm(i), Ymm(i)); - vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6)); - } - - mov(LL, K); - sar(LL, 3); - - Label kernel12, kernel13, kernel14, kernel15; - Label kernel16, kernel17, kernel18; - - sub(LL, SECOND_FETCH); - jle(kernel13, T_NEAR); - align(16); - - L(kernel12); - innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, - reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, - reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, - reg21, reg22, reg23); - jg(kernel12, T_NEAR); - align(16); - - L(kernel13); - prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]); - if (unroll_n >= 2) - prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]); - if (unroll_n >= 3) - prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]); - if (unroll_n >= 4) - prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]); - if (unroll_n >= 5) - prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]); - if (unroll_n >= 6) - prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]); - - add(LL, SECOND_FETCH); - jle(kernel15, T_NEAR); - align(16); - - L(kernel14); - innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, - reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, - reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, - reg21, reg22, reg23); - jg(kernel14, T_NEAR); - align(16); - - L(kernel15); - test(K, 4); - jle(kernel16, T_NEAR); - innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, - reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, - reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, - reg21, reg22, reg23); - - L(kernel16); - test(K, 2); - jle(kernel17, T_NEAR); - innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, - reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, - reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, - reg21, reg22, reg23); - align(16); - - L(kernel17); - if (unroll_m == 16) { - if (unroll_n <= 3) { - vaddps(reg00, reg00, reg12); - vaddps(reg01, reg01, reg13); - vaddps(reg02, reg02, reg14); - vaddps(reg06, reg06, reg18); - vaddps(reg07, reg07, reg19); - vaddps(reg08, reg08, reg20); - } - } - - if (unroll_m <= 8) { - vaddps(reg00, reg00, reg12); - vaddps(reg01, reg01, reg13); - vaddps(reg02, reg02, reg14); - vaddps(reg03, reg03, reg15); - vaddps(reg04, reg04, reg16); - vaddps(reg05, reg05, reg17); - } - - test(K, 1); - jle(kernel18, T_NEAR); - innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, - reg05, reg06, reg07, reg08, reg09, reg10, reg11); - align(16); - - L(kernel18); - vbroadcastss(VALPHA, ALPHA); - - if (isBetaN) { - vbroadcastss(VBETA, BETA); - } - - // Write back the results; all beta and bias cases need to be - // handled - switch (unroll_n) { - case 1: mov(rax, LDC); break; - case 2: lea(rax, ptr[LDC * 2]); break; - case 3: lea(rax, ptr[LDC + LDC * 2]); break; - case 4: lea(rax, ptr[LDC + LDC * 4]); break; - case 5: - lea(rax, ptr[LDC * 4]); - add(rax, LDC); - break; - case 6: - lea(rax, ptr[LDC + LDC * 2]); - add(rax, rax); - break; - } - - if (hasBias) { - mov(BIAS1, BIAS); - if (isLoad1Unmasked) { - vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); - } else { - vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]); - } - } - - for (int i = 0; i < unroll_n; i++) { - vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA); - if (!isBeta0) { - if (isLoad1Unmasked) { - switch (i) { - case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break; - case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break; - case 2: - vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]); - break; - case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break; - case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break; - case 5: - vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]); - break; - } - } else { - switch (i) { - case 0: - vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]); - break; - case 1: - vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]); - break; - case 2: - vmaskmovps( - ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]); - break; - case 3: - vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]); - break; - case 4: - vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]); - break; - case 5: - vmaskmovps( - ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]); - break; - } - } - - if (!isBetaN) { - vaddps(Ymm(i + 4), ymm0, Ymm(i + 4)); - } else { - fma(useFma, VBETA, ymm0, Ymm(i + 4), true); - } - } - if (hasBias) { - vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4)); - } - if (isLoad1Unmasked) { - switch (i) { - case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break; - case 1: - vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4)); - break; - case 2: - vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); - break; - case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break; - case 4: - vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4)); - break; - case 5: - vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); - break; - } - } else { - switch (i) { - case 0: - vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4)); - break; - case 1: - vmaskmovps( - ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); - break; - case 2: - vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK, - Ymm(i + 4)); - break; - case 3: - vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4)); - break; - case 4: - vmaskmovps( - ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); - break; - case 5: - vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK, - Ymm(i + 4)); - break; - } - } - - if (unroll_m >= 16) { - // Re-use ymm4 (VBIAS2) - if (i == 0) { - if (hasBias) { - if (isLoad1Unmasked) { - vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]); - } else { - vmaskmovps( - VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]); - } - } - } - vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA); - if (!isBeta0) { - if (isLoad2Unmasked) { - switch (i) { - case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break; - case 1: - vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]); - break; - case 2: - vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]); - break; - case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break; - case 4: - vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]); - break; - case 5: - vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]); - break; - } - } else { - switch (i) { - case 0: - vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]); - break; - case 1: - vmaskmovps( - ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]); - break; - case 2: - vmaskmovps(ymm0, VMASK, - ptr[CO1 + LDC * 2 + 8 * SIZE]); - break; - case 3: - vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]); - break; - case 4: - vmaskmovps( - ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]); - break; - case 5: - vmaskmovps(ymm0, VMASK, - ptr[CO2 + LDC * 2 + 8 * SIZE]); - break; - } - } - if (!isBetaN) { - vaddps(Ymm(i + 10), ymm0, Ymm(i + 10)); - } else { - fma(useFma, VBETA, ymm0, Ymm(i + 10), true); - } - } - if (hasBias) { - vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10)); - } - if (isLoad2Unmasked) { - switch (i) { - case 0: - vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10)); - break; - case 1: - vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10)); - break; - case 2: - vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); - break; - case 3: - vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10)); - break; - case 4: - vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10)); - break; - case 5: - vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); - break; - } - } else { - switch (i) { - case 0: - vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10)); - break; - case 1: - vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK, - Ymm(i + 10)); - break; - case 2: - vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK, - Ymm(i + 10)); - break; - case 3: - vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10)); - break; - case 4: - vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK, - Ymm(i + 10)); - break; - case 5: - vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK, - Ymm(i + 10)); - break; - } - } - } - if (i == 2) - add(CO1, rax); - } - if (unroll_n >= 4) { - add(CO2, rax); - } - - // Compute next address of B - if (!isTransB) { - lea(rax, ptr[K * SIZE]); - switch (unroll_n) { - case 1: - add(BO1, LDB); - add(BO2, LDB); - break; - case 2: - lea(BO1, ptr[BO1 + LDB * 2]); - lea(BO2, ptr[BO2 + LDB * 2]); - break; - case 3: - lea(BO1, ptr[BO1 + LDB3]); - lea(BO2, ptr[BO2 + LDB3]); - break; - case 4: - lea(BO1, ptr[BO1 + LDB * 4]); - lea(BO2, ptr[BO2 + LDB * 4]); - break; - case 5: - lea(BO1, ptr[BO1 + LDB * 4]); - add(BO1, LDB); - lea(BO2, ptr[BO2 + LDB * 4]); - add(BO2, LDB); - break; - case 6: - lea(BO1, ptr[BO1 + LDB3 * 2]); - lea(BO2, ptr[BO2 + LDB3 * 2]); - break; - } - sub(BO1, rax); - sub(BO2, rax); - } else { - mov(rax, LDB); - imul(rax, K); - sub(BO1, rax); - add(BO1, unroll_n * SIZE); - } - }; - - auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, true); - }; - - auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, true); - }; - - auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, true); - }; - - auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy, - bool useFma = true) { - kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), - Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), - Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), - Ymm(13), Ymm(14), Ymm(15)); - }; - - auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, false); - }; - - auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, false); - }; - - auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy, - bool useFma = true) { - kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), - Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), - Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), - Ymm(15)); - }; - - auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy); - }; - - auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy); - }; - - auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy, - bool useFma = true) { - kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), - Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), - Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), - Ymm(13), Ymm(14), Ymm(15)); - }; - - auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, false); - }; - - auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, - bool isLoad2Unmasked, bool isDirect, bool isCopy) { - kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, - isDirect, isCopy, false); - }; - - // High-level subroutine; does packing if needed, then splits C matrix. - // Operates on chunks of 16 rows, 6 columns at a time (handling tail - // cases appropriately). - // Masking is used for tail cases where M is not divisible by 8. - auto subloop = [&]( - int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { - if (isTransA) { - do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked); - } - - Label subloop11, subloop11mask; - Label subloop20, subloop21, subloop22, subloop23; - Label subloop24, subloop25; - Label subloop30, subloop31, subloop32, subloop33; - Label subloop34, subloop35; - Label subloop98, subloop98mask; - Label subloop99, subloop99mask; - - mov(CO1, C); - lea(CO2, ptr[CO1 + LDC * 2]); - add(CO2, LDC); - add(C, unroll_m * SIZE); - mov(BO1, B); - if (!isTransB) { - lea(BO2, qword[B + LDB3]); - } - - if (!isTransA) { - lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]); - cmp(M, UNROLL_M); - jg(subloop98, T_NEAR); - - mov(AA, ORIG_A); - lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]); - L(subloop98); - } - - mov(LL, N); - mov(I, LL); - if (!isTransA) { - // If N is too small, skip copy operation - cmp(LL, UNROLL_N * 3); - jle(subloop30, T_NEAR); - - // If A is not aligned to cache line - cmp(FLAG, 0); - je(subloop30, T_NEAR); - } else { - cmp(LL, UNROLL_N); - jl(subloop20, T_NEAR); - } - align(16); - - if (!isTransA) { - if (unroll_m == 16) { - kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, true, true); - } else { - kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, true, true); - } - } else { - if (unroll_m == 16) { - kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, false, false); - } else { - kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, false, false); - } - } - - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jl(subloop20, T_NEAR); - align(16); - - L(subloop11); - if (unroll_m == 16) { - kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, false, false); - } else { - kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, - false, false); - } - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jge(subloop11, T_NEAR); - align(16); - - L(subloop20); - cmp(I, 1); - jne(subloop21, T_NEAR); - if (unroll_m == 16) { - kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, - false, false); - } else { - kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, - false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop21); - cmp(I, 2); - jne(subloop22, T_NEAR); - if (unroll_m == 16) { - kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, - false, false); - } else { - kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, - false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop22); - cmp(I, 3); - jne(subloop23, T_NEAR); - if (unroll_m == 16) { - kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, - false, false); - } else { - kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, - false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop23); - cmp(I, 4); - jne(subloop24, T_NEAR); - if (unroll_m == 16) { - kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, - false, false); - } else { - kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, - false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop24); - cmp(I, 5); - jne(subloop99, T_NEAR); - if (unroll_m == 16) { - kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, - false, false); - } else { - kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, - false); - } - jmp(subloop99, T_NEAR); - align(16); - - if (!isTransA) { - L(subloop30); - cmp(I, UNROLL_N); - jl(subloop25, T_NEAR); - align(16); - - L(subloop31); - if (unroll_m == 16) { - kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, true, false); - } else { - kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, - isLoad2Unmasked, true, false); - } - sub(I, UNROLL_N); - cmp(I, UNROLL_N); - jge(subloop31, T_NEAR); - align(16); - - L(subloop25); - cmp(I, 1); - jne(subloop32, T_NEAR); - if (unroll_m == 16) { - kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } else { - kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop32); - cmp(I, 2); - jne(subloop33, T_NEAR); - if (unroll_m == 16) { - kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } else { - kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop33); - cmp(I, 3); - jne(subloop34, T_NEAR); - if (unroll_m == 16) { - kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } else { - kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop34); - cmp(I, 4); - jne(subloop35, T_NEAR); - if (unroll_m == 16) { - kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } else { - kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } - jmp(subloop99, T_NEAR); - align(16); - - L(subloop35); - cmp(I, 5); - jne(subloop99, T_NEAR); - if (unroll_m == 16) { - kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } else { - kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, - true, false); - } - align(16); - } - - L(subloop99); - // Compute address for A - if (!isTransA) { - add(A, unroll_m * SIZE); - } else { - mov(rax, LDA); - imul(rax, rax, unroll_m); - add(A, rax); - } - - // Compute next address of BIAS - if (hasBias) { - add(BIAS, unroll_m * SIZE); - } - }; - - preamble(); - - Label buffer_in_ws, buffer_allocated; - - // Get the registers - mov(B, ARG_B); - mov(LDB, ARG_LDB); - mov(r15, ARG_BETA); - mov(r12, ARG_C); - if (hasBias) - mov(r10, ARG_BIAS); - mov(LDC, ARG_LDC); - mov(rbp, rsp); - - vmovss(xmm0, ptr[ARG_ALPHA]); - vmovss(xmm1, ptr[r15]); - -#if _WIN32 - mov(A, ARG_A); - mov(LDA, ARG_LDA); -#endif - - cmp(K, STACK_K_CAPACITY); - jg(buffer_in_ws, T_NEAR); - - // Create buffer and align to 4kB page - lea(rax, ptr[K * SIZE]); - sal(rax, 4); - add(rax, 256); - sub(rsp, rax); - and_(rsp, -PAGE_4K); - jmp(buffer_allocated, T_NEAR); - - L(buffer_in_ws); - mov(rsp, ARG_WS); - - L(buffer_allocated); - - mov(ORIG_SP, rbp); - mov(M, ARG_M); - mov(N, ARG_N); - mov(C, r12); - if (hasBias) - mov(BIAS, r10); - vmovss(ALPHA, xmm0); - vmovss(BETA, xmm1); - sub(A, -OFFSET * SIZE); - sub(B, -OFFSET * SIZE); - mov(ORIG_A, A); - sal(LDA, BASE_SHIFT); - sal(LDB, BASE_SHIFT); - sal(LDC, BASE_SHIFT); - lea(LDB3, ptr[LDB + LDB * 2]); - - for (int i = 0; i < 8; i++) { - mov(dword[rsp + 88 + i * 4], i); - } - - if (isTransA && is_avx2) { - movq(xmm0, LDA); - vpbroadcastq(ymm1, xmm0); - vinsertf128(ymm0, ymm0, xmm0, 1); - vpermilpd(ymm0, ymm0, 5); - vpaddq(ymm1, ymm1, ymm1); - vperm2f128(ymm1, ymm1, ymm1, 8); - vpaddq(ymm0, ymm0, ymm1); - vmovups(STRIDE, ymm0); - } - - // Check A alignment and leading dimension; take copy-based path as - // needed - mov(rax, LDA); - or_(rax, A); - and_(rax, 0x1f); - mov(FLAG, rax); - - Label main0, main1, main2, main3, main999; - - cmp(M, UNROLL_M); - jl(main0, T_NEAR); - align(16); - - L(main1); - subloop(UNROLL_M, true, true); - sub(M, UNROLL_M); - cmp(M, UNROLL_M); - jge(main1, T_NEAR); - align(16); - - L(main0); - cmp(M, 0); - jle(main999, T_NEAR); - - if (UNROLL_M > 8) { - cmp(M, 8); - jle(main2, T_NEAR); - - sub(M, 8); - vbroadcastss(VMASK, M); - vpcmpgtd(VMASK, VMASK, MASK); - - subloop(16, true, false); - jmp(main999, T_NEAR); - align(16); - - L(main2); - cmp(M, 8); - jne(main3, T_NEAR); - subloop(8, true, true); - jmp(main999, T_NEAR); - } - - align(16); - - L(main3); - vbroadcastss(VMASK, M); - if (is_avx2) { - vpcmpgtd(VMASK, VMASK, MASK); - } else { - auto xmask = Xmm(VMASK.getIdx()); - auto xmm_tmp = xmm4; - - vextractf128(xmm_tmp, VMASK, 1); - vpcmpgtd(xmask, xmask, MASK); - vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4 - vinsertf128(VMASK, VMASK, xmm_tmp, 1); - } - subloop(8, false, false); - align(16); - - L(main999); - // Restore original stack - mov(rsp, ORIG_SP); - - vzeroupper(); - postamble(); - - ker_ = this->getCode(); - } - - typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, - const float *alpha, const float *a, dim_t lda, - const float *b, dim_t ldb, const float *beta, float *c, - dim_t ldc, const float *bias, float *ws); - - void operator()(dim_t m, dim_t n, dim_t k, - const float *alpha, const float *a, dim_t lda, - const float *b, dim_t ldb, const float *beta, float *c, - dim_t ldc, const float *bias, float *ws) const - { - ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); - } - -private: - ker_t ker_; -}; - -const xbyak_gemm *get_xbyak_gemm( - bool isTransA, bool isTransB, float beta, bool hasBias) { - auto beta_idx = [](float beta) { - return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); - }; - - // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] - static xbyak_gemm *kernel_table[2][2][2][3]; - static std::once_flag initialized; - std::call_once(initialized, [=]{ - for (bool isTransA: {false, true}) - for (bool isTransB: {false, true}) - for (bool hasBias: {false, true}) - for (float beta: {0.0f, 1.0f, 2.0f}) { - // nocopy sgemm with bias for beta != 0.0 is not supported - if (hasBias && beta != 0.0) - continue; - kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = - new xbyak_gemm(isTransA, isTransB, beta, hasBias); - } - }); - - return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; -} - -void sgemm_nocopy_driver(const char *transa, - const char *transb, int m, int n, int k, const float *alpha, - const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, - float *c, dim_t ldc, const float *bias, float *ws) -{ - bool isTransA = (*transa == 'T' || *transa == 't'); - bool isTransB = (*transb == 'T' || *transb == 't'); - - int Bm, sizeM, Bn, sizeN, Bk, sizeK; - - int i, j; - - if ((m <= 0) || (n <= 0)) - return; - - if ((k <= 0) || (alpha[0] == 0.)) { - - if (beta[0] == 0.) { - for (j = 0; j < n; j++) - for (i = 0; i < m; i++) - c[i + j * ldc] = 0.0; - } else if (beta[0] != 1.) { - for (j = 0; j < n; j++) - for (i = 0; i < m; i++) - c[i + j * ldc] *= beta[0]; - } - - return; - } - - assert(IMPLICATION(bias != nullptr, *beta == 0.0)); - - // XXX: this happens on every thread... - bool hasBias = (bias != nullptr); - auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); - auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); - auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); - assert(ker_bn && ker_b1 && ker_b0); - - int BM = 4032; - int BN = isTransA ? 96 : 48; - int BK = isTransB ? 96 : 256; - const float *curA, *curB, *curBias = nullptr; - float *curC; - - for (Bk = 0; Bk < k; Bk += sizeK) { - sizeK = k - Bk; - if (sizeK >= BK * 2) - sizeK = BK; - else { - if (sizeK > BK) - sizeK = (sizeK + 1) / 2; - } - - for (Bm = 0; Bm < m; Bm += sizeM) { - sizeM = m - Bm; - if (sizeM >= BM * 2) - sizeM = BM; - else { - if (sizeM > BM + BM / 2) - sizeM = (sizeM + 1) / 2; - } - - for (Bn = 0; Bn < n; Bn += sizeN) { - sizeN = n - Bn; - if (sizeN >= BN * 2) - sizeN = BN; - else { - if (sizeN > BN + BN / 2) - sizeN = (sizeN + 1) / 2; - } - - if (!isTransA) { - curA = a + Bm + Bk * lda; - } else { - curA = a + Bk + Bm * lda; - } - if (!isTransB) { - curB = b + Bk + Bn * ldb; - } else { - curB = b + Bn + Bk * ldb; - } - curC = c + Bm + (size_t)Bn * ldc; - if (bias != nullptr) { - if (Bk == 0) { - curBias = bias + Bm; - } else { - curBias = nullptr; - } - } - if (Bk == 0) { - if (*beta == 0.0 && bias == nullptr) - (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, - alpha, curA, lda, curB, ldb, beta, curC, ldc, - curBias, ws); - else - (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, - alpha, curA, lda, curB, ldb, beta, curC, ldc, - curBias, ws); - } else { - (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, - alpha, curA, lda, curB, ldb, beta, curC, ldc, - curBias, ws); - } - } - } - } -} - -} - -mkldnn_status_t jit_avx_gemm_f32( - const char *transa, const char *transb, - const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, - const float *A, const int *p_lda, const float *B, const int *p_ldb, - const float *p_beta, float *C, const int *p_ldc, const float *bias) -{ - using namespace mkldnn::impl::utils; - using namespace avx_gemm_f32; - using namespace gemm_utils; - - if (*p_beta != 0 && bias) - return ref_gemm(transa, transb, p_m, p_n, p_k, - p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); - - int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); - - int m = *p_m; - int n = *p_n; - int k = *p_k; - dim_t lda = *p_lda; - dim_t ldb = *p_ldb; - dim_t ldc = *p_ldc; - float beta = *p_beta; - int MB, NB, KB; - - int nthr_m, nthr_n, nthr_k, nthr_mn; - - // Determine threading partitioning - calc_nthr_nocopy_avx( - m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); - - // May not happen, but just in case - if (nthr < nthr_m * nthr_n * nthr_k) - nthr = nthr_m * nthr_n * nthr_k; - - nthr_mn = nthr_m * nthr_n; - - unsigned char * ompstatus_ = nullptr; - unsigned char volatile *ompstatus = nullptr; - - float *c_buffers = nullptr; - float *ws_buffers = nullptr; - - if (nthr_k > 1) { - ompstatus_ = (unsigned char *) malloc( - nthr * CACHE_LINE_SIZE, - CACHE_LINE_SIZE); - ompstatus = (unsigned char volatile *) ompstatus_; - assert(ompstatus); - - for (int i = 0; i < nthr; i++) - ompstatus[i * CACHE_LINE_SIZE] = 0; - - c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB - * sizeof(float), PAGE_4K); - } - - const size_t ws_elems_per_thr = (size_t)k * 16 + 64; - const size_t ws_size_per_thr - = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); - if (k > STACK_K_CAPACITY) { - ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); - } - - parallel_nd(nthr, [&](const int ithr) { - int ithr_m, ithr_n, ithr_k, ithr_mn; - int m_from, m_to, myM; - int n_from, n_to, myN; - int k_from, k_to, myK; - int cbase, ibase; - const float *myA, *myB, *myBias = nullptr; - float *myC = C, myBeta; - float *ws = ws_buffers ? - ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; - dim_t ld = ldc; - - int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); - - if (ithr < nthr_m * nthr_n * nthr_k) { - - ithr_mn = ithr % nthr_mn; - ithr_m = ithr_mn % nthr_m; - ithr_n = ithr_mn / nthr_m; - ithr_k = ithr / nthr_mn; - - /* swap ithr_k for performance improvement */ - if (ithr_k == 0) - ithr_k = nthr_k - 1; - else if (ithr_k == nthr_k - 1) - ithr_k = 0; - - m_from = MB * (ithr_m); - m_to = MB * (ithr_m + 1); - if (m_to > m) - m_to = m; - myM = m_to - m_from; - - n_from = NB * (ithr_n); - n_to = NB * (ithr_n + 1); - if (n_to > n) - n_to = n; - myN = n_to - n_from; - - k_from = KB * (ithr_k); - k_to = KB * (ithr_k + 1); - if (k_to > k) - k_to = k; - myK = k_to - k_from; - - cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; - - if ((myM > 0) && (myN > 0)) { - - if (*transa == 'N' || *transa == 'n') { - myA = &(A[m_from + k_from * lda]); - } else { - myA = &(A[k_from + m_from * lda]); - } - if (*transb == 'N' || *transb == 'n') { - myB = &(B[k_from + n_from * ldb]); - } else { - myB = &(B[n_from + k_from * ldb]); - } - if (ithr_k == 0) { - myC = &(C[m_from + n_from * ldc]); - myBeta = beta; - ld = ldc; - if (bias) - myBias = &(bias[m_from]); - } else { - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); - myBeta = 0.0; - ld = MB; - myBias = nullptr; - } - - sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, - lda, myB, ldb, &myBeta, myC, ld, myBias, ws); - - if (nthr_k > 1 && !sum_later) - ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; - } - - if (nthr_k > 1 && !sum_later) { - - // sum matrices partitioned along K dimension - int n1, n2; - - partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); - - if (ithr_k > 0) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) - + (dim_t)n1 * MB; - /* need to wait until main thread finishes */ - while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { - }; - - /* my cache is hot */ - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - - for (int ik = 1; ik < nthr_k; ++ik) { - if (ik != ithr_k) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) - + (dim_t)n1 * MB; - - while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { - }; - - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - } - } - } - }); - - // handle C summation later - if (nthr_k > 1 && ompstatus[0] == 0) { - - parallel_nd(nthr, [&](const int ithr) { - int ithr_m, ithr_n, ithr_k, ithr_mn; - int m_from, m_to, myM; - int n_from, n_to, myN; - int cbase; - float *myC = C; - - if (ithr < nthr_m * nthr_n * nthr_k) { - - ithr_mn = ithr % nthr_mn; - ithr_m = ithr_mn % nthr_m; - ithr_n = ithr_mn / nthr_m; - ithr_k = ithr / nthr_mn; - - /* swap ithr_k for performance improvement */ - if (ithr_k == 0) - ithr_k = nthr_k - 1; - else if (ithr_k == nthr_k - 1) - ithr_k = 0; - - m_from = MB * (ithr_m); - m_to = MB * (ithr_m + 1); - if (m_to > m) - m_to = m; - myM = m_to - m_from; - - n_from = NB * (ithr_n); - n_to = NB * (ithr_n + 1); - if (n_to > n) - n_to = n; - myN = n_to - n_from; - - cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - - if (nthr_k > 1) { - // sum matrices partitioned along K dimension - int n1, n2; - - partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); - - if (ithr_k > 0) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) - + (dim_t)n1 * MB; - - /* my cache is hot */ - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - - for (int ik = 1; ik < nthr_k; ++ik) { - if (ik != ithr_k) { - - myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) - + (dim_t)n1 * MB; - - sum_two_matrices(myM, n2, myC, MB, - &C[m_from + (n_from + n1) * ldc], ldc); - } - } - } - } - }); - } - - - free(c_buffers); - free(ompstatus_); - free(ws_buffers); - - return mkldnn_success; -} - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp deleted file mode 100644 index aabf520a3..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX_GEMM_F32_HPP -#define JIT_AVX_GEMM_F32_HPP - -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t jit_avx_gemm_f32( - const char *transa, const char *transb, const int *M, - const int *N, const int *K, const float *alpha, const float *A, - const int *lda, const float *B, const int *ldb, const float *beta, - float *C, const int *ldc, const float *bias = nullptr); - - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp deleted file mode 100644 index 5147885a8..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp +++ /dev/null @@ -1,346 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_types.h" - -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -#include "gemm_utils_f32.hpp" -#include "ref_gemm_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace gemm_utils; - -namespace { - -template -void copy_A( - bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) { - for (int k = 0; k < K; k++) { - PRAGMA_OMP_SIMD() - for (int i = 0; i < unroll_factor::m; i++) { - ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; - } - ws += unroll_factor::m; - } -} - -template -void kernel_mxn(int K, const data_t *A, const dim_t lda, - const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc, - const data_t alpha, const data_t beta) { - data_t c[unroll_factor::m * unroll_factor::n] = - { static_cast(0.) }; - for (int k = 0; k < K; k++) { - for (int j = 0; j < unroll_factor::n; j++) { - data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; - PRAGMA_OMP_SIMD() - for (int i = 0; i < unroll_factor::m; i++) { - data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; - c[i + unroll_factor::m * j] += a * b; - } - } - } - for (int j = 0; j < unroll_factor::n; j++) { - PRAGMA_OMP_SIMD() - for (int i = 0; i < unroll_factor::m; i++) { - C[i + j * ldc] = (beta == static_cast(0.)) - ? alpha * c[i + unroll_factor::m * j] - : alpha * c[i + unroll_factor::m * j] - + beta * C[i + j * ldc]; - } - } -} - -template -void block_ker(const int M, const int N, const int K, - const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, - data_t *C, const dim_t ldc, const data_t alpha, const data_t beta, - data_t *ws, bool do_copy) { - int Nu = rnd_dn(N, unroll_factor::n); - int Mu = rnd_dn(M, unroll_factor::m); - for (int i = 0; i < Mu; i += unroll_factor::m) { - for (int j = 0; j < Nu; j += unroll_factor::n) { - const data_t *b = isTransB ? &B[j] : &B[j * ldb]; - const data_t *a = isTransA ? &A[i * lda] : &A[i]; - if (do_copy) { - if (j == 0) { - copy_A(isTransA, K, a, lda, ws); - } - kernel_mxn( - K, ws, unroll_factor::m, b, ldb, - &C[i + j * ldc], ldc, alpha, beta); - } else { - kernel_mxn( - K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); - } - } - } - // tail processing - for (int i = 0; i < M; i++) { - for (int j = Nu; j < N; j++) { - data_t c = beta == static_cast(0.) - ? static_cast(0.) - : beta * C[i + j * ldc]; - for (int p = 0; p < K; p++) { - data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; - data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; - c += alpha * a * b; - } - C[i + j * ldc] = c; - } - } - for (int i = Mu; i < M; i++) { - for (int j = 0; j < Nu; j++) { - data_t c = beta == static_cast(0.) - ? static_cast(0.) - : beta * C[i + j * ldc]; - for (int p = 0; p < K; p++) { - data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; - data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; - c += alpha * a * b; - } - C[i + j * ldc] = c; - } - } -} - -template -void gemm_ithr(const int M, const int N, const int K, const data_t alpha, - const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, - const data_t beta, data_t *C, const dim_t ldc, bool do_copy, - data_t *ws) { - constexpr int BM = gemm_traits::BM; - constexpr int BN = gemm_traits::BN; - constexpr int BK = gemm_traits::BK; - - const data_t *curA; - const data_t *curB; - data_t *curC; - - if ((M <= 0) || (N <= 0)) - return; - - if ((K <= 0) || (alpha == static_cast(0))) { - dim_t MN = N * M; - if (beta == static_cast(0.)) { - for (dim_t j = 0; j < MN; j++) - C[j] = static_cast(0.); - } else if (beta != static_cast(1.)) { - for (dim_t j = 0; j < MN; j++) - C[j] *= beta; - } - return; - } - - for (int Bk = 0; Bk < K; Bk += BK) { - int kb = nstl::min(K - Bk, BK); - for (int Bm = 0; Bm < M; Bm += BM) { - int mb = nstl::min(M - Bm, BM); - for (int Bn = 0; Bn < N; Bn += BN) { - int nb = nstl::min(N - Bn, BN); - curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; - curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; - curC = C + Bm + Bn * ldc; - if (Bk == 0) { - block_ker(mb, nb, kb, curA, lda, - curB, ldb, curC, ldc, alpha, beta, ws, do_copy); - } else { - block_ker(mb, nb, kb, curA, lda, - curB, ldb, curC, ldc, alpha, static_cast(1.0), - ws, do_copy); - } - } - } - } -} - -} - -template -mkldnn_status_t ref_gemm( - const char *transa_, const char *transb_, const int *M_, - const int *N_, const int *K_, const data_t *alpha_, const data_t *A, - const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_, - data_t *C, const int *ldc_, const data_t *bias) { - - bool isTransA = (*transa_ == 'T' || *transa_ == 't'); - bool isTransB = (*transb_ == 'T' || *transb_ == 't'); - const int M = *M_, N = *N_, K = *K_; - const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; - const data_t alpha = *alpha_, beta = *beta_; - - int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads(); - int nthr_m, nthr_n, nthr_k; - int MB, NB, KB; - // thread balancing over M, N, K & size of blocking dimensions - calc_nthr_nocopy_avx( - M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); - assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); - - data_t *c_buffers = nullptr; - data_t *ws_buffers = nullptr; - if (nthr_k > 1) { - c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB - * sizeof(data_t), PAGE_4K); - if (!c_buffers) { - nthr_k = 1; - KB = K; - } - } - - bool do_copy = (NB / unroll_factor::n > 3); - const int nthr_mn = nthr_m * nthr_n; - const int nthr = nthr_mn * nthr_k; - const size_t ws_elems_per_thr = K * unroll_factor::m; - const size_t ws_size_per_thr - = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); - if (do_copy) { - ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K); - if (!ws_buffers) - do_copy = false; - } - - auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N, - int ithr) { - from = NB * (ithr); - to = NB * (ithr + 1); - if (to > N) - to = N; - myN = to - from; - }; - - parallel_nd(nthr, [&](const int ithr) { - int ithr_mn = ithr % nthr_mn; - int ithr_m = ithr_mn % nthr_m; - int ithr_n = ithr_mn / nthr_m; - int ithr_k = ithr / nthr_mn; - - int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - - data_t *ws = do_copy - ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) - : nullptr; - - int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, - k_from = 0, k_to = 0, myK = 0; - - get_thr_block(m_from, m_to, myM, MB, M, ithr_m); - get_thr_block(n_from, n_to, myN, NB, N, ithr_n); - get_thr_block(k_from, k_to, myK, KB, K, ithr_k); - - if (myM > 0 && myN > 0) { - data_t myBeta, *myC; - dim_t ld; - if (ithr_k == 0) { - myC = &(C[m_from + n_from * ldc]); - myBeta = beta; - ld = ldc; - } else { - myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); - myBeta = 0.0f; - ld = MB; - } - const data_t *myA = isTransA - ? &(A[k_from + m_from * lda]) - : &(A[m_from + k_from * lda]); - const data_t *myB = isTransB - ? &(B[n_from + k_from * ldb]) - : &(B[k_from + n_from * ldb]); - - if (!isTransA) { - if (!isTransB) { - gemm_ithr(myM, myN, myK, alpha, myA, - lda, myB, ldb, myBeta, myC, ld, do_copy, ws); - } else { - gemm_ithr(myM, myN, myK, alpha, myA, - lda, myB, ldb, myBeta, myC, ld, do_copy, ws); - } - } else { - if (!isTransB) { - gemm_ithr(myM, myN, myK, alpha, myA, - lda, myB, ldb, myBeta, myC, ld, do_copy, ws); - } else { - gemm_ithr(myM, myN, myK, alpha, myA, - lda, myB, ldb, myBeta, myC, ld, do_copy, ws); - } - } - } - }); - - if (nthr_k > 1) { - parallel_nd(nthr, [&](const int ithr) { - int ithr_mn = ithr % nthr_mn; - int ithr_m = ithr_mn % nthr_m; - int ithr_k = ithr / nthr_mn; - int ithr_n = ithr_mn / nthr_m; - - int n_from = 0, n_to = 0, myN = 0; - int m_from = 0, m_to = 0, myM = 0; - - int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); - - get_thr_block(n_from, n_to, myN, NB, N, ithr_n); - get_thr_block(m_from, m_to, myM, MB, M, ithr_m); - - // sum matrices partitioned along K dimension - int offset = 0, block = 0; - gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset, - &block); - for (int ik = 1; ik < nthr_k; ++ik) { - data_t *myC = c_buffers - + MB * ((dim_t)NB * (cbase + ik - 1) + offset); - - gemm_utils::sum_two_matrices(myM, block, myC, MB, - &C[m_from + (n_from + offset) * ldc], ldc); - } - }); - } - - if (bias) { - parallel_nd(N, M, [&](int i, int j) { - C[i*ldc + j] += bias[j]; - }); - } - - free(ws_buffers); - free(c_buffers); - - return mkldnn_success; -} - -template mkldnn_status_t ref_gemm( - const char *transa_, const char *transb_, - const int *M_, const int *N_, const int *K_, const float *alpha_, - const float *A, const int *lda_, const float *B, const int *ldb_, - const float *beta_, float *C, const int *ldc_, const float *bias); - -template mkldnn_status_t ref_gemm( - const char *transa_, const char *transb_, - const int *M_, const int *N_, const int *K_, const double *alpha_, - const double *A, const int *lda_, const double *B, const int *ldb_, - const double *beta_, double *C, const int *ldc_, const double *bias); -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp deleted file mode 100644 index 7c90ba627..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef REF_GEMM_F32_HPP -#define REF_GEMM_F32_HPP - -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M, - const int *N, const int *K, const data_t *alpha, const data_t *A, - const int *lda, const data_t *B, const int *ldb, const data_t *beta, - data_t *C, const int *ldc, const data_t *bias); - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp deleted file mode 100644 index 3dbe07d74..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp +++ /dev/null @@ -1,280 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" - -#include "mkldnn_traits.hpp" -#include "nstl.hpp" - -#include "jit_generator.hpp" - -#include "gemm.hpp" - -#include "f32/jit_avx512_common_gemm_f32.hpp" -#include "f32/jit_avx_gemm_f32.hpp" -#include "f32/ref_gemm_f32.hpp" - -#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp" -#include "s8x8s32/simple_gemm_s8s8s32.hpp" -#include "s8x8s32/ref_gemm_s8x8s32.hpp" - -#include "os_blas.hpp" - -/* USE_MKL USE_CBLAS effect - * ------- --------- ------ - * yes yes use Intel(R) MKL CBLAS - * yes no use jit - * no yes system-dependent CBLAS - * no no use jit - */ - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t check_gemm_input(const char *transa, const char *transb, - const int *M, const int *N, const int *K, const int *lda, - const int *ldb, const int *ldc, const float *alpha, const float *beta, - const bool with_bias) { - if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta)) - return mkldnn_invalid_arguments; - if (with_bias && *beta != 0) - return mkldnn_unimplemented; - bool consistency = true - && utils::one_of(*transa, 'T', 't', 'N', 'n') - && utils::one_of(*transb, 'T', 't', 'N', 'n') - && *M >= 0 - && *N >= 0 - && *K >= 0; - - if (!consistency) - return mkldnn_invalid_arguments; - bool isTransA = utils::one_of(*transa, 'T', 't'); - bool isTransB = utils::one_of(*transb, 'T', 't'); - int nrowA = isTransA ? *K : *M; - int nrowB = isTransB ? *N : *K; - consistency = true - && *lda >= nstl::max(1, nrowA) - && *ldb >= nstl::max(1, nrowB) - && *ldc >= nstl::max(1, *M); - if (!consistency) - return mkldnn_invalid_arguments; - - return mkldnn_success; -} - -mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, - const char *transa, const char *transb, const int *M, const int *N, - const int *K, const int *lda, const int *ldb, const int *ldc, - const float *alpha, const float *beta, const bool with_bias) { - if (offsetc == nullptr) - return mkldnn_invalid_arguments; - if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) - return mkldnn_invalid_arguments; - - return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, - beta, with_bias); -} - -mkldnn_status_t extended_sgemm(const char *transa, const char *transb, - const int *M, const int *N, const int *K, const float *alpha, - const float *A, const int *lda, const float *B, const int *ldb, - const float *beta, float *C, const int *ldc, - const float *bias, const bool force_jit_gemm) { - mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, - lda, ldb, ldc, alpha, beta, bias != nullptr); - if (status != mkldnn_success) - return status; - -#ifdef USE_CBLAS - if (!force_jit_gemm) { - bool trA = *transa == 't' || *transa == 'T'; - bool trB = *transb == 't' || *transb == 'T'; - CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; - cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, - *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); - - if (bias) { - // Add bias if necessary (bias is applied to columns of C) - cblas_int incx = 1, incy = 1; - parallel_nd(*N, [&](int n) { - ptrdiff_t offset = (ptrdiff_t)n * (*ldc); - cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); - }); - } - return mkldnn_success; - } -#endif - - if (mayiuse(avx512_common)) - return jit_avx512_common_gemm_f32(transa, transb, - M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); - else if (mayiuse(avx)) - return jit_avx_gemm_f32(transa, transb, - M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); - else - return ref_gemm(transa, transb, - M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); -} - -template -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co) { - mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, - M, N, K, LDA, LDB, LDC, alpha, beta, false); - if (status != mkldnn_success) - return status; - - if (*M == 0 || *N == 0 || *K == 0) - return mkldnn_success; - -#if USE_MKL_IGEMM - bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); - bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); - bool AisN = (*transa == 'N' || *transa == 'n'); - bool BisN = (*transb == 'N' || *transb == 'n'); - - if (data_traits::data_type == data_type::u8) { - CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; - CBLAS_OFFSET Cblas_offsetc = - OCisR - ? CblasRowOffset - : OCisC - ? CblasColOffset - : CblasFixOffset; - cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, - *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo, - *beta, C, *LDC, co); - return mkldnn_success; - } else { - assert(data_traits::data_type == data_type::s8); - // TODO CBLAS implementation of gemm_s8s8s32 goes here. - // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo - if (utils::everyone_is(0, *ao, *bo)) { - return simple_gemm_s8s8s32(transa, transb, offsetc, M, - N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, - C, LDC, co); - } else { - return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, - alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); - } - } -#else - cpu_isa_t isa = isa_any; - if (mayiuse(avx512_core_vnni)) { - isa = avx512_core_vnni; - } else if (mayiuse(avx512_core)) { - isa = avx512_core; - } - - if (data_traits::data_type == data_type::u8) { - switch (isa) { - case avx512_core: - case avx512_core_vnni: - return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M, - N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta, - C, LDC, co); - default: - return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, - alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); - } - } else { - assert(data_traits::data_type == data_type::s8); - // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo - if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni)) - && *ao == 0 && *bo == 0) { - return simple_gemm_s8s8s32(transa, transb, offsetc, M, - N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, - C, LDC, co); - } else { - return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, - alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); - } - } -#endif -} - -template -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const int8_t *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co); - -template -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co); - -} -} -} - -using namespace mkldnn::impl; -using namespace mkldnn::impl::cpu; - -mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, - const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha, - const float *A, const int64_t *lda, const float *B, const int64_t *ldb, - const float *beta, float *C, const int64_t *ldc) { - int M_s32 = (int)*M; - int N_s32 = (int)*N; - int K_s32 = (int)*K; - int lda_s32 = (int)*lda; - int ldb_s32 = (int)*ldb; - int ldc_s32 = (int)*ldc; - - return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32, - alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32); -} - -mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, - const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, - const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, - const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, - int32_t *C, const int64_t *ldc, const int32_t *co) { - int M_s32 = (int)*M; - int N_s32 = (int)*N; - int K_s32 = (int)*K; - int lda_s32 = (int)*lda; - int ldb_s32 = (int)*ldb; - int ldc_s32 = (int)*ldc; - return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, - alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); -} - -mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, - const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, - const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, - const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, - int32_t *C, const int64_t *ldc, const int32_t *co) { - int M_s32 = (int)*M; - int N_s32 = (int)*N; - int K_s32 = (int)*K; - int lda_s32 = (int)*lda; - int ldb_s32 = (int)*ldb; - int ldc_s32 = (int)*ldc; - - return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, - alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp deleted file mode 100644 index dc15ff713..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp +++ /dev/null @@ -1,58 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef GEMM_HPP -#define GEMM_HPP - -#include "mkldnn_types.h" -#include "os_blas.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t extended_sgemm(const char *transa, const char *transb, - const int *M, const int *N, const int *K, const float *alpha, - const float *A, const int *lda, const float *B, const int *ldb, - const float *beta, float *C, const int *ldc, - const float *bias = nullptr, bool force_jit_gemm = false); - -template -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, - const b_dt *B, const int *ldb, const int8_t *bo, const float *beta, - int32_t *c, const int *ldc, const int32_t *co); - -#ifdef USE_CBLAS -#define GEMM_IMPL_STR "gemm:blas" -#else -#define GEMM_IMPL_STR "gemm:jit" -#endif - -#if USE_MKL_IGEMM -#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas" -#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas" -#else -#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit" -#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit" -#endif - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp deleted file mode 100644 index 4d34ede0b..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef OS_BLAS_HPP -#define OS_BLAS_HPP - -/** \file - * Common stuff respecting USE_MKL and USE_CBLAS compile flags - * - * USE_MKL USE_CBLAS effect - * ------- --------- ------ - * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS - * yes no jit calls OK; assert if cblas is ever called - * no yes system-dependent CBLAS - * no no gemm convolution (or other blas) N/A; create stubs - */ - -#if defined(USE_MKL) - -#include "mkl_version.h" - -#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001) -#define USE_MKL_IGEMM \ - (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628) - -#include "mkl_cblas.h" -#if !defined(USE_CBLAS) -#define cblas_sgemm(...) assert(!"CBLAS is unavailable") -#endif - -#else /* defined(USE_MKL) */ - -#define USE_MKL_PACKED_GEMM 0 -#define USE_MKL_IGEMM 0 - -#if defined(_SX) -/* TODO: _SX should also define USE_CBLAS in case the later is available */ -extern "C" { -#include "cblas.h" // CHECK: does SX also have a fortran API sgemm? -} - -#elif defined(USE_CBLAS) -#include "cblas.h" // Maybe a system/cmake cblas works for you? -#else -/* put the stubs to make a code compilable but not workable */ -#define cblas_sgemm(...) assert(!"CBLAS is unavailable") -#endif /* defined(_SX) */ - -#endif /* defined(USE_MKL) */ - -namespace mkldnn { -namespace impl { -namespace cpu { - -#if defined(USE_MKL) && defined(USE_CBLAS) -typedef MKL_INT cblas_int; - -#elif defined(USE_CBLAS) -typedef int cblas_int; - -#if defined(_SX) -/* this cblas.h is peculiar... */ -typedef CBLAS_ORDER CBLAS_LAYOUT; -#endif -#endif - -} -} -} - -#endif /* OS_BLAS_HPP */ - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp deleted file mode 100644 index dde72f4a1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp +++ /dev/null @@ -1,206 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef COMMON_H -#define COMMON_H - -#define GEMM_CODE_SIZE (4096L * 32) - -#define AVX512_UNROLL_M 48 -#define AVX512_UNROLL_N 8 -#define AVX512_UNROLL_K 1 -#define AVX512_BM 9984 -#define AVX512_BN 384 -#define AVX512_BK 768 -#define AVX512_BK_VNNI 1536 -#define AVX512_BK_TRADITIONAL 384 -#define AVX512_BLOCKING_SMALL_K 48 -#define AVX512_BN_SMALL_K 24 - - -#define PAGESIZE 4096 - -#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE -#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -enum { - PARTITION_1D_ROW, - PARTITION_1D_COL, - PARTITION_2D_COL_MAJOR, - PARTITION_2D = PARTITION_2D_COL_MAJOR, -}; - -enum { - COPY_NONE, - COPY_A, -}; - -enum { - NO_OFFSET, - FIX_OFFSET, - COL_OFFSET, - ROW_OFFSET, -}; - -// Alias for any dimension related variable. -typedef long long int dim_t; - -typedef struct { - // Interface arguments. - int transa, transb, offsetc; - dim_t m, n, k; - dim_t lda, ldb, ldc; - const int8_t *a; - const uint8_t *b; - int32_t *c; - const float *alpha, *beta; - - int8_t ao, bo; - const int32_t *co; - - // Kernel parameters. - dim_t um, un, uk, bm, bn, bk; - dim_t bn_small_k, bk_traditional, blocking_small_k; - - int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a, - const dim_t *lda, const int8_t *alpha, int8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a, - const dim_t *lda, const uint8_t *alpha, uint8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - // Gemv kernels - void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float, - const int8_t*, const dim_t, const uint8_t*, - const float, int32_t*); - - void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float, - const uint8_t*, const dim_t, const int8_t*, - const float, int32_t*); - - // Gemv parameters - int swap; - -} blas_t; - - -class jit_avx512_core_u8_copy_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern); - - public: - jit_avx512_core_u8_copy_an_kern(); -}; - -class jit_avx512_core_u8_copy_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern); - - public: - jit_avx512_core_u8_copy_at_kern(); -}; - -class jit_avx512_core_u8_copy_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern); - - public: - jit_avx512_core_u8_copy_bn_kern(); -}; - -class jit_avx512_core_u8_copy_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern); - - public: - jit_avx512_core_u8_copy_bt_kern(); -}; - -class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern); - - public: - jit_avx512_core_u8_copy_sum_an_kern(); -}; - -class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern); - - public: - jit_avx512_core_u8_copy_sum_at_kern(); -}; - -class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern); - - public: - jit_avx512_core_u8_copy_sum_bn_kern(); -}; - -class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern); - - public: - jit_avx512_core_u8_copy_sum_bt_kern(); -}; - -} -} -} -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp deleted file mode 100644 index db9dd9ef9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg); -int gemv_threading_driver(blas_t *arg); - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp deleted file mode 100644 index e4b8e1cde..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp +++ /dev/null @@ -1,1409 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "common.hpp" -#include "mkldnn_types.h" -#include "nstl.hpp" -#include "utils.hpp" - -#include "jit_avx512_core_gemm_s8u8s32.hpp" -#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" -#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" -#include "gemv.hpp" - -#if defined(_MSC_VER) -#include -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { - -typedef struct { - int nthrs_m, nthrs_n; - int partition; - int copy_type; -} blas_thread_t; - -static inline void round_to_nearest(int32_t *rounded_val, double fp_val) { - if (fp_val >= 0.) { - fp_val += 0.5; - if (fp_val > INT32_MAX) { - fp_val = INT32_MAX; - } - } else { - fp_val -= 0.5; - if (fp_val < INT32_MIN) { - fp_val = INT32_MIN; - } - } - *rounded_val = (int32_t) fp_val; -} - -static inline void add_results(const dim_t m, const dim_t n, const dim_t k, - const float alpha, const float beta, const int32_t *c_partial_sum, - const dim_t ldcp, int32_t *c_data, const dim_t ldc, - const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao, - const int8_t bo, const int32_t *co, const int offsetc) -{ - for (dim_t j = 0; j < n; ++j) { - for (dim_t i = 0; i < m; ++i) { - int32_t ctemp = c_partial_sum[i + j * ldcp]; - - if (alpha == 1.0f) { - if (beta == 0.0f) { - c_data[i + j * ldc] = ctemp; - } else { - double c_float = (double) beta - * (double) c_data[i + j * ldc]; - c_float += (double) ctemp; - round_to_nearest(&c_data[i + j * ldc], c_float); - } - } else if (alpha == -1.0f) { - if (beta == 0.0f) { - c_data[i + j * ldc] = -ctemp; - } else { - double c_float = (double) beta - * (double) c_data[i + j * ldc]; - c_float -= (double) ctemp; - round_to_nearest(&c_data[i + j * ldc], c_float); - } - } else { - if (beta == 0.0f) { - double c_float = alpha * (double) ctemp; - round_to_nearest(&c_data[i + j * ldc], c_float); - } else { - double c_float = alpha * (double) ctemp + - beta * (double) c_data[i + j * ldc]; - round_to_nearest(&c_data[i + j * ldc], c_float); - } - } - - if (offsetc == FIX_OFFSET) { - c_data[i + j * ldc] += co[0]; - } else if (offsetc == ROW_OFFSET) { - c_data[i + j * ldc] += co[j]; - } else if (offsetc == COL_OFFSET) { - c_data[i + j * ldc] += co[i]; - } - } - } -} - -// TODO Find a better place for those functions. -static inline dim_t ld_padd(const dim_t x) -{ - return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t))) - * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t)); -} - -void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k, - const int8_t *a, const uint8_t *b, float beta, int32_t *c, - const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum, - const int32_t *co, const int offsetc, const blas_t *arg) -{ - int8_t ao = arg->ao; - int8_t bo = arg->bo; - int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0]; - - // Since m and n are limited by blocking, stack overflow may not happen; - // it's up to 32kB -#if !defined(_MSC_VER) - int32_t col_offset[m]; - int32_t row_offset[n]; -#else - int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m); - int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n); -#endif - - int col_req = 0; - int row_req = 0; - - if ((bo != 0) || (offsetc == COL_OFFSET)) - col_req = 1; - if ((ao != 0) || (offsetc == ROW_OFFSET)) - row_req = 1; - - // It needs one of colum or row offsets, but it doesn't need both - if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) { - if ((col_req == 0) && (row_req == 0)) { - if (m <= n) { - col_req = 1; - } else { - row_req = 1; - } - } - } - - if (col_req) { - for (dim_t i = 0; i < m; i++) - col_offset[i] = 0; - - if (offsetc == COL_OFFSET) { - for (dim_t i = 0; i < m; i++) - col_offset[i] += co[i]; - } - - if (bo != 0) { - for (dim_t i = 0; i < m; i++) - col_offset[i] += bo * a_row_sum[i]; - } - } - - if (row_req) { - for (dim_t i = 0; i < n; i++) - row_offset[i] = 0; - - if (offsetc == ROW_OFFSET) { - for (dim_t i = 0; i < n; i++) - row_offset[i] += co[i]; - } - - if (ao != 0) { - for (dim_t i = 0; i < n; i++) - row_offset[i] += ao * b_col_sum[i]; - } - } - - if ((offsetc == FIX_OFFSET) && (co_0 != 0)) { - if (col_req) { - for (dim_t i = 0; i < m; i++) - col_offset[i] += co_0; - } else { - for (dim_t i = 0; i < n; i++) - row_offset[i] += co_0; - } - } - - if ((ao != 0) && (bo != 0)) { - if (col_req) { - for (dim_t i = 0; i < m; i++) - col_offset[i] += (int32_t) k * ao * bo; - } else { - for (dim_t i = 0; i < n; i++) - row_offset[i] += (int32_t) k * ao * bo; - } - } - - if (col_req == 0) { - if (row_req == 0) { - if (beta == 0.0) { - arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } else { - arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } - } else { - if (beta == 0.0) { - arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } else { - arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } - } - } else { - if (row_req == 0) { - if (beta == 0.0) { - arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } else { - arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } - } else { - if (beta == 0.0) { - arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } else { - arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, - row_offset); - } - } - } -} - -static inline void *align(void *ptr, size_t alignment) -{ - return (void *) utils::rnd_up((uintptr_t) ptr, alignment); -} - -static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k, - const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co, - const blas_t *arg) -{ - dim_t lda = arg->lda; - dim_t ldb = arg->ldb; - dim_t ldc = arg->ldc; - int8_t ao = arg->ao; - int8_t bo = arg->bo; - float alpha = *arg->alpha; - float beta = *arg->beta; - - if (m <= 0 || n <= 0) { - return 0; - } - - // Padding along K dimension. - dim_t k_padd = 0; - if (k <= arg->bk_traditional) { - k_padd = utils::rnd_up(k, arg->uk); - k_padd = nstl::max(128LL, k_padd); - } else if (k < 2 * arg->bk) { - k_padd = utils::rnd_up(k / 2, arg->uk); - } else { - k_padd = arg->bk; - } - - // Padding along M dimension. - dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), - arg->um); - - // Padding along N dimension. - dim_t n_padd = 0; - if (k < arg->blocking_small_k) { - n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), - arg->bn_small_k), arg->un); - } else { - n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), - arg->un); - } - - // Padding for temporary buffer for C - dim_t ldc_buf = ld_padd(m_padd); - - dim_t strideAm = (arg->transa == 0)? 1 : lda; - dim_t strideAn = (arg->transa != 0)? 1 : lda; - dim_t strideBm = (arg->transb == 0)? 1 : ldb; - dim_t strideBn = (arg->transb != 0)? 1 : ldb; - - size_t a_buf_nelems = m_padd * k_padd; - size_t b_buf_nelems = k_padd * n_padd; - size_t a_row_sum_nelems = m_padd; - size_t b_col_sum_nelems = n_padd; - - size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K - + b_buf_nelems * sizeof(*b) + PAGE_4K - + a_row_sum_nelems * sizeof(*c) + PAGE_4K - + b_col_sum_nelems * sizeof(*c) + PAGE_4K; - - bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); - if (need_c_buffer) { - size_t c_buf_nelems = ldc_buf * n_padd; - mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; - } - - char *mem = (char *) malloc(mem_size, 128); - - if (!mem) { - return -1; - } - - int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); - uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K); - int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); - int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems, - PAGE_4K); - - int32_t *bufferC = NULL; - if (need_c_buffer) { - bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); - } - - float beta_saved = beta; - - int a_block_copied = 0; - dim_t sizeM = 0; - for (dim_t Bm = 0; Bm < m; Bm += sizeM) { - sizeM = m - Bm; - if (sizeM > m_padd) - sizeM = m_padd; - - dim_t sizeK = 0; - for (dim_t Bk = 0; Bk < k; Bk += sizeK) { - sizeK = k - Bk; - if (sizeK > k_padd) - sizeK = k_padd; - - // Scale C blocks by beta only for the first time - if (Bk == 0) - beta = beta_saved; - else - beta = 1.0f; - - // Apply C offset when to the last k-block of the partial sum. - int offsetc = NO_OFFSET; - if (Bk + sizeK == k) - offsetc = arg->offsetc; - - dim_t sizeN = 0; - for (dim_t Bn = 0; Bn < n; Bn += sizeN) { - sizeN = n - Bn; - if (sizeN > n_padd) - sizeN = n_padd; - - const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn; - arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL, - NULL, b_col_sum); - - dim_t sizeUM = 0; - for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { - sizeUM = sizeM - Um; - if (sizeUM > arg->um) - sizeUM = arg->um; - - /* - * Use the whole A buffer only if we have multiple B blocks - * for k-dimension, otherwise we are wasting cache to store - * B and C blocks. - */ - dim_t Um_forA = 0; - if (sizeN < n) - Um_forA = Um; - - const int8_t *a_block = a + (Bm + Um) * strideAm - + Bk * strideAn; - if (!a_block_copied) { - arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL, - bufferA + Um_forA * sizeK, NULL, NULL, - a_row_sum + Um_forA); - } - - int32_t *c_block = c + (Bm + Um) + Bn * ldc; - dim_t co_stride = 0; - if (offsetc == FIX_OFFSET) { - co_stride = 0; - } else if (offsetc == ROW_OFFSET) { - co_stride = Bn; - } else if (offsetc == COL_OFFSET) { - co_stride = Bm + Um; - } - if (need_c_buffer) { - igemm_inner_kernel(sizeUM, sizeN, sizeK, - bufferA + Um_forA * sizeK, bufferB, 0.0f, - bufferC + Um, ldc_buf, a_row_sum + Um_forA, - b_col_sum, NULL, NO_OFFSET, arg); - - // Finish the block adding the necessary alpha, beta - // and offsets. - add_results(sizeUM, sizeN, sizeK, alpha, beta, - bufferC + Um, ldc_buf, c_block, ldc, - a_row_sum + Um_forA, b_col_sum, ao, bo, - co + co_stride, offsetc); - } else { - igemm_inner_kernel(sizeUM, sizeN, sizeK, - bufferA + Um_forA * sizeK, bufferB, beta, - c_block, ldc, a_row_sum + Um_forA, b_col_sum, - co + co_stride, offsetc, arg); - } - } - a_block_copied = 1; - } - a_block_copied = 0; - } - } - - free(mem); - - return 0; -} - -static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n, - const dim_t k, const int8_t *bufferA, const uint8_t *b, - const float beta, int32_t *c, const int offsetc, const int32_t *co, - const int32_t *a_row_sum, const blas_t *arg) -{ - dim_t ldb = arg->ldb; - dim_t ldc = arg->ldc; - int8_t ao = arg->ao; - int8_t bo = arg->bo; - float alpha = *arg->alpha; - - if (m <= 0 || n <= 0) { - return 0; - } - - // Padding along N dimension. - dim_t n_padd = 0; - if (k < arg->blocking_small_k) { - n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), - arg->bn_small_k), arg->un); - } else { - n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), - arg->un); - } - - // Padding for temporary buffer for C - dim_t ldc_buf = ld_padd(m); - - dim_t strideBn = (arg->transb != 0)? 1 : ldb; - - size_t b_buf_nelems = k * n_padd; - size_t b_col_sum_nelems = n_padd; - - size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K - + b_col_sum_nelems * sizeof(*c) + PAGE_4K; - - bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); - if (need_c_buffer) { - size_t c_buf_nelems = ldc_buf * n_padd; - mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; - } - - char *mem = (char *) malloc(mem_size, 128); - - if (!mem) { - return -1; - } - - uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K); - int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); - - int32_t *bufferC = NULL; - if (need_c_buffer) { - bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); - } - - dim_t sizeN = 0; - for (dim_t Bn = 0; Bn < n; Bn += sizeN) { - sizeN = n - Bn; - if (sizeN > n_padd) - sizeN = n_padd; - - // Implement the kernel here. - const uint8_t *b_block = b + Bn * strideBn; - arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL, - b_col_sum); - - dim_t co_stride = 0; - if (offsetc == FIX_OFFSET) { - co_stride = 0; - } else if (offsetc == ROW_OFFSET) { - co_stride = Bn; - } else if (offsetc == COL_OFFSET) { - co_stride = 0; - } - int32_t *c_block = c + Bn * ldc; - if (need_c_buffer) { - igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC, - ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg); - - // Finish the block adding the necessary alpha, beta and offsets. - add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block, - ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride, - offsetc); - } else { - igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block, - ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg); - } - } - - free(mem); - - return 0; - -} - -#define N2D_MAX_AVX512 384 -#define M2D_MIN_AVX512 384 -#define VECLEN 16 -#define NCONS 1 -static inline void set_thread_opts_avx512(int *p_nthrs, - blas_thread_t *thread_info, const blas_t *arg) -{ - int nthrs = *p_nthrs; - dim_t m = arg->m; - dim_t n = arg->n; - - thread_info->nthrs_m = 0; - thread_info->nthrs_n = 0; - thread_info->copy_type = COPY_NONE; // By default don't do parallel copy. - - int condition_2D_bsrc = -1; - if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) { - condition_2D_bsrc = 1; - } else { - condition_2D_bsrc = 0; - } - - int condition_1D_copya = 0; - if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) { - condition_2D_bsrc = 0; - condition_1D_copya = 1; - } - - // If offset is non-zero, we need to keep 1D_copya to reduce update overhead - if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0 - || arg->offsetc != FIX_OFFSET) { - condition_2D_bsrc = 0; - condition_1D_copya = 1; - } - - if (condition_2D_bsrc == 1) { - int nthrs_m = 1; - int nthrs_n = nthrs; - - while ((nthrs_n % 2 == 0) && - (n / nthrs > N2D_MAX_AVX512 || - n / nthrs_n <= N2D_MAX_AVX512 / 2) && - (m / nthrs_m >= 2 * M2D_MIN_AVX512) && - (nthrs_m < 4)) { - nthrs_m *= 2; - nthrs_n /= 2; - } - - thread_info->nthrs_m = nthrs_m; - thread_info->nthrs_n = nthrs_n; - thread_info->partition = PARTITION_2D; - - // Reset the total number of threads that will be used. - *p_nthrs = nthrs_m * nthrs_n; - - } else if (condition_1D_copya && mkldnn_thr_syncable()) { - // Use parallel copy A algorithm - thread_info->copy_type = COPY_A; - thread_info->partition = PARTITION_1D_COL; - } else { - if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) { - thread_info->partition = PARTITION_1D_ROW; - } else { - thread_info->partition = PARTITION_1D_COL; - } - } -} -#undef N2D_MAX_AVX512 -#undef M2D_MIN_AVX512 -#undef VECLEN -#undef NCONS - -static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, - dim_t *t_offset, dim_t *t_block) -{ - dim_t band = n / nthrs; - - dim_t tail = n - (nthrs - 1) * band; - if (tail > (band + 1)) - band++; - tail = n - (nthrs - 1) * band; - - if (ithr < (nthrs - 1)) - *t_block = band; - else - *t_block = tail; - - *t_offset = ithr * band; - - if (*t_offset >= n) { - *t_block = 0; - *t_offset = 0; - } else if ((*t_offset + *t_block) > n) { - *t_block = n - *t_offset; - } -} - -static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, - const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, - const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp, - dim_t *p_n_band) -{ - dim_t m_disp = 0, n_disp = 0; - dim_t m_band = 0, n_band = 0; - - int mdiv = nthrs_m; - int ndiv = nthrs_n; - - dim_t m_bandt = m / mdiv; /* size per thread */ - dim_t n_bandt = n / ndiv; /* size per thread */ - int firstmgroup = mdiv - 1; - int firstngroup = ndiv - 1; - dim_t firstmval = m_bandt; - dim_t firstnval = n_bandt; - - int mthr_used = mdiv; - if (m - (mdiv - 1) * m_bandt > m_bandt + 1) { - if (m - (mdiv - 1) * m_bandt > mdiv) - ++m_bandt; - - firstmval = m_bandt + 1; - mthr_used = (int) (m / firstmval); - - if (mthr_used * firstmval < m) - ++mthr_used; - - firstmgroup = mthr_used - 1; - } - - int nthr_used = ndiv; - if (n - (ndiv - 1) * n_bandt > n_bandt + 1) { - firstnval = n_bandt + 1; - nthr_used = (int) (n / firstnval); - - if (nthr_used * firstnval < n) - ++nthr_used; - - firstngroup = nthr_used - 1; - } - - *nthrs = mthr_used * nthr_used; - - if (ithr < *nthrs) { - if (ithr_i < firstmgroup) { - m_band = firstmval; - m_disp = ithr_i * firstmval; - } else if (ithr_i <= mthr_used - 2) { - m_band = m_bandt; - m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt; - } else { - m_disp = firstmgroup * firstmval - + (mthr_used - 1 - firstmgroup) * m_bandt; - m_band = nstl::max(0LL, m - m_disp); - } - - if (ithr_j < firstngroup) { - n_band = firstnval; - n_disp = ithr_j * firstnval; - } else if (ithr_j <= nthr_used - 2) { - n_band = n_bandt; - n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt; - } else { - n_disp = firstngroup * firstnval - + (nthr_used - 1 - firstngroup) * n_bandt; - n_band = nstl::max(0LL, n - n_disp); - } - m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL); - n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL); - } - - if (ithr < *nthrs) { - *p_m_disp = m_disp; - *p_n_disp = n_disp; - *p_m_band = m_band; - *p_n_band = n_band; - } else { - *p_m_disp = 0; - *p_n_disp = 0; - *p_m_band = 0; - *p_n_band = 0; - } - - return; -} - -static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m, - dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c, - const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg) -{ - dim_t strideAm = (arg->transa == 0)? 1 : arg->lda; - dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb; - int offsetc = arg->offsetc; - - switch (thread_info->partition) { - case PARTITION_1D_ROW: - { - dim_t offset = 0; - dim_t block = 0; - partition_1d(ithr, *nthrs, arg->m, &offset, &block); - - *m = block; - *n = arg->n; - *k = arg->k; - - // Set matrix A. - *a = arg->a + offset * strideAm; - - // Set matrix B. - *b = arg->b; - - // Set matrix C. - *c = arg->c + offset; - - // Set offset vector for C matrix - dim_t co_stride = 0; - if (offsetc == FIX_OFFSET) { - co_stride = 0; - } else if (offsetc == ROW_OFFSET) { - co_stride = 0; - } else if (offsetc == COL_OFFSET) { - co_stride = offset; - } - *co = arg->co + co_stride; - break; - } - - case PARTITION_1D_COL: - { - dim_t offset = 0; - dim_t block = 0; - partition_1d(ithr, *nthrs, arg->n, &offset, &block); - - *m = arg->m; - *n = block; - *k = arg->k; - - // Set matrix A. - *a = arg->a; - - // Set matrix B. - *b = arg->b + offset * strideBn; - - // Set matrix C. - *c = arg->c + offset * arg->ldc; - - // Set offset vector for C matrix - dim_t co_stride = 0; - if (offsetc == FIX_OFFSET) { - co_stride = 0; - } else if (offsetc == ROW_OFFSET) { - co_stride = offset; - } else if (offsetc == COL_OFFSET) { - co_stride = 0; - } - *co = arg->co + co_stride; - break; - } - - case PARTITION_2D_COL_MAJOR: - { - int nthrs_m = thread_info->nthrs_m; - int nthrs_n = thread_info->nthrs_n; - int ithr_i = ithr % nthrs_m; - int ithr_j = ithr / nthrs_m; - - dim_t m_disp = 0; - dim_t m_band = 0; - dim_t n_disp = 0; - dim_t n_band = 0; - - partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n, - arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band); - - *m = m_band; - *n = n_band; - *k = arg->k; - - // Set matrix A. - *a = arg->a + m_disp * strideAm; - - // Set matrix B. - *b = arg->b + n_disp * strideBn; - - // Set matrix C. - *c = arg->c + m_disp + n_disp * arg->ldc; - - // Set offset vector for C matrix - dim_t co_stride = 0; - if (offsetc == FIX_OFFSET) { - co_stride = 0; - } else if (offsetc == ROW_OFFSET) { - co_stride = n_disp; - } else if (offsetc == COL_OFFSET) { - co_stride = m_disp; - } - *co = arg->co + co_stride; - break; - } - } -} - -#define MULTIPLIER 10 -static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m, - const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b, - int32_t *c, const int32_t *co, const blas_t *arg, - char **p_shared_mem) -{ - const dim_t lda = arg->lda; - const dim_t ldb = arg->ldb; - const dim_t strideAm = (arg->transa == 0)? 1 : lda; - const dim_t strideAn = (arg->transa != 0)? 1 : lda; - const dim_t strideBm = (arg->transb == 0)? 1 : ldb; - - // Padding along M dimension. - dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), - arg->um); - - // Padding along K dimension. - dim_t k_padd = 0; - if (k <= arg->bk_traditional) { - k_padd = utils::rnd_up(k, arg->uk); - k_padd = nstl::max(128LL, k_padd); - } else if (k < 2 * arg->bk) { - k_padd = utils::rnd_up(k / 2, arg->uk); - } else { - k_padd = arg->bk; - } - - m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs; - if (m_padd > m) { - m_padd = utils::rnd_up(m, arg->um); - } - - size_t a_buf_nelems = m_padd * k_padd; - - // Allocate shared memory for A and its row sum buffers in master thread. - if (ithr == 0) { // If thread master - size_t a_row_sum_nelems = m_padd; - - size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K) - + a_row_sum_nelems * sizeof(*c) + PAGE_4K; - - *p_shared_mem = (char *) malloc(mem_size, 128); - - } - mkldnn_thr_barrier(); - - char *mem = *p_shared_mem; - int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); - int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K); - - if (!mem) { - return -1; - } - - int result = 0; // Return status - - dim_t sizeK = 0; - for (dim_t Bk = 0; Bk < k; Bk += sizeK) { - sizeK = k - Bk; - if (sizeK > k_padd) - sizeK = k_padd; - - // Scale C blocks by beta only for the first term of partial sum. - float beta = 1.0f; - if (Bk == 0) - beta = *(arg->beta); - - // Apply C offset for the last k-block of the partial sum. - int offsetc = NO_OFFSET; - if (Bk + sizeK == k) - offsetc = arg->offsetc; - - dim_t sizeM = 0; - for (dim_t Bm = 0; Bm < m; Bm += sizeM) { - sizeM = m - Bm; - if (sizeM > m_padd) - sizeM = m_padd; - - if (ithr < nthrs) { - dim_t band = (sizeM + nthrs - 1) / nthrs; - band = utils::rnd_up(band, arg->um); - - dim_t offset = band * ithr; - - // If offset is too large don't use that thread for copying. - if (offset >= sizeM) { - offset = 0; - band = 0; - } - - // Handle the tail of the copy. - if (offset + band > sizeM) { - band = sizeM - offset; - } - - if (band > 0) { - const int8_t *a_block = a + (Bm + offset) * strideAm - + Bk * strideAn; - arg->copyA(&sizeK, &band, a_block, &lda, NULL, - bufferA + offset * sizeK, NULL, NULL, - a_row_sum + offset); - } - } - mkldnn_thr_barrier(); // Wait for finishing parallel copy. - - const uint8_t *b_block = b + Bk * strideBm; - int32_t *c_block = c + Bm; - dim_t co_stride = 0; - if (offsetc == FIX_OFFSET) { - co_stride = 0; - } else if (offsetc == ROW_OFFSET) { - co_stride = 0; - } else if (offsetc == COL_OFFSET) { - co_stride = Bm; - } - - result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK, - bufferA, b_block, beta, c_block, offsetc, co + co_stride, - a_row_sum, arg); - - mkldnn_thr_barrier(); // Wait for kernel computations to finish. - } - } - - // Free memory allocated in master thread - if (ithr == 0) { - free(mem); - } - - return result; -} -#undef MULTIPLIER - -static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k, - double fp_per_cycle, int *nthrs) -{ - double omp_overhead_small_core = 3.0e+3; - double omp_intercept_big_core = 4.0e+3; - double omp_slope_big_core = 5.0e+2; - - double gemm_cycles = 8.0 * m * n * k / fp_per_cycle; - - int i = *nthrs; - - // Use a different model for omp overheads if nthrs is <= 4 - if (*nthrs <= 4 && omp_overhead_small_core > 0) { - double omp_cycles = omp_overhead_small_core; - if (gemm_cycles < omp_cycles) { - *nthrs = 1; - return; - } else { - while (i > 1) { - if (omp_cycles * i < gemm_cycles * (i - 1)) break; - --i; - } - } - } else { - if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { - *nthrs = 1; - return; - } - - // adaptive decrement to march faster· - while (i > 1) { - double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; - if (omp_cycles * i < gemm_cycles * (i - 1)) - break; - - if (i < 10) - i -= 2; - else if (i < 30) - i -= 4; - else - i -= 8; - } - } - - if (i < 1) - i = 1; - - *nthrs = i; -} - -#define CACHE_LINE_SIZE 64 -static int gemm_threading_driver(blas_t *arg) -{ - if ((arg->m <= 0) || (arg->n <= 0)) - return mkldnn_success; - - if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { - return mkldnn_success; - } - - int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); - get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr); - - if (nthr == 1) { - return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b, - arg->c, arg->co, arg); - } - - int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE, - PAGE_4K); - - if (!results) { - return -1; - } - - for (int i = 0; i < nthr; i++) { - results[i * CACHE_LINE_SIZE] = 0; // Initialize to success - } - - char *shared_mem = NULL; - - parallel(nthr, [&](const int ithr, const int nthr) { - int nthrs = nthr; - if (nthrs == 1) { - results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, - arg->b, arg->c, arg->co, arg); - } else { - blas_thread_t thread_info; - set_thread_opts_avx512(&nthrs, &thread_info, arg); - - const int8_t *a = NULL; - const uint8_t *b = NULL; - int32_t *c = NULL; - const int32_t *co = NULL; - dim_t m = -1; - dim_t n = -1; - dim_t k = -1; - decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co, - &thread_info, arg); - - if (ithr < nthrs) { - switch (thread_info.copy_type) { - case COPY_A: - results[ithr * CACHE_LINE_SIZE] = - parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg, - &shared_mem); - break; - - default: - case COPY_NONE: - results[ithr * CACHE_LINE_SIZE] = - gemm_kernel_driver(m, n, k, a, b, c, co, arg); - break; - } - } - } - }); - - int result = 0; // Initialize to success - for (int i = 0; i < nthr; i++) { - if (results[i] != 0) { - result = results[i * CACHE_LINE_SIZE]; - break; - } - } - - free(results); - - return result; -} -#undef CACHE_LINE_SIZE - -static jit_avx512_core_u8_copy_an_kern *copy_an; -static jit_avx512_core_u8_copy_at_kern *copy_at; -static jit_avx512_core_u8_copy_bn_kern *copy_bn; -static jit_avx512_core_u8_copy_bt_kern *copy_bt; -static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an; -static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at; -static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn; -static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt; -static jit_avx512_core_gemm_s8u8s32_kern *kernel; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_b; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_r; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_c; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r; -static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c; -static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel; -static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel; - -static void jit_init(blas_t *arg) -{ - static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a, - const dim_t *lda, const int8_t *alpha, int8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a, - const dim_t *lda, const int8_t *alpha, int8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a, - const dim_t *lda, const uint8_t *alpha, uint8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a, - const dim_t *lda, const uint8_t *alpha, uint8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a, - const dim_t *lda, const int8_t *alpha, int8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a, - const dim_t *lda, const int8_t *alpha, int8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a, - const dim_t *lda, const uint8_t *alpha, uint8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a, - const dim_t *lda, const uint8_t *alpha, uint8_t *b, - const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); - - static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, - const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, - const dim_t ldc, const int32_t *col_offset, - const int32_t *row_offset); - - static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float, - const int8_t*, const dim_t, const uint8_t*, - const float, int32_t*); - - static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float, - const uint8_t*, const dim_t, const int8_t*, - const float, int32_t*); - - if (mayiuse(avx512_core_vnni)) { - arg->um = AVX512_UNROLL_M; - arg->un = AVX512_UNROLL_N; - arg->uk = AVX512_UNROLL_K; - arg->bm = AVX512_BM; - arg->bn = AVX512_BN; - arg->bk = AVX512_BK_VNNI; - - arg->bk_traditional = AVX512_BK_TRADITIONAL; - arg->bn_small_k = AVX512_BN_SMALL_K; - arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; - } else { - arg->um = AVX512_UNROLL_M; - arg->un = AVX512_UNROLL_N; - arg->uk = AVX512_UNROLL_K; - arg->bm = AVX512_BM; - arg->bn = AVX512_BN; - arg->bk = AVX512_BK; - - arg->bk_traditional = AVX512_BK_TRADITIONAL; - arg->bn_small_k = AVX512_BN_SMALL_K; - arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; - } - - static std::once_flag initialized; - std::call_once(initialized, []{ - - copy_an = new jit_avx512_core_u8_copy_an_kern(); - copy_at = new jit_avx512_core_u8_copy_at_kern(); - copy_bn = new jit_avx512_core_u8_copy_bn_kern(); - copy_bt = new jit_avx512_core_u8_copy_bt_kern(); - - copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern(); - copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern(); - copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern(); - copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern(); - - kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false); - kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true); - kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true); - kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false); - kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false); - kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true); - kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true); - kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false); - - gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); - gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); - - - copyAn = copy_an->getCode(); - - copyAt = copy_at->getCode(); - - copyBn = copy_bn->getCode(); - - copyBt = copy_bt->getCode(); - - copySumAn = copy_sum_an->getCode(); - - copySumAt = copy_sum_at->getCode(); - - copySumBn = copy_sum_bn->getCode(); - - copySumBt = copy_sum_bt->getCode(); - - kern = kernel->getCode(); - - kern_b = kernel_b->getCode(); - - kern_r = kernel_r->getCode(); - - kern_c = kernel_c->getCode(); - - kern_b0 = kernel_b0->getCode(); - - kern_b0_b = kernel_b0_b->getCode(); - - kern_b0_r = kernel_b0_r->getCode(); - - kern_b0_c = kernel_b0_c->getCode(); - - gemv_s8u8s32_kern = - gemv_s8u8s32_kernel -> generate - (mayiuse(avx512_core_vnni)); - gemv_u8s8s32_kern = - gemv_u8s8s32_kernel -> generate - (mayiuse(avx512_core_vnni)); - }); - - if (arg->bo == 0) { // No need to compute A row sum if bo is zero - if (arg->transa == 0) { - arg->copyA = copyAn; - } else { - arg->copyA = copyAt; - } - } else { - if (arg->transa == 0) { - arg->copyA = copySumAn; - } else { - arg->copyA = copySumAt; - } - } - - if (arg->ao == 0) { // No need to compute B column sum if ao is zero - if (arg->transb == 0) { - arg->copyB = copyBn; - } else { - arg->copyB = copyBt; - } - } else { - if (arg->transb == 0) { - arg->copyB = copySumBn; - } else { - arg->copyB = copySumBt; - } - } - - arg->kernel = kern; - arg->kernel_b = kern_b; - arg->kernel_r = kern_r; - arg->kernel_c = kern_c; - arg->kernel_b0 = kern_b0; - arg->kernel_b0_b = kern_b0_b; - arg->kernel_b0_r = kern_b0_r; - arg->kernel_b0_c = kern_b0_c; - arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern; - arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern; -} - -mkldnn_status_t jit_avx512_core_gemm_s8u8s32( - const char *transA, const char *transB, const char *offsetC, - const int *m, const int *n, const int *k, - const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, - const uint8_t *b, const int *ldb, const int8_t *ob, - const float *beta, int32_t *c, const int *ldc, const int32_t *oc) -{ - char transa = *transA; - char transb = *transB; - char offsetc = *offsetC; - - blas_t args; - - // Initialize blas structure - args.m = *m; - args.n = *n; - args.k = *k; - args.alpha = alpha; - args.a = a; - args.lda = *lda; - args.b = b; - args.ldb = *ldb; - args.beta = beta; - args.c = c; - args.ldc = *ldc; - args.transa = (transa == 'N' || transa == 'n') ? 0 : 1; - args.transb = (transb == 'N' || transb == 'n') ? 0 : 1; - args.um = 0; - args.un = 0; - args.bm = 0; - args.bn = 0; - args.bk = 0; - args.copyA = NULL; - args.copyB = NULL; - args.kernel = NULL; - args.kernel_b0 = NULL; - args.ao = *oa; - args.bo = *ob; - args.co = oc; - - if (offsetc == 'F' || offsetc == 'f') { - args.offsetc = FIX_OFFSET; - } else if (offsetc == 'R' || offsetc == 'r') { - args.offsetc = ROW_OFFSET; - } else { // offsetc == 'C' || offsetc == 'c' - args.offsetc = COL_OFFSET; - } - - jit_init(&args); - int result = gemm_threading_driver(&args); - - return (result < 0) ? mkldnn_out_of_memory : mkldnn_success; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp deleted file mode 100644 index b2e2902a1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP -#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP - -#include -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t jit_avx512_core_gemm_s8u8s32( - const char *transA, const char *transB, const char *offsetC, - const int *m, const int *n, const int *k, - const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, - const uint8_t *b, const int *ldb, const int8_t *ob, - const float *beta, int32_t *c, const int *ldc, const int32_t *oc); - -} -} -} - -#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp deleted file mode 100644 index 57554a185..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp +++ /dev/null @@ -1,539 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" - - -#ifdef _WIN32 -static const bool is_windows = 1; -#else -static const bool is_windows = 0; -#endif - - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - - - - -// Convert between vector register lengths. -static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); } -static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); } - -// Load from or store to C. -void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst, - const Xbyak::Address &src, int nelems) -{ - switch (nelems) { - default: vmovups(dst, src); break; - case 8: vmovups(make_ymm(dst), src); break; - case 4: vmovups(make_xmm(dst), src); break; - case 2: vmovlps(make_xmm(dst), src); break; - case 1: vmovss(make_xmm(dst), src); break; - } -} -void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst, - const Xbyak::Xmm &src, int nelems) -{ - switch (nelems) { - default: vmovups(dst, src); break; - case 8: vmovups(dst, make_ymm(src)); break; - case 4: vmovups(dst, make_xmm(src)); break; - case 2: vmovsd(dst, make_xmm(src)); break; - case 1: vmovss(dst, make_xmm(src)); break; - } -} - -// Perform length-4 dot product accumulations of unsigned and signed bytes -// in parallel. -// Use vpdpbusd if VNNI available, otherwise emulate. -void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst, - const Xmm &src1, const Xmm &src2) -{ - if (vnni) - vpdpbusd(dst, src1, src2); - else { - vpmaddubsw(dp_scratch, src1, src2); - vpmaddwd(dp_scratch, ones, dp_scratch); - vpaddd(dst, dst, dp_scratch); - } -} - -// Inner kernel. -void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n, - bool cfetch) -{ - int um_vecs = (unroll_m + 15) >> 4; - Label label_kernel_loop; - - L_aligned(label_kernel_loop); { - for (int h = 0; h < 4; h++) { - for (int j = 0; j < unroll_n; j++) { - const Zmm b = b_regs[j & 1]; - - vpbroadcastd(b, ptr[BO + isize * - (2 * j + 2 * h * unroll_n - offset_b)]); - dot_product(c_regs[0][j], b, a_regs[0]); - - if (j == 1 && !(h & 1)) - prefetch_b(ptr[BO + isize * (prefetch_size_b - + 2 * h * unroll_n - offset_b)]); - else if (j % 3 == 0) - prefetch_a(ptr[AO + isize * (prefetch_size_a - + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]); - - for (int i = 1; i < um_vecs; i++) - dot_product(c_regs[i][j], b, a_regs[i]); - - if (cfetch && (j == std::min(1, unroll_n - 1))) { - if (h == 3) - lea(CO2, ptr[CO2 + LDC]); - else if (h < um_vecs) - prefetch_c(ptr[CO2 + (16 * h * size)]); - } - - if (h == 3 && j == std::min(3, unroll_n - 1)) - lea(AA, ptr[AA + (32 * isize)]); - } - - for (int i = 0; i < um_vecs; i++) - vmovups(a_regs[i], ptr[AO + isize * - (32 * i + 2 * (h + 1) * unroll_m - offset_a)]); - - if (h == 2) - prefetch_x(ptr[AA - (offset_a * isize)]); - } - - add(AO, 8 * isize * unroll_m); - add(BO, 8 * isize * unroll_n); - sub(LoopCount, 1); - jg(label_kernel_loop, T_NEAR); - } -} - -// k remainder loop for kernel. -void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m, - int unroll_n, int unroll_k, int bwidth) -{ - if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) - || (unroll_m < 0) || (unroll_n < 0)) - return; - - int um_vecs = (unroll_m + 15) >> 4; - - for (int h = 0; h < unroll_k; h++) { - for (int j = 0; j < unroll_n; j++) { - Zmm b = b_regs[j & 1]; - auto b_src = ptr[BO + (-isize * offset_b - + bwidth * (j + h * unroll_n))]; - - switch (bwidth) { - case 4: - vpbroadcastd(b, b_src); - break; - case 2: - vpbroadcastw(b, b_src); - break; - case 1: - vpbroadcastb(b, b_src); - break; - } - for (int i = 0; i < um_vecs; i++) - dot_product(c_regs[i][j], b, a_regs[i]); - } - - if (unroll_k > 1) { - for (int i = 0; i < um_vecs; i++) - vmovups(a_regs[i], ptr[AO + isize * (32 * i - + (h + 1) * 2 * unroll_m - offset_a)]); - } - } - - add(AO, unroll_k * unroll_m * bwidth); - add(BO, unroll_k * unroll_n * bwidth); -} - -// Inner loop. -void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) -{ - if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) - || (unroll_m < 0) || (unroll_n < 0)) - return; - - int um_vecs = (unroll_m + 15) >> 4; - int stage1 = unroll_n, stage2 = unroll_n; - - Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2; - Label label_k_main_loop_3, label_kernel_loop_3; - Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2; - Label label_k_rem_1, label_update_begin; - - mov(AO, A); - for (int i = 0; i < um_vecs; i++) - vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]); - - mov(LoopCount, K); - sar(LoopCount, 4); - jle(label_k_remainder_loop_begin, T_NEAR); - - // Main k loops, broken into three parts to time C prefetching. - sub(LoopCount, stage1 + stage2); - jle(label_k_main_loop_2, T_NEAR); - - kernel_loop(unroll_m, unroll_n, false); - - L_aligned(label_k_main_loop_2); - lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); - add(LoopCount, stage1); - jle(label_k_main_loop_3, T_NEAR); - - kernel_loop(unroll_m, unroll_n, true); - - L_aligned(label_k_main_loop_3); - lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); - add(LoopCount, stage2); - jle(label_k_remainder_loop_begin, T_NEAR); - - kernel_loop(unroll_m, unroll_n, true); - - // k remainder handling - L_aligned(label_k_remainder_loop_begin); - mov(LoopCount, K); - test(LoopCount, 8); - je(label_k_rem_4, T_NEAR); - - remainder_kernel(unroll_m, unroll_n, 2, 4); - - L_aligned(label_k_rem_4); - mov(LoopCount, K); - test(LoopCount, 4); - je(label_k_rem_2, T_NEAR); - - remainder_kernel(unroll_m, unroll_n, 1, 4); - - L_aligned(label_k_rem_2); - mov(LoopCount, K); - test(LoopCount, 2); - je(label_k_rem_1, T_NEAR); - - Zmm zero = zmm6; - Zmm tmp = zmm5; - - vpxorq(zero, zero, zero); - for (int i = 0; i < um_vecs; i++) { - Zmm a = a_regs[i]; - vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]); - vpunpcklwd(tmp, a, zero); - vpunpckhwd(a, a, zero); - vshufi32x4(a, tmp, a, 0x44); - vshufi32x4(a, a, a, 0xD8); - } - - remainder_kernel(unroll_m, unroll_n, 1, 2); - - L_aligned(label_k_rem_1); - mov(LoopCount, K); - test(LoopCount, 1); - je(label_update_begin, T_NEAR); - - vpxorq(zero, zero, zero); - for (int i = 0; i < um_vecs; i++) { - Zmm a = a_regs[i]; - vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]); - vpunpcklbw(tmp, a, zero); - vpunpckhbw(a, a, zero); - vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1); - vpunpcklwd(tmp, a, zero); - vpunpckhwd(a, a, zero); - vshufi32x4(a, tmp, a, 0x44); - vshufi32x4(a, a, a, 0xD8); - } - - remainder_kernel(unroll_m, unroll_n, 1, 1); - - // Add offsets and update C. - L_aligned(label_update_begin); - - if (enable_offset_r) { - // Add row offsets. - mov(rax, coffset_ry); - for (int j = 0; j < unroll_n; j++) { - Zmm row_offset = zmm0; - - vbroadcastss(row_offset, ptr[rax + size * j]); - - for (int i = 0; i < um_vecs; i++) - vpaddd(c_regs[i][j], c_regs[i][j], row_offset); - } - add(coffset_ry, size * unroll_n); - } - - if (enable_offset_c) { - // Add column offsets. - mov(rax, coffset_cy); - for (int i = 0; i < um_vecs; i++) { - Zmm col_offset = zmm0; - - c_load(col_offset, ptr[rax + size * 16 * i], unroll_m); - - for (int j = 0; j < unroll_n; j++) - vpaddd(c_regs[i][j], c_regs[i][j], col_offset); - } - } - - Reg64 LDC3 = rax; - lea(LDC3, ptr[LDC + LDC * 2]); - - // C updates. - int c_off_j = 0; - for (int j = 0; j < unroll_n; j++) { - if (j > 0 && (j & 3) == 0) { - lea(CO1, ptr[CO1 + LDC * 4]); - c_off_j += 4; - } - - int jj = j - c_off_j; - - for (int i = 0; i < um_vecs; i++) { - Zmm c = c_regs[i][j]; - Zmm c_old = zmm0; - decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj; - - auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i]; - - if (beta_zero) - c_store(c_mem, c, unroll_m); - else { - c_load(c_old, c_mem, unroll_m); - vpaddd(c_old, c, c_old); - c_store(c_mem, c_old, unroll_m); - } - - vpxorq(c, c, c); - } - } - - lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]); -} - -// Outer loop. -void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y, - Label *&cur_outerloop_label) -{ - Label label_m_loop, label_n_loop, label_n_remainder_loops[6]; - - L(*cur_outerloop_label); - cur_outerloop_label++; - if (unroll_x >= IGEMM_UNROLL_M) { - mov(J, M); - cmp(J, unroll_x); - jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label. - } else { - test(J, unroll_x); - jle(*cur_outerloop_label, T_NEAR); - } - - L_aligned(label_m_loop); { - mov(CO1, C); - add(C, unroll_x * size); - - mov(BO, B); - - mov(AA, K); - imul(AA, AA, unroll_x * isize); - lea(AA, ptr[A + AA + isize * prefetch_size_a]); - - if (enable_offset_c) { - mov(rax, coffset_cx); - mov(coffset_cy, rax); - add(rax, unroll_x * size); - mov(coffset_cx, rax); - } - - if (enable_offset_r) { - mov(rax, coffset_rx); - mov(coffset_ry, rax); - } - - mov(I, N); - cmp(I, unroll_y); - jl(label_n_remainder_loops[0], T_NEAR); - - L_aligned(label_n_loop); { - innerloop(unroll_x, unroll_y); - sub(I, unroll_y); - cmp(I, unroll_y); - jge(label_n_loop, T_NEAR); - } - - align(16); - - int label_idx = 0; - for (int uy = 16; uy > 0; uy >>= 1) { - L(label_n_remainder_loops[label_idx++]); - if (unroll_y > uy) { - test(I, uy); - jle(label_n_remainder_loops[label_idx], T_NEAR); - - innerloop(unroll_x, uy); - align(16); - } - } - L(label_n_remainder_loops[label_idx]); - - mov(A, AO); - if (unroll_x >= IGEMM_UNROLL_M) { - sub(J, unroll_x); - cmp(J, unroll_x); - jge(label_m_loop); - } - } - - align(16); -} - -void jit_avx512_core_gemm_s8u8s32_kern::generate() -{ - // Prologue - preamble(); - sub(rsp, stack_alloc_size); - - if (is_windows) { - mov(A, arg_a); - mov(B, arg_b); - } - - mov(C, arg_c); - mov(LDC, arg_ldc); - - sub(A, -offset_a * isize); - sub(B, -offset_b * isize); - - mov(M, qword[M]); - mov(N, qword[N]); - mov(K, qword[K]); - - lea(LDC, ptr[LDC * size]); - - if (enable_offset_c) { - mov(rax, arg_coffset_c); - mov(coffset_cx, rax); - } - if (enable_offset_r) { - mov(rax, arg_coffset_r); - mov(coffset_rx, rax); - } - - for (int i = 0; i < (max_unroll_m >> 4); i++) { - for (int j = 0; j < max_unroll_n; j++) { - auto &c = c_regs[i][j]; - vpxorq(c, c, c); - } - } - - if (!vnni) { - mov(rax, 1); - movq(make_xmm(ones), rax); - vpbroadcastw(ones, make_xmm(ones)); - } - - Label outerloop_labels[8]; - Label *cur_outerloop_label = &outerloop_labels[0]; - - // Main m loop. - outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label); - - // m remainder loops. - for (int um = 32; um > 0; um >>= 1) - if (IGEMM_UNROLL_M > um) - outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label); - - L(*cur_outerloop_label); - - // Epilogue. - add(rsp, stack_alloc_size); - postamble(); -} - - -jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool - beta_zero_, bool enable_offset_c_, bool enable_offset_r_) : - jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0), - arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0), - coffset_rx(0), coffset_ry(0) -{ - beta_zero = beta_zero_; - enable_offset_c = enable_offset_c_; - enable_offset_r = enable_offset_r_; - vnni = mayiuse(avx512_core_vnni); - - // Assign integer registers - M = is_windows ? rcx : rdi; - N = is_windows ? rdx : rsi; - K = is_windows ? r8 : rdx; - A = is_windows ? rsi : r8; - B = r9; - C = r10; - LDC = r11; - I = r12; - J = r13; - LoopCount = rax; - AO = r14; - BO = r15; - CO1 = rbx; - CO2 = rbp; - AA = is_windows ? rdi : rcx; - - // Assign vector registers - dp_scratch = zmm6; - ones = zmm7; - for (int i = 0; i < (max_unroll_m >> 4); i++) - a_regs[i] = Zmm(i); - b_regs[0] = zmm4; - b_regs[1] = zmm5; - - int rn = 0; - for (int i = 0; i < (max_unroll_m >> 4); i++) - for (int j = 0; j < max_unroll_n; j++) - c_regs[i][j] = Zmm(8 + rn++); - - // Assign stack variables. - stack_alloc_size = 32; - auto args_offset = stack_alloc_size + get_size_of_abi_save_regs() - + 8 + (is_windows ? 48 : 0); - - arg_a = ptr[rsp + (args_offset - 16)]; - arg_b = ptr[rsp + (args_offset - 8)]; - arg_c = ptr[rsp + (args_offset + 0)]; - arg_ldc = ptr[rsp + (args_offset + 8)]; - arg_coffset_c = ptr[rsp + (args_offset + 16)]; - arg_coffset_r = ptr[rsp + (args_offset + 24)]; - - coffset_cx = qword[rsp + 0]; - coffset_cy = qword[rsp + 8]; - coffset_rx = qword[rsp + 16]; - coffset_ry = qword[rsp + 24]; - - generate(); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp deleted file mode 100644 index e8efcc1cc..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp +++ /dev/null @@ -1,101 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef IGEMM_KERNEL_GENERATOR_HPP -#define IGEMM_KERNEL_GENERATOR_HPP - -#include "jit_generator.hpp" - - -namespace mkldnn { -namespace impl { -namespace cpu { - -class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator { -public: - jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_, - bool enable_offset_r_); - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern); - -protected: - bool beta_zero; - bool enable_offset_c, enable_offset_r; - bool vnni; - - void prefetch_a(const Xbyak::Address &src) { - prefetcht0(src); - } - void prefetch_b(const Xbyak::Address &src) { - prefetcht0(src); - } - void prefetch_c(const Xbyak::Address &src) { - prefetchw(src); - } - void prefetch_x(const Xbyak::Address &src) { - prefetcht0(src); - } - - void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems); - void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems); - - void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1, - const Xbyak::Xmm &src2); - void kernel_loop(int unroll_m, int unroll_n, bool cfetch); - void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth); - void innerloop(int unroll_m, int unroll_n); - void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); - - void generate(); - - -private: - static const int IGEMM_UNROLL_M = 48; - static const int IGEMM_UNROLL_N = 8; - - static const int isize = 2; - static const int size = 4; - - // Prefetch configuration - static const int prefetch_size_a = 32 * 5; - static const int prefetch_size_b = 32 * 4; - - static const int offset_a = 256, offset_b = 256; - static const int max_unroll_m = 48, max_unroll_n = 8; - - // Integer register assignments - Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount; - Xbyak::Reg64 AO, BO, CO1, CO2, AA; - - // Vector register assignments - Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2]; - Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n]; - - // Stack variable assignments - int stack_alloc_size; - Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r; - Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry; - - void L_aligned(Xbyak::Label &label, int alignment = 16) { - align(alignment); - L(label); - } -}; - -} -} -} - -#endif /* header guard */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp deleted file mode 100644 index 4f0b10dad..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp +++ /dev/null @@ -1,290 +0,0 @@ -/******************************************************************************* - * Copyright 2019 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include "gemv.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) { - - blas_t arg_gemv = *arg; - - if ((arg -> offsetc == FIX_OFFSET) && // Fix offset - (arg -> ao == 0) && - (arg -> bo == 0) && - (arg -> co[0] == 0) && - (*(arg -> alpha) == 1.0f) && - ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) { - - if (arg -> n == 1) { - - if (arg -> transa == 1) { // A transpose - arg_gemv.n = arg -> k; - arg_gemv.ldc = 1; - arg_gemv.swap = 0; - if (arg -> transb == 0) { // B non transpose - arg_gemv.ldb = 1; - } - // B transpose arg_gemv.ldb = arg -> ldb - gemv_threading_driver(&arg_gemv); - return 1; - } - } - - if (arg -> m == 1) { - - if (arg -> transb == 0) { // B non transpose - arg_gemv.transa = 1; - arg_gemv.m = arg -> n; - arg_gemv.n = arg -> k; - arg_gemv.a = (int8_t *) arg -> b; - arg_gemv.lda = arg -> ldb; - arg_gemv.b = (uint8_t *) arg -> a; - arg_gemv.swap = 1; - if (arg -> transa == 0) { // A non transpose - arg_gemv.ldb = arg -> lda; - } - else { // A transpose - arg_gemv.ldb = 1; - } - gemv_threading_driver(&arg_gemv); - return 1; - } - } - } - - return 0; -} - - -int gemv_kernel_driver(blas_t *arg) { - - dim_t m = arg -> m; - dim_t n = arg -> n; - uint8_t *a = (uint8_t *) arg -> a; - dim_t lda = arg -> lda; - int8_t *b = (int8_t *) arg -> b; - float beta = *(arg -> beta); - - if (arg -> swap) { - arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c); - } - else { - arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a, - arg -> lda, arg -> b, *(arg -> beta), arg -> c); - } - - return 0; -} - -int gemv_threading_driver(blas_t *arg) { - - dim_t nthr_m, nthr_n = 1; - dim_t MB, NB, UM = 16, UN = 64; - dim_t BLOCKM = 192, BLOCKN = 3072; - int status; - dim_t i; - - dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); - - uint8_t *new_x = NULL; - int32_t *tmp_y = NULL, *new_y = NULL; - - dim_t m = arg -> m, n = arg -> n; - - blas_t arg_seq = *arg; - float zero = 0.0f; - - nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr); - MB = m / nthr_m; - MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM; - nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1; - nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr); - - while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) { - nthr_n++; - } - - NB = n / nthr_n; - NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN; - nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1; - nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m); - - nthr = nthr_m * nthr_n; - - if (arg -> ldb != 1) { - new_x = (uint8_t *)malloc(n, 64); - if (new_x == NULL) - return 1; - for (i = 0; i < n; i++) { - new_x[i] = (arg -> b)[i * arg -> ldb]; - } - arg_seq.b = new_x; - arg_seq.ldb = 1; - } - else new_x = (uint8_t *) arg -> b; - - if (arg -> ldc != 1) { - new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64); - if (new_y == NULL) { - if (arg -> ldb != 1) { - free(new_x); - } - return 1; - } - } - - // GEMV computation - if (nthr == 1) { - - if (arg -> ldc != 1) { - if (*(arg -> beta) != 0.0f) { - for (i = 0; i < m; i++) { - new_y[i] = arg -> c[i * arg -> ldc]; - } - } - } - - status = gemv_kernel_driver(&arg_seq); - - if (arg -> ldc != 1) { - for (i = 0; i < m; i++) { - arg -> c[i * arg -> ldc] = new_y[i]; - } - } - - if (arg -> ldb != 1) { - free(new_x); - } - if (arg -> ldc != 1) { - free(new_y); - } - return status; - } - - if (nthr_n > 1) { - tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE); - if (tmp_y == NULL) { - if (arg -> ldb != 1) { - free(new_x); - } - return 1; - } - } - - parallel_nd((int) nthr, [&](const dim_t ithr) { - - dim_t m_from, m_to, myM; - dim_t n_from, n_to, myN; - - dim_t n_id, m_id; - dim_t loc_incy = 1; - int32_t *loc_y; - - blas_t arg_loc = arg_seq; - int j; - - m_id = ithr / nthr_n; - n_id = ithr % nthr_n; - - m_from = MB * m_id; - m_to = MB * (m_id + 1); - if ((m_to > m) || (m_id == nthr_m - 1)) - m_to = m; - - myM = m_to - m_from; - - n_from = NB * n_id; - n_to = NB * (n_id + 1); - if ((n_to > n) || (n_id == nthr_n - 1)) - n_to = n; - - myN = n_to - n_from; - - if (n_id != 0) { - arg_loc.beta = &zero; - loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from; - } - else { - if (arg -> ldc == 1) { - loc_y = arg_seq.c + m_from; - } - else { - // need to copy the block of c in new_y - loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t)); - if (*(arg -> beta) != 0.0f) { - for (j = 0; j < myM; j++) { - loc_y[j] = arg -> c[(m_from + j) * arg -> ldc]; - } - } - } - } - - arg_loc.m = myM; - arg_loc.n = myN; - arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from; - arg_loc.b = arg_seq.b + n_from; - arg_loc.c = loc_y; - arg_loc.ldc = loc_incy; - - gemv_kernel_driver(&arg_loc); - - if ((n_id == 0) && (arg -> ldc != 1)) { - for (j = 0; j < myM; j++) { - arg -> c[(m_from + j) * arg -> ldc] = loc_y[j]; - } - } - - }); - - if (nthr_n > 1) { - parallel_nd((int) nthr_m, [&](const dim_t ithr) { - - dim_t j, j_from, j_to, ii; - int32_t acc; - - j_from = MB * ithr; - j_to = MB * (ithr + 1); - if ((j_to > m) || (ithr == nthr - 1)) - j_to = m; - - for (j = j_from; j < j_to; j++) { - acc = 0; - for (ii = 0; ii < nthr_n - 1; ii++) { - acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j]; - } - (arg -> c)[j * arg -> ldc] += acc; - } - }); - free(tmp_y); - } - - if (arg -> ldb != 1) { - free(new_x); - } - - if (arg -> ldc != 1) { - free(new_y); - } - - return 0; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp deleted file mode 100644 index c57a8c1d1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp +++ /dev/null @@ -1,411 +0,0 @@ -/******************************************************************************* - * Copyright 2019 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" - -#ifdef _WIN32 -#define is_windows 1 -#else -#define is_windows 0 -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { - -void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b, - Xbyak::Zmm a, Xbyak::Zmm tmp, - Xbyak::Zmm one, bool swap, - int use_vnni) { - - if (use_vnni) { - if (swap) - vpdpbusd(acc, a, b); - else - vpdpbusd(acc, b, a); - } - - else { - if (swap) - vpmaddubsw(tmp, a, b); - else - vpmaddubsw(tmp, b, a); - vpmaddwd(tmp, tmp, one); - vpaddd(acc, tmp, acc); - } - -} - -void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx, - int b_idx, int nreg_acc, - Xbyak::Reg64 A, Xbyak::Reg64 lda, - Xbyak::Reg64 X, Xbyak::Zmm tmp, - Xbyak::Zmm one, bool swap, int use_vnni, - int use_mask, Xbyak::Opmask mask_n) { - - int i; - int nreg_A = nreg_acc / 2 + (nreg_acc % 2); - - // load X + j - if (use_mask) - vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]); - else - vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]); - - xor_(r14, r14); - // load values of A - for (i = 0; i < nreg_A; i++) { - if (use_mask) - vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); - else - vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); - add(r14, lda); - } - - for (i = 0; i < nreg_A; i++) { - // vnni (acc, b, a, tmp, one, swap, use_vnni) - vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx), - Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); - } - - for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { - if (use_mask) - vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); - else - vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); - add(r14, lda); - } - - for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { - vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx), - Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); - } - -} - -void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A, - Xbyak::Zmm B, Xbyak::Zmm C, - Xbyak::Zmm D) { - - vshufi32x4(dest, A, C, 0x44); - vshufi32x4(A, A, C, 0xEE); - vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3 - - vshufi32x4(dest, B, D, 0x44); - vshufi32x4(B, B, D, 0xEE); - vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3 - - vshufi32x4(A, C, D, 0x88); - vshufi32x4(B, C, D, 0xDD); - vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi - -} - -void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y, - int start_a_idx, int start_acc_idx, - Xbyak::Xmm beta, int use_mask, - Xbyak::Opmask mask_m) { - - int l, i, k, j, last_it; - Xbyak::Label store_label; - - l = 0; - for (k = 0; k < nreg_acc; k += 8) { - for (i = 0, j = k; i < 8; i += 4, j += 2) { - if (j < nreg_acc) { - // shuffle per block of 4 registers - shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest - Xbyak::Zmm(start_acc_idx + j), // A = acc0 - Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1 - Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4 - Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5 - - // extract low and high from dest and hadd - vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0); - vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1); - vphaddd(Xbyak::Ymm(start_a_idx + l), - Xbyak::Ymm(start_a_idx + l + 1), - Xbyak::Ymm(start_a_idx + l + 2)); - } - l++; - } - - vphaddd(Xbyak::Ymm(start_a_idx + l), - Xbyak::Ymm(start_a_idx + l - 2), - Xbyak::Ymm(start_a_idx + l - 1)); - - l++; - } - - // eventually add with C and store new value - vxorps(Xbyak::Ymm(start_a_idx), - Xbyak::Ymm(start_a_idx), - Xbyak::Ymm(start_a_idx)); - vucomiss(beta, Xbyak::Ymm(start_a_idx)); - je(store_label, T_NEAR); - - // beta = 1 - for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { - // load Y and add - last_it = (k + 8) > nreg_acc; - if (use_mask && last_it) - vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]); - else - vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]); - - vpaddd(Xbyak::Ymm(start_a_idx + l), - Xbyak::Ymm(start_a_idx + l), - Xbyak::Ymm(start_a_idx + k / 8)); - } - - // store - aligned_label(store_label); - for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { - last_it = (k + 8) > nreg_acc; - if (use_mask && last_it) - vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m); - else - vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l)); - } - -} - -template -T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) { - - Xbyak::Opmask mask_n = k1, mask_m = k2; - Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label; - Xbyak::Label n_tail_label, update_c_label, end_label; - constexpr unsigned int n_labels = (1 << unroll_m) - 1; - Xbyak::Label m_tail_label_case[n_labels]; - Xbyak::Label n_loop_label_case[n_labels]; - Xbyak::Label n_tail_label_case[n_labels]; - Xbyak::Label update_c_label_case[n_labels]; - - int i, ii; - - Xbyak::Zmm one, tmp; - Xbyak::Reg64 n = abi_param2, m = abi_param1; - Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3; - Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4; - Xbyak::Reg64 X = is_windows ? rdi : r8; - Xbyak::Xmm beta = xmm1; - Xbyak::Reg64 Y = is_windows ? rsi : r9; - - bool swap = !std::is_same::value; - - // Windows: read on the stack lda, X, beta, Y - - int zmm_idx = 1; - int nreg_acc = 1 << unroll_m; - int nreg_A = 1 << (unroll_m - 1); - int nreg_A_acc = nreg_acc + nreg_A; - - if (!use_vnni) { - // set a zmm register to one - tmp = Xbyak::Zmm(0); - one = Xbyak::Zmm(zmm_idx + 1); - zmm_idx += 2; // one + tmp - } - else { - beta = xmm0; - } - - preamble(); - - if (is_windows) { - mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]); - mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]); - movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]); - mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]); - } - - if (use_vnni && !is_windows) { - movaps(beta, xmm1); - } - - mov(rax, (1 << unroll_n) - 1); - kmovq(k3, rax); - - and_(rax, n); // rax contains n & ((1 << unroll_n) - 1) - mov(rbx, 1); - shlx(rbx, rbx, rax); - sub(rbx, 1); - kmovq(mask_n, rbx); - // mask_n set (AVX512 only), can use rax and rbx again - - // set mask_m for update of the C matrix - // load/store on the C matrix use Ymm so tail according to Ymm size - mov(rax, 7); // 8 * 32 = 256 Ymm size - and_(rax, m); // rax contains m & 7 - mov(rbx, 1); - shlx(rbx, rbx, rax); - sub(rbx, 1); - kmovq(mask_m, rbx); - // mask_m set (AVX512 only), can use rax and rbx again - - // setup register of ones when VNNI instructions not available - if (!use_vnni) { - vmovdqu16(one, ptr[rip + one_label]); - } - - // M loop - // base pointer for A rax contains a + i * lda - // Loop stop when rax >= a + (m & mask_um) * lda = rbx - // loop increment r10 = um * lda - // rbp = Y + i - mov(rax, A); // i = 0 - mov(rbx, m); - and_(rbx, mask_um); - imul(rbx, lda); - add(rbx, A); - mov(r10, lda); - sal(r10, unroll_m); - mov(rbp, Y); - - // N loop - // base pointer for X r11 contains x + j - // Loop stop when r11 >= x + n & mask_un = r12 - // loop increment un - // r13 = rax + j = A + i * lda + j - mov(r12, n); - and_(r12, mask_un); - add(r12, X); - - // M loop - aligned_label(m_loop_label); - cmp(rax, rbx); - jge(m_tail_label, T_NEAR); - - // enter M loop - for(i = 0; i < nreg_acc; i++) { - vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), - Xbyak::Zmm(i + zmm_idx + nreg_A), - Xbyak::Zmm(i + zmm_idx + nreg_A)); - } - - // N loop - mov(r11, X); // j = 0 - mov(r13, rax); - aligned_label(n_loop_label); - cmp(r11, r12); - jge(n_tail_label, T_NEAR); - - // enter N loop - - n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, - r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n); - - // increment rax with un - add(r11, 1 << unroll_n); - add(r13, 1 << unroll_n); - jmp(n_loop_label, T_NEAR); - // end N loop - - // N tail - aligned_label(n_tail_label); - - ktestq(mask_n, k3); - je(update_c_label, T_NEAR); - n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, - r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n); - - // update C matrix - aligned_label(update_c_label); - - update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m); - - // increment rax with um * lda - add(rax, r10); - add(rbp, 1 << (unroll_m + 2)); - jmp(m_loop_label, T_NEAR); - // end M loop - - // M tail - aligned_label(m_tail_label); - - // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1 - mov(r10, m); - and_(r10, (1 << unroll_m) - 1); - for (ii = 1; ii < 1 << unroll_m; ii++) { - aligned_label(m_tail_label_case[ii-1]); - cmp(r10, ii); - if (ii == (1 << unroll_m) - 1) - jne(end_label, T_NEAR); - else - jne(m_tail_label_case[ii], T_NEAR); - - // m_tail = i, use i accumulators - - for(i = 0; i < ii; i++) { - vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), - Xbyak::Zmm(i + zmm_idx + nreg_A), - Xbyak::Zmm(i + zmm_idx + nreg_A)); - } - - // N loop - mov(r11, X); // j = 0 - mov(r13, rax); - aligned_label(n_loop_label_case[ii - 1]); - cmp(r11, r12); - jge(n_tail_label_case[ii - 1], T_NEAR); - - n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, - lda, r11, tmp, one, swap, use_vnni, 0, mask_n); - - // increment rax with un - add(r11, 1 << unroll_n); - add(r13, 1 << unroll_n); - jmp(n_loop_label_case[ii - 1], T_NEAR); - // end N loop - - // N tail - aligned_label(n_tail_label_case[ii - 1]); - ktestq(mask_n, k3); - je(update_c_label_case[ii - 1], T_NEAR); - n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, - lda, r11, tmp, one, swap, use_vnni, 1, mask_n); - - // update C matrix - aligned_label(update_c_label_case[ii - 1]); - update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m); - - if (ii < ((1 << unroll_m) - 1)) - jmp(end_label, T_NEAR); - } - - aligned_label(end_label); - - postamble(); - - if (!use_vnni) { - aligned_label(one_label); - for (i = 0; i < size_vec_reg/8; i++) - dq(0x0001000100010001); - } - - return (T) getCode(); -} - -template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t -jit_avx512_core_gemv_s8u8s32_kern::generate(int); - -template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t -jit_avx512_core_gemv_s8u8s32_kern::generate(int); - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp deleted file mode 100644 index 9ea23a5f5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright 2019 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -class jit_avx512_core_gemv_s8u8s32_kern : jit_generator { - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern); - - // assumes untoll_{m,n} are a power of 2 - static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m - const int mask_um = 0xFFFFFFF0; - static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n - const int mask_un = 0xFFFFFFC0; - const int size_vec_reg = 64; // bytes - - void aligned_label(Xbyak::Label &label, int alignment = 16) { - align(alignment); - L(label); - } - - void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int); - void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64, - Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask); - void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm); - void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask); - -public: - jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {}; - - // m, n, alpha, a, lda, x, beta, y - typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float, - const int8_t*, const dim_t, const uint8_t*, - const float, int32_t*); - typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float, - const uint8_t*, const dim_t, const int8_t*, - const float, int32_t*); - - template - T generate(int use_vnni); - -}; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp deleted file mode 100644 index 544cd2ff2..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp +++ /dev/null @@ -1,819 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l170; -Xbyak::Label l1f0; -Xbyak::Label l20; -Xbyak::Label l224; -Xbyak::Label l234; -Xbyak::Label l240; -Xbyak::Label l254; -Xbyak::Label l32c; -Xbyak::Label l34; -Xbyak::Label l388; -Xbyak::Label l3b0; -Xbyak::Label l3c0; -Xbyak::Label l3cc; -Xbyak::Label l3dc; -Xbyak::Label l454; -Xbyak::Label l48c; -Xbyak::Label l4a8; -Xbyak::Label l4b8; -Xbyak::Label l4c4; -Xbyak::Label l4d8; -Xbyak::Label l570; -Xbyak::Label l5c4; -Xbyak::Label l5f0; -Xbyak::Label l60c; -Xbyak::Label l61c; -Xbyak::Label l628; -Xbyak::Label l638; -Xbyak::Label l6b0; -Xbyak::Label l6f4; -Xbyak::Label l720; -Xbyak::Label l73c; -Xbyak::Label l74c; -Xbyak::Label l758; -Xbyak::Label l76c; -Xbyak::Label l804; -Xbyak::Label l858; -Xbyak::Label l88c; -Xbyak::Label l8a4; -Xbyak::Label l8b2; -Xbyak::Label l8bc; -Xbyak::Label l8cc; -Xbyak::Label l944; -Xbyak::Label l98c; -Xbyak::Label l9b0; -Xbyak::Label l9c8; -Xbyak::Label l9d8; - - preamble(); -#ifdef _WIN32 - auto stacksize = get_size_of_abi_save_regs(); - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(M, qword[M]); - mov(N, qword[N]); - mov(LDA, qword[LDA]); - lea(LDA3, ptr[LDA+LDA*2]); - sub(A, -128); - sub(B, -128); - cmp(N, 0x30); - jl(l234, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - add(A, 0x30); - mov(I, M); - sar(I, 0x2); - jle(l170, T_NEAR); - align(4); - -L(l34); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movdqu(xword[B-0x60], xmm4); - movdqu(xword[B-0x50], xmm2); - movdqu(xmm0, xword[A1-0x70]); - movdqu(xmm1, xword[A1+LDA*1-0x70]); - movdqu(xmm2, xword[A1+LDA*2-0x70]); - movdqu(xmm3, xword[A1+LDA3*1-0x70]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - movdqu(xword[B-0x40], xmm0); - movdqu(xword[B-0x30], xmm1); - movdqu(xword[B-0x20], xmm4); - movdqu(xword[B-0x10], xmm2); - movdqu(xmm0, xword[A1-0x60]); - movdqu(xmm1, xword[A1+LDA*1-0x60]); - movdqu(xmm2, xword[A1+LDA*2-0x60]); - movdqu(xmm3, xword[A1+LDA3*1-0x60]); - lea(A1, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - movdqu(xword[B], xmm0); - movdqu(xword[B+0x10], xmm1); - movdqu(xword[B+0x20], xmm4); - movdqu(xword[B+0x30], xmm2); - sub(B, -192); - dec(I); - jg(l34, T_NEAR); - align(4); - -L(l170); - test(M, 0x2); - jle(l1f0, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1-0x70]); - movdqu(xmm2, xword[A1-0x60]); - add(A1, LDA); - movdqu(xmm3, xword[A1-0x80]); - movdqu(xmm4, xword[A1-0x70]); - movdqu(xmm5, xword[A1-0x60]); - add(A1, LDA); - movdqa(xmm6, xmm0); - punpcklbw(xmm0, xmm3); - punpckhbw(xmm6, xmm3); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm6); - movdqa(xmm6, xmm1); - punpcklbw(xmm1, xmm4); - punpckhbw(xmm6, xmm4); - movdqu(xword[B-0x60], xmm1); - movdqu(xword[B-0x50], xmm6); - movdqa(xmm6, xmm2); - punpcklbw(xmm2, xmm5); - punpckhbw(xmm6, xmm5); - movdqu(xword[B-0x40], xmm2); - movdqu(xword[B-0x30], xmm6); - sub(B, -96); - align(4); - -L(l1f0); - test(M, 0x1); - jle(l224, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1-0x70]); - movdqu(xmm2, xword[A1-0x60]); - add(A1, LDA); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movdqu(xword[B-0x60], xmm2); - sub(B, -48); - align(4); - -L(l224); - sub(N, 0x30); - cmp(N, 0x30); - jge(l20, T_NEAR); - align(4); - -L(l234); - cmp(N, 0x20); - jl(l3c0, T_NEAR); - align(4); - -L(l240); - mov(A1, A); - add(A, 0x20); - mov(I, M); - sar(I, 0x2); - jle(l32c, T_NEAR); - align(4); - -L(l254); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movdqu(xword[B-0x60], xmm4); - movdqu(xword[B-0x50], xmm2); - movdqu(xmm0, xword[A1-0x70]); - movdqu(xmm1, xword[A1+LDA*1-0x70]); - movdqu(xmm2, xword[A1+LDA*2-0x70]); - movdqu(xmm3, xword[A1+LDA3*1-0x70]); - lea(A1, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - movdqu(xword[B-0x40], xmm0); - movdqu(xword[B-0x30], xmm1); - movdqu(xword[B-0x20], xmm4); - movdqu(xword[B-0x10], xmm2); - sub(B, -128); - dec(I); - jg(l254, T_NEAR); - align(4); - -L(l32c); - test(M, 0x2); - jle(l388, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1-0x70]); - add(A1, LDA); - movdqu(xmm2, xword[A1-0x80]); - movdqu(xmm3, xword[A1-0x70]); - add(A1, LDA); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm2); - punpckhbw(xmm4, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm4); - movdqa(xmm4, xmm1); - punpcklbw(xmm1, xmm3); - punpckhbw(xmm4, xmm3); - movdqu(xword[B-0x60], xmm1); - movdqu(xword[B-0x50], xmm4); - sub(B, -64); - align(4); - -L(l388); - test(M, 0x1); - jle(l3b0, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1-0x70]); - add(A1, LDA); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l3b0); - sub(N, 0x20); - cmp(N, 0x20); - jge(l240, T_NEAR); - align(4); - -L(l3c0); - cmp(N, 0x10); - jl(l4b8, T_NEAR); - align(4); - -L(l3cc); - mov(A1, A); - add(A, 0x10); - mov(I, M); - sar(I, 0x2); - jle(l454, T_NEAR); - align(4); - -L(l3dc); - movdqu(xmm0, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm1, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm2, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm3, xword[A1-0x80]); - add(A1, LDA); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm1, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm1, xmm3); - movdqa(xmm3, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm3, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm1); - punpckhwd(xmm2, xmm1); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm3); - movdqu(xword[B-0x60], xmm4); - movdqu(xword[B-0x50], xmm2); - sub(B, -64); - dec(I); - jg(l3dc, T_NEAR); - align(4); - -L(l454); - test(M, 0x2); - jle(l48c, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm1, xword[A1-0x80]); - add(A1, LDA); - movdqa(xmm2, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm2, xmm1); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm2); - sub(B, -32); - align(4); - -L(l48c); - test(M, 0x1); - jle(l4a8, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - add(A1, LDA); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l4a8); - sub(N, 0x10); - cmp(N, 0x10); - jge(l3cc, T_NEAR); - align(4); - -L(l4b8); - cmp(N, 0x8); - jl(l61c, T_NEAR); - align(4); - -L(l4c4); - mov(A1, A); - add(A, 0x8); - mov(I, M); - sar(I, 0x3); - jle(l570, T_NEAR); - align(4); - -L(l4d8); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - dec(I); - jg(l4d8, T_NEAR); - align(4); - -L(l570); - test(M, 0x4); - jle(l5c4, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l5c4); - test(M, 0x2); - jle(l5f0, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l5f0); - test(M, 0x1); - jle(l60c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l60c); - sub(N, 0x8); - cmp(N, 0x8); - jge(l4c4, T_NEAR); - align(4); - -L(l61c); - cmp(N, 0x4); - jl(l74c, T_NEAR); - align(4); - -L(l628); - mov(A1, A); - add(A, 0x4); - mov(I, M); - sar(I, 0x3); - jle(l6b0, T_NEAR); - align(4); - -L(l638); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - dec(I); - jg(l638, T_NEAR); - align(4); - -L(l6b0); - test(M, 0x4); - jle(l6f4, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l6f4); - test(M, 0x2); - jle(l720, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l720); - test(M, 0x1); - jle(l73c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l73c); - sub(N, 0x4); - cmp(N, 0x4); - jge(l628, T_NEAR); - align(4); - -L(l74c); - cmp(N, 0x2); - jl(l8b2, T_NEAR); - align(4); - -L(l758); - mov(A1, A); - add(A, 0x2); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l804, T_NEAR); - align(4); - -L(l76c); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm4, eax, 0x0); - punpcklbw(xmm1, xmm2); - punpcklbw(xmm3, xmm4); - punpcklwd(xmm1, xmm3); - punpcklqdq(xmm0, xmm1); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(LDA3); - jg(l76c, T_NEAR); - align(4); - -L(l804); - test(M, 0x4); - jle(l858, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l858); - test(M, 0x2); - jle(l88c, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - punpcklbw(xmm0, xmm1); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l88c); - test(M, 0x1); - jle(l8a4, T_NEAR); - mov(ax, word[A1-0x80]); - mov(word[B-0x80], ax); - sub(B, -2); - align(4); - -L(l8a4); - sub(N, 0x2); - cmp(N, 0x2); - jge(l758, T_NEAR); - align(4); - -L(l8b2); - cmp(N, 0x1); - jl(l9d8, T_NEAR); - align(4); - -L(l8bc); - mov(A1, A); - add(A, 0x1); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l944, T_NEAR); - align(4); - -L(l8cc); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x7); - movq(qword[B-0x80], xmm0); - sub(B, -8); - dec(LDA3); - jg(l8cc, T_NEAR); - align(4); - -L(l944); - test(M, 0x4); - jle(l98c, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l98c); - test(M, 0x2); - jle(l9b0, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - mov(byte[B-0x80], al); - mov(al, byte[A1-0x80]); - add(A1, LDA); - mov(byte[B-0x7f], al); - sub(B, -2); - align(4); - -L(l9b0); - test(M, 0x1); - jle(l9c8, T_NEAR); - mov(al, byte[A1-0x80]); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l9c8); - sub(N, 0x1); - cmp(N, 0x1); - jge(l8bc, T_NEAR); - align(4); - -L(l9d8); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp deleted file mode 100644 index 1c11fc6ce..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp +++ /dev/null @@ -1,2209 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l1014; -Xbyak::Label l1390; -Xbyak::Label l159c; -Xbyak::Label l173c; -Xbyak::Label l18e4; -Xbyak::Label l1a7c; -Xbyak::Label l1a8c; -Xbyak::Label l1a98; -Xbyak::Label l1ab4; -Xbyak::Label l1c64; -Xbyak::Label l1d74; -Xbyak::Label l1e50; -Xbyak::Label l1f2c; -Xbyak::Label l1ffc; -Xbyak::Label l20; -Xbyak::Label l200c; -Xbyak::Label l2018; -Xbyak::Label l2034; -Xbyak::Label l2110; -Xbyak::Label l21a0; -Xbyak::Label l2210; -Xbyak::Label l2284; -Xbyak::Label l22f0; -Xbyak::Label l2300; -Xbyak::Label l230c; -Xbyak::Label l2324; -Xbyak::Label l2398; -Xbyak::Label l23e8; -Xbyak::Label l242c; -Xbyak::Label l2474; -Xbyak::Label l24b4; -Xbyak::Label l24c4; -Xbyak::Label l24d0; -Xbyak::Label l24e8; -Xbyak::Label l2520; -Xbyak::Label l254c; -Xbyak::Label l2578; -Xbyak::Label l25a8; -Xbyak::Label l25c8; -Xbyak::Label l25d6; -Xbyak::Label l25e0; -Xbyak::Label l25f0; -Xbyak::Label l260c; -Xbyak::Label l262c; -Xbyak::Label l264c; -Xbyak::Label l2668; -Xbyak::Label l2680; -Xbyak::Label l2690; -Xbyak::Label l44; -Xbyak::Label l58c; -Xbyak::Label l8b0; -Xbyak::Label lb14; -Xbyak::Label ld84; -Xbyak::Label lfdc; -Xbyak::Label lfec; -Xbyak::Label lff8; - - preamble(); -#ifdef _WIN32 - auto stacksize = get_size_of_abi_save_regs(); - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(N, qword[N]); - mov(M, qword[M]); - mov(LDA, qword[LDA]); - sub(A, -128); - sub(B, -128); - lea(LDA3, ptr[LDA+LDA*2]); - cmp(N, 0x30); - jl(lfec, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - mov(I, LDA); - shl(I, 0x5); - lea(I, ptr[I+LDA*8]); - lea(I, ptr[I+LDA*8]); - add(A, I); - mov(I, M); - sar(I, 0x4); - jle(l58c, T_NEAR); - align(4); - -L(l44); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B+0x40], xmm1); - movdqu(xword[B+0x100], xmm4); - movdqu(xword[B+0x1c0], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B+0x50], xmm1); - movdqu(xword[B+0x110], xmm4); - movdqu(xword[B+0x1d0], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B+0x60], xmm1); - movdqu(xword[B+0x120], xmm4); - movdqu(xword[B+0x1e0], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x50], xmm0); - movdqu(xword[B+0x70], xmm1); - movdqu(xword[B+0x130], xmm4); - movdqu(xword[B+0x1f0], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x40], xmm0); - movdqu(xword[B+0x80], xmm1); - movdqu(xword[B+0x140], xmm4); - movdqu(xword[B+0x200], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x30], xmm0); - movdqu(xword[B+0x90], xmm1); - movdqu(xword[B+0x150], xmm4); - movdqu(xword[B+0x210], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x20], xmm0); - movdqu(xword[B+0xa0], xmm1); - movdqu(xword[B+0x160], xmm4); - movdqu(xword[B+0x220], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x10], xmm0); - movdqu(xword[B+0xb0], xmm1); - movdqu(xword[B+0x170], xmm4); - movdqu(xword[B+0x230], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B], xmm0); - movdqu(xword[B+0xc0], xmm1); - movdqu(xword[B+0x180], xmm4); - movdqu(xword[B+0x240], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B+0x10], xmm0); - movdqu(xword[B+0xd0], xmm1); - movdqu(xword[B+0x190], xmm4); - movdqu(xword[B+0x250], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B+0x20], xmm0); - movdqu(xword[B+0xe0], xmm1); - movdqu(xword[B+0x1a0], xmm4); - movdqu(xword[B+0x260], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B+0x30], xmm0); - movdqu(xword[B+0xf0], xmm1); - movdqu(xword[B+0x1b0], xmm4); - movdqu(xword[B+0x270], xmm3); - sub(A1, -16); - sub(B, -768); - dec(I); - jg(l44, T_NEAR); - align(4); - -L(l58c); - test(M, 0x8); - jle(l8b0, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B+0x40], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B+0x50], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B+0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x50], xmm0); - movdqu(xword[B+0x70], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x40], xmm0); - movdqu(xword[B+0x80], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x30], xmm0); - movdqu(xword[B+0x90], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x20], xmm0); - movdqu(xword[B+0xa0], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x10], xmm0); - movdqu(xword[B+0xb0], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B], xmm0); - movdqu(xword[B+0xc0], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B+0x10], xmm0); - movdqu(xword[B+0xd0], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B+0x20], xmm0); - movdqu(xword[B+0xe0], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B+0x30], xmm0); - movdqu(xword[B+0xf0], xmm1); - sub(A1, -8); - sub(B, -384); - align(4); - -L(l8b0); - test(M, 0x4); - jle(lb14, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x60], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x50], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x40], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x30], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x20], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x10], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B+0x10], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B+0x20], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B+0x30], xmm0); - sub(A1, -4); - sub(B, -192); - align(4); - -L(lb14); - test(M, 0x2); - jle(ld84, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x7); - movdqu(xword[B-0x80], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x70], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x60], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x50], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x40], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x30], xmm0); - sub(A1, -2); - sub(B, -96); - align(4); - -L(ld84); - test(M, 0x1); - jle(lfdc, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - movdqu(xword[B-0x80], xmm0); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - movdqu(xword[B-0x70], xmm0); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - movdqu(xword[B-0x60], xmm0); - sub(B, -48); - align(4); - -L(lfdc); - sub(N, 0x30); - cmp(N, 0x30); - jge(l20, T_NEAR); - align(4); - -L(lfec); - cmp(N, 0x20); - jl(l1a8c, T_NEAR); - align(4); - -L(lff8); - mov(A1, A); - mov(I, LDA); - shl(I, 0x5); - add(A, I); - mov(I, M); - sar(I, 0x4); - jle(l1390, T_NEAR); - align(4); - -L(l1014); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B], xmm1); - movdqu(xword[B+0x80], xmm4); - movdqu(xword[B+0x100], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B+0x10], xmm1); - movdqu(xword[B+0x90], xmm4); - movdqu(xword[B+0x110], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B+0x20], xmm1); - movdqu(xword[B+0xa0], xmm4); - movdqu(xword[B+0x120], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x50], xmm0); - movdqu(xword[B+0x30], xmm1); - movdqu(xword[B+0xb0], xmm4); - movdqu(xword[B+0x130], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x40], xmm0); - movdqu(xword[B+0x40], xmm1); - movdqu(xword[B+0xc0], xmm4); - movdqu(xword[B+0x140], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x30], xmm0); - movdqu(xword[B+0x50], xmm1); - movdqu(xword[B+0xd0], xmm4); - movdqu(xword[B+0x150], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x20], xmm0); - movdqu(xword[B+0x60], xmm1); - movdqu(xword[B+0xe0], xmm4); - movdqu(xword[B+0x160], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x10], xmm0); - movdqu(xword[B+0x70], xmm1); - movdqu(xword[B+0xf0], xmm4); - movdqu(xword[B+0x170], xmm3); - sub(A1, -16); - sub(B, -512); - dec(I); - jg(l1014, T_NEAR); - align(4); - -L(l1390); - test(M, 0x8); - jle(l159c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B+0x10], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B+0x20], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x50], xmm0); - movdqu(xword[B+0x30], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x40], xmm0); - movdqu(xword[B+0x40], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x30], xmm0); - movdqu(xword[B+0x50], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x20], xmm0); - movdqu(xword[B+0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x10], xmm0); - movdqu(xword[B+0x70], xmm1); - sub(A1, -8); - sub(B, -256); - align(4); - -L(l159c); - test(M, 0x4); - jle(l173c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x60], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x50], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x40], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x30], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x20], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x10], xmm0); - sub(A1, -4); - sub(B, -128); - align(4); - -L(l173c); - test(M, 0x2); - jle(l18e4, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x7); - movdqu(xword[B-0x80], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x70], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x60], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - movdqu(xword[B-0x50], xmm0); - sub(A1, -2); - sub(B, -64); - align(4); - -L(l18e4); - test(M, 0x1); - jle(l1a7c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - movdqu(xword[B-0x80], xmm0); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - align(4); - -L(l1a7c); - sub(N, 0x20); - cmp(N, 0x20); - jge(lff8, T_NEAR); - align(4); - -L(l1a8c); - cmp(N, 0x10); - jl(l200c, T_NEAR); - align(4); - -L(l1a98); - mov(A1, A); - mov(I, LDA); - shl(I, 0x4); - add(A, I); - mov(I, M); - sar(I, 0x4); - jle(l1c64, T_NEAR); - align(4); - -L(l1ab4); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x40], xmm1); - movdqu(xword[B], xmm4); - movdqu(xword[B+0x40], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B-0x30], xmm1); - movdqu(xword[B+0x10], xmm4); - movdqu(xword[B+0x50], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B-0x20], xmm1); - movdqu(xword[B+0x20], xmm4); - movdqu(xword[B+0x60], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x50], xmm0); - movdqu(xword[B-0x10], xmm1); - movdqu(xword[B+0x30], xmm4); - movdqu(xword[B+0x70], xmm3); - sub(A1, -16); - sub(B, -256); - dec(I); - jg(l1ab4, T_NEAR); - align(4); - -L(l1c64); - test(M, 0x8); - jle(l1d74, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x40], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B-0x30], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B-0x20], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x50], xmm0); - movdqu(xword[B-0x10], xmm1); - sub(A1, -8); - sub(B, -128); - align(4); - -L(l1d74); - test(M, 0x4); - jle(l1e50, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x60], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x50], xmm0); - sub(A1, -4); - sub(B, -64); - align(4); - -L(l1e50); - test(M, 0x2); - jle(l1f2c, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x7); - movdqu(xword[B-0x80], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - movdqu(xword[B-0x70], xmm0); - sub(A1, -2); - sub(B, -32); - align(4); - -L(l1f2c); - test(M, 0x1); - jle(l1ffc, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0xf); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l1ffc); - sub(N, 0x10); - cmp(N, 0x10); - jge(l1a98, T_NEAR); - align(4); - -L(l200c); - cmp(N, 0x8); - jl(l2300, T_NEAR); - align(4); - -L(l2018); - mov(A1, A); - lea(A2, ptr[A1+LDA*4]); - lea(I, ptr[A1+LDA*8]); - mov(A, I); - mov(I, M); - sar(I, 0x4); - jle(l2110, T_NEAR); - align(4); - -L(l2034); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - sub(A1, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x60], xmm1); - movdqu(xword[B-0x40], xmm4); - movdqu(xword[B-0x20], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B-0x50], xmm1); - movdqu(xword[B-0x30], xmm4); - movdqu(xword[B-0x10], xmm3); - sub(B, -128); - dec(I); - jg(l2034, T_NEAR); - align(4); - -L(l2110); - test(M, 0x8); - jle(l21a0, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - sub(A1, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - align(4); - -L(l21a0); - test(M, 0x4); - jle(l2210, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - sub(A1, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - align(4); - -L(l2210); - test(M, 0x2); - jle(l2284, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x7); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l2284); - test(M, 0x1); - jle(l22f0, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x7); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l22f0); - sub(N, 0x8); - cmp(N, 0x8); - jge(l2018, T_NEAR); - align(4); - -L(l2300); - cmp(N, 0x4); - jl(l24c4, T_NEAR); - align(4); - -L(l230c); - mov(A1, A); - lea(A2, ptr[A1+LDA*2]); - lea(I, ptr[A1+LDA*4]); - mov(A, I); - mov(I, M); - sar(I, 0x4); - jle(l2398, T_NEAR); - align(4); - -L(l2324); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - sub(A1, -16); - movdqu(xmm2, xword[A2-0x80]); - movdqu(xmm3, xword[A2+LDA*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movdqu(xword[B-0x60], xmm4); - movdqu(xword[B-0x50], xmm3); - sub(B, -64); - dec(I); - jg(l2324, T_NEAR); - align(4); - -L(l2398); - test(M, 0x8); - jle(l23e8, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - sub(A1, -8); - movq(xmm2, qword[A2-0x80]); - movq(xmm3, qword[A2+LDA*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l23e8); - test(M, 0x4); - jle(l242c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - sub(A1, -4); - movd(xmm2, dword[A2-0x80]); - movd(xmm3, dword[A2+LDA*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l242c); - test(M, 0x2); - jle(l2474, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x3); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l2474); - test(M, 0x1); - jle(l24b4, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x3); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l24b4); - sub(N, 0x4); - cmp(N, 0x4); - jge(l230c, T_NEAR); - align(4); - -L(l24c4); - cmp(N, 0x2); - jl(l25d6, T_NEAR); - align(4); - -L(l24d0); - mov(A1, A); - lea(A2, ptr[A1+LDA*1]); - lea(I, ptr[A1+LDA*2]); - mov(A, I); - mov(I, M); - sar(I, 0x4); - jle(l2520, T_NEAR); - align(4); - -L(l24e8); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - movdqu(xmm1, xword[A2-0x80]); - sub(A2, -16); - movdqa(xmm2, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm2, xmm1); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm2); - sub(B, -32); - dec(I); - jg(l24e8, T_NEAR); - align(4); - -L(l2520); - test(M, 0x8); - jle(l254c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - movq(xmm1, qword[A2-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l254c); - test(M, 0x4); - jle(l2578, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - movd(xmm1, dword[A2-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l2578); - test(M, 0x2); - jle(l25a8, T_NEAR); - mov(ax, word[A1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x1); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l25a8); - test(M, 0x1); - jle(l25c8, T_NEAR); - mov(al, byte[A1-0x80]); - mov(byte[B-0x80], al); - mov(al, byte[A2-0x80]); - mov(byte[B-0x7f], al); - sub(B, -2); - align(4); - -L(l25c8); - sub(N, 0x2); - cmp(N, 0x2); - jge(l24d0, T_NEAR); - align(4); - -L(l25d6); - cmp(N, 0x1); - jl(l2690, T_NEAR); - align(4); - -L(l25e0); - mov(A1, A); - add(A, LDA); - mov(I, M); - sar(I, 0x4); - jle(l260c, T_NEAR); - align(4); - -L(l25f0); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(I); - jg(l25f0, T_NEAR); - align(4); - -L(l260c); - test(M, 0x8); - jle(l262c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l262c); - test(M, 0x4); - jle(l264c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l264c); - test(M, 0x2); - jle(l2668, T_NEAR); - mov(ax, word[A1-0x80]); - mov(word[B-0x80], ax); - sub(A1, -2); - sub(B, -2); - align(4); - -L(l2668); - test(M, 0x1); - jle(l2680, T_NEAR); - mov(al, byte[A1-0x80]); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l2680); - sub(N, 0x1); - cmp(N, 0x1); - jge(l25e0, T_NEAR); - align(4); - -L(l2690); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp deleted file mode 100644 index 56c36ee14..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp +++ /dev/null @@ -1,564 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l118; -Xbyak::Label l1a8; -Xbyak::Label l20; -Xbyak::Label l218; -Xbyak::Label l28c; -Xbyak::Label l2f8; -Xbyak::Label l308; -Xbyak::Label l314; -Xbyak::Label l32c; -Xbyak::Label l3a0; -Xbyak::Label l3c; -Xbyak::Label l3f0; -Xbyak::Label l434; -Xbyak::Label l47c; -Xbyak::Label l4bc; -Xbyak::Label l4cc; -Xbyak::Label l4d8; -Xbyak::Label l4f0; -Xbyak::Label l528; -Xbyak::Label l554; -Xbyak::Label l580; -Xbyak::Label l5b0; -Xbyak::Label l5d0; -Xbyak::Label l5de; -Xbyak::Label l5e8; -Xbyak::Label l5f8; -Xbyak::Label l614; -Xbyak::Label l634; -Xbyak::Label l654; -Xbyak::Label l670; -Xbyak::Label l688; -Xbyak::Label l698; - - preamble(); -#ifdef _WIN32 - auto stacksize = get_size_of_abi_save_regs(); - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(N, qword[N]); - mov(M, qword[M]); - mov(LDA, qword[LDA]); - sub(A, -128); - sub(B, -128); - lea(LDA3, ptr[LDA+LDA*2]); - cmp(N, 0x8); - jl(l308, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - lea(A2, ptr[A1+LDA*4]); - lea(I, ptr[A1+LDA*8]); - mov(A, I); - mov(I, M); - sar(I, 0x4); - jle(l118, T_NEAR); - align(4); - -L(l3c); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - sub(A1, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x60], xmm1); - movdqu(xword[B-0x40], xmm4); - movdqu(xword[B-0x20], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B-0x50], xmm1); - movdqu(xword[B-0x30], xmm4); - movdqu(xword[B-0x10], xmm3); - sub(B, -128); - dec(I); - jg(l3c, T_NEAR); - align(4); - -L(l118); - test(M, 0x8); - jle(l1a8, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - sub(A1, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x70], xmm0); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - align(4); - -L(l1a8); - test(M, 0x4); - jle(l218, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - sub(A1, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - align(4); - -L(l218); - test(M, 0x2); - jle(l28c, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x7); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l28c); - test(M, 0x1); - jle(l2f8, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x7); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l2f8); - sub(N, 0x8); - cmp(N, 0x8); - jge(l20, T_NEAR); - align(4); - -L(l308); - cmp(N, 0x4); - jl(l4cc, T_NEAR); - align(4); - -L(l314); - mov(A1, A); - lea(A2, ptr[A1+LDA*2]); - lea(I, ptr[A1+LDA*4]); - mov(A, I); - mov(I, M); - sar(I, 0x4); - jle(l3a0, T_NEAR); - align(4); - -L(l32c); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - sub(A1, -16); - movdqu(xmm2, xword[A2-0x80]); - movdqu(xmm3, xword[A2+LDA*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movdqu(xword[B-0x60], xmm4); - movdqu(xword[B-0x50], xmm3); - sub(B, -64); - dec(I); - jg(l32c, T_NEAR); - align(4); - -L(l3a0); - test(M, 0x8); - jle(l3f0, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - sub(A1, -8); - movq(xmm2, qword[A2-0x80]); - movq(xmm3, qword[A2+LDA*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l3f0); - test(M, 0x4); - jle(l434, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - sub(A1, -4); - movd(xmm2, dword[A2-0x80]); - movd(xmm3, dword[A2+LDA*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l434); - test(M, 0x2); - jle(l47c, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x3); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l47c); - test(M, 0x1); - jle(l4bc, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x3); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l4bc); - sub(N, 0x4); - cmp(N, 0x4); - jge(l314, T_NEAR); - align(4); - -L(l4cc); - cmp(N, 0x2); - jl(l5de, T_NEAR); - align(4); - -L(l4d8); - mov(A1, A); - lea(A2, ptr[A1+LDA*1]); - lea(I, ptr[A1+LDA*2]); - mov(A, I); - mov(I, M); - sar(I, 0x4); - jle(l528, T_NEAR); - align(4); - -L(l4f0); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - movdqu(xmm1, xword[A2-0x80]); - sub(A2, -16); - movdqa(xmm2, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm2, xmm1); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm2); - sub(B, -32); - dec(I); - jg(l4f0, T_NEAR); - align(4); - -L(l528); - test(M, 0x8); - jle(l554, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - movq(xmm1, qword[A2-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l554); - test(M, 0x4); - jle(l580, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - movd(xmm1, dword[A2-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l580); - test(M, 0x2); - jle(l5b0, T_NEAR); - mov(ax, word[A1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x1); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l5b0); - test(M, 0x1); - jle(l5d0, T_NEAR); - mov(al, byte[A1-0x80]); - mov(byte[B-0x80], al); - mov(al, byte[A2-0x80]); - mov(byte[B-0x7f], al); - sub(B, -2); - align(4); - -L(l5d0); - sub(N, 0x2); - cmp(N, 0x2); - jge(l4d8, T_NEAR); - align(4); - -L(l5de); - cmp(N, 0x1); - jl(l698, T_NEAR); - align(4); - -L(l5e8); - mov(A1, A); - add(A, LDA); - mov(I, M); - sar(I, 0x4); - jle(l614, T_NEAR); - align(4); - -L(l5f8); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(I); - jg(l5f8, T_NEAR); - align(4); - -L(l614); - test(M, 0x8); - jle(l634, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l634); - test(M, 0x4); - jle(l654, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l654); - test(M, 0x2); - jle(l670, T_NEAR); - mov(ax, word[A1-0x80]); - mov(word[B-0x80], ax); - sub(A1, -2); - sub(B, -2); - align(4); - -L(l670); - test(M, 0x1); - jle(l688, T_NEAR); - mov(al, byte[A1-0x80]); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l688); - sub(N, 0x1); - cmp(N, 0x1); - jge(l5e8, T_NEAR); - align(4); - -L(l698); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp deleted file mode 100644 index 53e99d94d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp +++ /dev/null @@ -1,501 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l120; -Xbyak::Label l14c; -Xbyak::Label l168; -Xbyak::Label l178; -Xbyak::Label l184; -Xbyak::Label l194; -Xbyak::Label l20; -Xbyak::Label l20c; -Xbyak::Label l250; -Xbyak::Label l27c; -Xbyak::Label l298; -Xbyak::Label l2a8; -Xbyak::Label l2b4; -Xbyak::Label l2c8; -Xbyak::Label l34; -Xbyak::Label l360; -Xbyak::Label l3b4; -Xbyak::Label l3e8; -Xbyak::Label l400; -Xbyak::Label l40e; -Xbyak::Label l418; -Xbyak::Label l428; -Xbyak::Label l4a0; -Xbyak::Label l4e8; -Xbyak::Label l50c; -Xbyak::Label l524; -Xbyak::Label l534; -Xbyak::Label lcc; - - preamble(); -#ifdef _WIN32 - auto stacksize = get_size_of_abi_save_regs(); - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(M, qword[M]); - mov(N, qword[N]); - mov(LDA, qword[LDA]); - lea(LDA3, ptr[LDA+LDA*2]); - sub(A, -128); - sub(B, -128); - cmp(N, 0x8); - jl(l178, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - add(A, 0x8); - mov(I, M); - sar(I, 0x3); - jle(lcc, T_NEAR); - align(4); - -L(l34); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - dec(I); - jg(l34, T_NEAR); - align(4); - -L(lcc); - test(M, 0x4); - jle(l120, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l120); - test(M, 0x2); - jle(l14c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l14c); - test(M, 0x1); - jle(l168, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l168); - sub(N, 0x8); - cmp(N, 0x8); - jge(l20, T_NEAR); - align(4); - -L(l178); - cmp(N, 0x4); - jl(l2a8, T_NEAR); - align(4); - -L(l184); - mov(A1, A); - add(A, 0x4); - mov(I, M); - sar(I, 0x3); - jle(l20c, T_NEAR); - align(4); - -L(l194); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - dec(I); - jg(l194, T_NEAR); - align(4); - -L(l20c); - test(M, 0x4); - jle(l250, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l250); - test(M, 0x2); - jle(l27c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l27c); - test(M, 0x1); - jle(l298, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l298); - sub(N, 0x4); - cmp(N, 0x4); - jge(l184, T_NEAR); - align(4); - -L(l2a8); - cmp(N, 0x2); - jl(l40e, T_NEAR); - align(4); - -L(l2b4); - mov(A1, A); - add(A, 0x2); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l360, T_NEAR); - align(4); - -L(l2c8); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm4, eax, 0x0); - punpcklbw(xmm1, xmm2); - punpcklbw(xmm3, xmm4); - punpcklwd(xmm1, xmm3); - punpcklqdq(xmm0, xmm1); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(LDA3); - jg(l2c8, T_NEAR); - align(4); - -L(l360); - test(M, 0x4); - jle(l3b4, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l3b4); - test(M, 0x2); - jle(l3e8, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - punpcklbw(xmm0, xmm1); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l3e8); - test(M, 0x1); - jle(l400, T_NEAR); - mov(ax, word[A1-0x80]); - mov(word[B-0x80], ax); - sub(B, -2); - align(4); - -L(l400); - sub(N, 0x2); - cmp(N, 0x2); - jge(l2b4, T_NEAR); - align(4); - -L(l40e); - cmp(N, 0x1); - jl(l534, T_NEAR); - align(4); - -L(l418); - mov(A1, A); - add(A, 0x1); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l4a0, T_NEAR); - align(4); - -L(l428); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x7); - movq(qword[B-0x80], xmm0); - sub(B, -8); - dec(LDA3); - jg(l428, T_NEAR); - align(4); - -L(l4a0); - test(M, 0x4); - jle(l4e8, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l4e8); - test(M, 0x2); - jle(l50c, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - mov(byte[B-0x80], al); - mov(al, byte[A1-0x80]); - add(A1, LDA); - mov(byte[B-0x7f], al); - sub(B, -2); - align(4); - -L(l50c); - test(M, 0x1); - jle(l524, T_NEAR); - mov(al, byte[A1-0x80]); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l524); - sub(N, 0x1); - cmp(N, 0x1); - jge(l418, T_NEAR); - align(4); - -L(l534); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp deleted file mode 100644 index 49a312fc8..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp +++ /dev/null @@ -1,1283 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#define ARG_BIAS 24+stacksize+rsp - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp -#define ARG_BIAS 72+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l1024; -Xbyak::Label l1090; -Xbyak::Label l10d4; -Xbyak::Label l10fc; -Xbyak::Label l111a; -Xbyak::Label l1124; -Xbyak::Label l113c; -Xbyak::Label l11d4; -Xbyak::Label l1234; -Xbyak::Label l1278; -Xbyak::Label l129c; -Xbyak::Label l12bc; -Xbyak::Label l20; -Xbyak::Label l2a0; -Xbyak::Label l3c0; -Xbyak::Label l438; -Xbyak::Label l480; -Xbyak::Label l48c; -Xbyak::Label l4c8; -Xbyak::Label l5c; -Xbyak::Label l6a8; -Xbyak::Label l7b4; -Xbyak::Label l850; -Xbyak::Label l89c; -Xbyak::Label l8a8; -Xbyak::Label l8d0; -Xbyak::Label l9d0; -Xbyak::Label la64; -Xbyak::Label lab8; -Xbyak::Label lae8; -Xbyak::Label laf4; -Xbyak::Label lb14; -Xbyak::Label lc30; -Xbyak::Label lcc8; -Xbyak::Label ld1c; -Xbyak::Label ld54; -Xbyak::Label ld78; -Xbyak::Label ld84; -Xbyak::Label ld9c; -Xbyak::Label le58; -Xbyak::Label lebc; -Xbyak::Label lef8; -Xbyak::Label lf1c; -Xbyak::Label lf3c; -Xbyak::Label lf48; -Xbyak::Label lf60; - - preamble(); - auto stacksize = get_size_of_abi_save_regs(); -#ifdef _WIN32 - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(M, qword[M]); - mov(N, qword[N]); - mov(LDA, qword[LDA]); - lea(LDA3, ptr[LDA+LDA*2]); - sub(A, -128); - sub(B, -128); - cmp(N, 0x30); - jl(l480, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - add(A, 0x30); - vxorps(ymm8, ymm8, ymm8); - vxorps(ymm9, ymm9, ymm9); - vxorps(ymm10, ymm10, ymm10); - vxorps(ymm11, ymm11, ymm11); - vxorps(ymm12, ymm12, ymm12); - vxorps(ymm13, ymm13, ymm13); - vxorps(ymm14, ymm14, ymm14); - vxorps(ymm15, ymm15, ymm15); - mov(I, M); - sar(I, 0x2); - jle(l2a0, T_NEAR); - align(4); - -L(l5c); - vmovdqu(xmm0, xword[A1-0x80]); - vmovdqu(xmm1, xword[A1+LDA*1-0x80]); - vmovdqu(xmm2, xword[A1+LDA*2-0x80]); - vmovdqu(xmm3, xword[A1+LDA3*1-0x80]); - vpunpcklbw(xmm4, xmm0, xmm1); - vpunpckhbw(xmm5, xmm0, xmm1); - vpunpcklbw(xmm6, xmm2, xmm3); - vpunpckhbw(xmm7, xmm2, xmm3); - vpunpcklwd(xmm0, xmm4, xmm6); - vpunpckhwd(xmm1, xmm4, xmm6); - vpunpcklwd(xmm2, xmm5, xmm7); - vpunpckhwd(xmm3, xmm5, xmm7); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm8, ymm8, ymm5); - vmovdqu(xword[B-0x80], xmm0); - vmovdqu(xword[B-0x70], xmm1); - vpmovsxbw(ymm5, xmm2); - vmovhlps(xmm6, xmm2, xmm2); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm9, ymm9, ymm5); - vmovdqu(xword[B-0x60], xmm2); - vmovdqu(xword[B-0x50], xmm3); - vmovdqu(xmm0, xword[A1-0x70]); - vmovdqu(xmm1, xword[A1+LDA*1-0x70]); - vmovdqu(xmm2, xword[A1+LDA*2-0x70]); - vmovdqu(xmm3, xword[A1+LDA3*1-0x70]); - vpunpcklbw(xmm4, xmm0, xmm1); - vpunpckhbw(xmm5, xmm0, xmm1); - vpunpcklbw(xmm6, xmm2, xmm3); - vpunpckhbw(xmm7, xmm2, xmm3); - vpunpcklwd(xmm0, xmm4, xmm6); - vpunpckhwd(xmm1, xmm4, xmm6); - vpunpcklwd(xmm2, xmm5, xmm7); - vpunpckhwd(xmm3, xmm5, xmm7); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm10, ymm10, ymm5); - vmovdqu(xword[B-0x40], xmm0); - vmovdqu(xword[B-0x30], xmm1); - vpmovsxbw(ymm5, xmm2); - vmovhlps(xmm6, xmm2, xmm2); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm11, ymm11, ymm5); - vmovdqu(xword[B-0x20], xmm2); - vmovdqu(xword[B-0x10], xmm3); - vmovdqu(xmm0, xword[A1-0x60]); - vmovdqu(xmm1, xword[A1+LDA*1-0x60]); - vmovdqu(xmm2, xword[A1+LDA*2-0x60]); - vmovdqu(xmm3, xword[A1+LDA3*1-0x60]); - lea(A1, ptr[A1+LDA*4]); - vpunpcklbw(xmm4, xmm0, xmm1); - vpunpckhbw(xmm5, xmm0, xmm1); - vpunpcklbw(xmm6, xmm2, xmm3); - vpunpckhbw(xmm7, xmm2, xmm3); - vpunpcklwd(xmm0, xmm4, xmm6); - vpunpckhwd(xmm1, xmm4, xmm6); - vpunpcklwd(xmm2, xmm5, xmm7); - vpunpckhwd(xmm3, xmm5, xmm7); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm12, ymm12, ymm5); - vmovdqu(xword[B], xmm0); - vmovdqu(xword[B+0x10], xmm1); - vpmovsxbw(ymm5, xmm2); - vmovhlps(xmm6, xmm2, xmm2); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm13, ymm13, ymm5); - vmovdqu(xword[B+0x20], xmm2); - vmovdqu(xword[B+0x30], xmm3); - sub(B, -192); - dec(I); - jg(l5c, T_NEAR); - align(4); - -L(l2a0); - test(M, 0x2); - jle(l3c0, T_NEAR); - vmovdqu(xmm0, xword[A1-0x80]); - vmovdqu(xmm1, xword[A1-0x70]); - vmovdqu(xmm2, xword[A1-0x60]); - add(A1, LDA); - vmovdqu(xmm6, xword[A1-0x80]); - vmovdqu(xmm4, xword[A1-0x70]); - vmovdqu(xmm5, xword[A1-0x60]); - add(A1, LDA); - vpunpcklbw(xmm3, xmm0, xmm6); - vpunpckhbw(xmm0, xmm0, xmm6); - vpmovsxbw(ymm7, xmm3); - vmovhlps(xmm6, xmm3, xmm3); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm7, ymm7, ymm6); - vpmovsxwd(ymm7, xmm7); - vpaddd(ymm8, ymm8, ymm7); - vmovdqu(xword[B-0x80], xmm3); - vpmovsxbw(ymm7, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm7, ymm7, ymm6); - vpmovsxwd(ymm7, xmm7); - vpaddd(ymm9, ymm9, ymm7); - vmovdqu(xword[B-0x70], xmm0); - vpunpcklbw(xmm3, xmm1, xmm4); - vpunpckhbw(xmm0, xmm1, xmm4); - vpmovsxbw(ymm7, xmm3); - vmovhlps(xmm6, xmm3, xmm3); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm7, ymm7, ymm6); - vpmovsxwd(ymm7, xmm7); - vpaddd(ymm10, ymm10, ymm7); - vmovdqu(xword[B-0x60], xmm3); - vpmovsxbw(ymm7, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm7, ymm7, ymm6); - vpmovsxwd(ymm7, xmm7); - vpaddd(ymm11, ymm11, ymm7); - vmovdqu(xword[B-0x50], xmm0); - vpunpcklbw(xmm3, xmm2, xmm5); - vpunpckhbw(xmm0, xmm2, xmm5); - vpmovsxbw(ymm7, xmm3); - vmovhlps(xmm6, xmm3, xmm3); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm7, ymm7, ymm6); - vpmovsxwd(ymm7, xmm7); - vpaddd(ymm12, ymm12, ymm7); - vmovdqu(xword[B-0x40], xmm3); - vpmovsxbw(ymm7, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm7, ymm7, ymm6); - vpmovsxwd(ymm7, xmm7); - vpaddd(ymm13, ymm13, ymm7); - vmovdqu(xword[B-0x30], xmm0); - sub(B, -96); - align(4); - -L(l3c0); - test(M, 0x1); - jle(l438, T_NEAR); - vmovdqu(xmm0, xword[A1-0x80]); - vmovdqu(xmm1, xword[A1-0x70]); - vmovdqu(xmm2, xword[A1-0x60]); - add(A1, LDA); - vpmovsxbd(ymm7, xmm0); - vpaddd(ymm8, ymm8, ymm7); - vmovhlps(xmm7, xmm0, xmm0); - vpmovsxbd(ymm7, xmm7); - vpaddd(ymm9, ymm9, ymm7); - vmovdqu(xword[B-0x80], xmm0); - vpmovsxbd(ymm7, xmm1); - vpaddd(ymm10, ymm10, ymm7); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbd(ymm7, xmm7); - vpaddd(ymm11, ymm11, ymm7); - vmovdqu(xword[B-0x70], xmm1); - vpmovsxbd(ymm7, xmm2); - vpaddd(ymm12, ymm12, ymm7); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbd(ymm7, xmm7); - vpaddd(ymm13, ymm13, ymm7); - vmovdqu(xword[B-0x60], xmm2); - sub(B, -48); - align(4); - -L(l438); - mov(A1, qword[ARG_BIAS]); - vmovdqu(yword[A1], ymm8); - vmovdqu(yword[A1+0x20], ymm9); - vmovdqu(yword[A1+0x40], ymm10); - vmovdqu(yword[A1+0x60], ymm11); - vmovdqu(yword[A1+0x80], ymm12); - vmovdqu(yword[A1+0xa0], ymm13); - add(qword[ARG_BIAS], 0xc0); - sub(N, 0x30); - cmp(N, 0x30); - jge(l20, T_NEAR); - vzeroupper(); - align(4); - -L(l480); - cmp(N, 0x20); - jl(l89c, T_NEAR); - align(4); - -L(l48c); - mov(A1, A); - add(A, 0x20); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - pxor(xmm10, xmm10); - pxor(xmm11, xmm11); - pxor(xmm12, xmm12); - pxor(xmm13, xmm13); - pxor(xmm14, xmm14); - pxor(xmm15, xmm15); - mov(I, M); - sar(I, 0x2); - jle(l6a8, T_NEAR); - align(4); - -L(l4c8); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm4); - pmovsxbw(xmm5, xmm2); - movhlps(xmm6, xmm2); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm2); - movdqu(xmm0, xword[A1-0x70]); - movdqu(xmm1, xword[A1+LDA*1-0x70]); - movdqu(xmm2, xword[A1+LDA*2-0x70]); - movdqu(xmm3, xword[A1+LDA3*1-0x70]); - lea(A1, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm5); - punpckhwd(xmm2, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B-0x40], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B-0x30], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B-0x20], xmm4); - pmovsxbw(xmm5, xmm2); - movhlps(xmm6, xmm2); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B-0x10], xmm2); - sub(B, -128); - dec(I); - jg(l4c8, T_NEAR); - align(4); - -L(l6a8); - test(M, 0x2); - jle(l7b4, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1-0x70]); - add(A1, LDA); - movdqu(xmm2, xword[A1-0x80]); - movdqu(xmm3, xword[A1-0x70]); - add(A1, LDA); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm2); - punpckhbw(xmm4, xmm2); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm4); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x70], xmm4); - movdqa(xmm4, xmm1); - punpcklbw(xmm1, xmm3); - punpckhbw(xmm4, xmm3); - pmovsxbw(xmm5, xmm1); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm13, xmm6); - movdqu(xword[B-0x60], xmm1); - pmovsxbw(xmm5, xmm4); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm15, xmm6); - movdqu(xword[B-0x50], xmm4); - sub(B, -64); - align(4); - -L(l7b4); - test(M, 0x1); - jle(l850, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1-0x70]); - add(A1, LDA); - pmovsxbd(xmm5, xmm0); - paddd(xmm8, xmm5); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm9, xmm6); - pshufd(xmm5, xmm0, 0xaa); - pmovsxbd(xmm5, xmm5); - paddd(xmm10, xmm5); - pshufd(xmm6, xmm0, 0xff); - pmovsxbd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x80], xmm0); - pmovsxbd(xmm5, xmm1); - paddd(xmm12, xmm5); - pshufd(xmm6, xmm1, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm13, xmm6); - pshufd(xmm5, xmm1, 0xaa); - pmovsxbd(xmm5, xmm5); - paddd(xmm14, xmm5); - pshufd(xmm6, xmm1, 0xff); - pmovsxbd(xmm6, xmm6); - paddd(xmm15, xmm6); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l850); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - movdqu(xword[A1+0x20], xmm10); - movdqu(xword[A1+0x30], xmm11); - movdqu(xword[A1+0x40], xmm12); - movdqu(xword[A1+0x50], xmm13); - movdqu(xword[A1+0x60], xmm14); - movdqu(xword[A1+0x70], xmm15); - add(qword[ARG_BIAS], 0x80); - sub(N, 0x20); - cmp(N, 0x20); - jge(l48c, T_NEAR); - align(4); - -L(l89c); - cmp(N, 0x10); - jl(lae8, T_NEAR); - align(4); - -L(l8a8); - mov(A1, A); - add(A, 0x10); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - pxor(xmm10, xmm10); - pxor(xmm11, xmm11); - mov(I, M); - sar(I, 0x2); - jle(l9d0, T_NEAR); - align(4); - -L(l8d0); - movdqu(xmm0, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm1, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm2, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm3, xword[A1-0x80]); - add(A1, LDA); - movdqa(xmm4, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm4, xmm1); - movdqa(xmm1, xmm2); - punpcklbw(xmm2, xmm3); - punpckhbw(xmm1, xmm3); - movdqa(xmm3, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm3, xmm2); - movdqa(xmm2, xmm4); - punpcklwd(xmm4, xmm1); - punpckhwd(xmm2, xmm1); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm3); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - pmovsxbw(xmm5, xmm2); - movhlps(xmm6, xmm2); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x60], xmm4); - movdqu(xword[B-0x50], xmm2); - sub(B, -64); - dec(I); - jg(l8d0, T_NEAR); - align(4); - -L(l9d0); - test(M, 0x2); - jle(la64, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - add(A1, LDA); - movdqu(xmm1, xword[A1-0x80]); - add(A1, LDA); - movdqa(xmm2, xmm0); - punpcklbw(xmm0, xmm1); - punpckhbw(xmm2, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - pmovsxbw(xmm5, xmm2); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movhlps(xmm6, xmm2); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm2); - sub(B, -32); - align(4); - -L(la64); - test(M, 0x1); - jle(lab8, T_NEAR); - movdqu(xmm0, xword[A1-0x80]); - add(A1, LDA); - pmovsxbd(xmm5, xmm0); - paddd(xmm8, xmm5); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm9, xmm6); - pshufd(xmm5, xmm0, 0xaa); - pmovsxbd(xmm5, xmm5); - paddd(xmm10, xmm5); - pshufd(xmm6, xmm0, 0xff); - pmovsxbd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(lab8); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - movdqu(xword[A1+0x20], xmm10); - movdqu(xword[A1+0x30], xmm11); - add(qword[ARG_BIAS], 0x40); - sub(N, 0x10); - cmp(N, 0x10); - jge(l8a8, T_NEAR); - align(4); - -L(lae8); - cmp(N, 0x8); - jl(ld78, T_NEAR); - align(4); - -L(laf4); - mov(A1, A); - add(A, 0x8); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - mov(I, M); - sar(I, 0x3); - jle(lc30, T_NEAR); - align(4); - -L(lb14); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - dec(I); - jg(lb14, T_NEAR); - align(4); - -L(lc30); - test(M, 0x4); - jle(lcc8, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(lcc8); - test(M, 0x2); - jle(ld1c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(ld1c); - test(M, 0x1); - jle(ld54, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - pmovsxbd(xmm5, xmm0); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm8, xmm5); - paddd(xmm9, xmm6); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(ld54); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - add(qword[ARG_BIAS], 0x20); - sub(N, 0x8); - cmp(N, 0x8); - jge(laf4, T_NEAR); - align(4); - -L(ld78); - cmp(N, 0x4); - jl(lf3c, T_NEAR); - align(4); - -L(ld84); - mov(A1, A); - add(A, 0x4); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x3); - jle(le58, T_NEAR); - align(4); - -L(ld9c); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - dec(I); - jg(ld9c, T_NEAR); - align(4); - -L(le58); - test(M, 0x4); - jle(lebc, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(lebc); - test(M, 0x2); - jle(lef8, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(lef8); - test(M, 0x1); - jle(lf1c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(lf1c); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm7); - add(qword[ARG_BIAS], 0x10); - sub(N, 0x4); - cmp(N, 0x4); - jge(ld84, T_NEAR); - align(4); - -L(lf3c); - cmp(N, 0x2); - jl(l111a, T_NEAR); - align(4); - -L(lf48); - mov(A1, A); - add(A, 0x2); - pxor(xmm7, xmm7); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l1024, T_NEAR); - align(4); - -L(lf60); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm4, eax, 0x0); - punpcklbw(xmm1, xmm2); - punpcklbw(xmm3, xmm4); - punpcklwd(xmm1, xmm3); - punpcklqdq(xmm0, xmm1); - pshufd(xmm6, xmm0, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(LDA3); - jg(lf60, T_NEAR); - align(4); - -L(l1024); - test(M, 0x4); - jle(l1090, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l1090); - test(M, 0x2); - jle(l10d4, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - punpcklbw(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l10d4); - test(M, 0x1); - jle(l10fc, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - mov(word[B-0x80], ax); - sub(B, -2); - align(4); - -L(l10fc); - mov(A1, qword[ARG_BIAS]); - movq(qword[A1], xmm7); - add(qword[ARG_BIAS], 0x8); - sub(N, 0x2); - cmp(N, 0x2); - jge(lf48, T_NEAR); - align(4); - -L(l111a); - cmp(N, 0x1); - jl(l12bc, T_NEAR); - align(4); - -L(l1124); - mov(A1, A); - add(A, 0x1); - pxor(xmm7, xmm7); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l11d4, T_NEAR); - align(4); - -L(l113c); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - dec(LDA3); - jg(l113c, T_NEAR); - align(4); - -L(l11d4); - test(M, 0x4); - jle(l1234, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l1234); - test(M, 0x2); - jle(l1278, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(byte[B-0x80], al); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - mov(byte[B-0x7f], al); - sub(B, -2); - align(4); - -L(l1278); - test(M, 0x1); - jle(l129c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l129c); - mov(A1, qword[ARG_BIAS]); - movd(dword[A1], xmm7); - add(qword[ARG_BIAS], 0x4); - sub(N, 0x1); - cmp(N, 0x1); - jge(l1124, T_NEAR); - align(4); - -L(l12bc); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -#undef ARG_BIAS -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp deleted file mode 100644 index a4f4ff09c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp +++ /dev/null @@ -1,3163 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#define ARG_BIAS 24+stacksize+rsp - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp -#define ARG_BIAS 72+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l1750; -Xbyak::Label l1b6c; -Xbyak::Label l1e14; -Xbyak::Label l20; -Xbyak::Label l2068; -Xbyak::Label l226c; -Xbyak::Label l22b8; -Xbyak::Label l22c4; -Xbyak::Label l22f4; -Xbyak::Label l26b4; -Xbyak::Label l28cc; -Xbyak::Label l2a2c; -Xbyak::Label l2b5c; -Xbyak::Label l2c64; -Xbyak::Label l2c94; -Xbyak::Label l2ca0; -Xbyak::Label l2cc8; -Xbyak::Label l2eac; -Xbyak::Label l2fc0; -Xbyak::Label l3078; -Xbyak::Label l3118; -Xbyak::Label l319c; -Xbyak::Label l31c0; -Xbyak::Label l31cc; -Xbyak::Label l31ec; -Xbyak::Label l32e4; -Xbyak::Label l3378; -Xbyak::Label l33dc; -Xbyak::Label l3434; -Xbyak::Label l347c; -Xbyak::Label l349c; -Xbyak::Label l34a8; -Xbyak::Label l34c8; -Xbyak::Label l3558; -Xbyak::Label l35b0; -Xbyak::Label l35f4; -Xbyak::Label l3638; -Xbyak::Label l366c; -Xbyak::Label l368a; -Xbyak::Label l3694; -Xbyak::Label l36a8; -Xbyak::Label l36ec; -Xbyak::Label l3728; -Xbyak::Label l3760; -Xbyak::Label l3794; -Xbyak::Label l37b8; -Xbyak::Label l37d8; -Xbyak::Label l5cc; -Xbyak::Label l6c; -Xbyak::Label l968; -Xbyak::Label lc80; -Xbyak::Label lf1c; -Xbyak::Label lf64; -Xbyak::Label lf70; -Xbyak::Label lfb4; - - preamble(); - auto stacksize = get_size_of_abi_save_regs(); -#ifdef _WIN32 - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(N, qword[N]); - mov(M, qword[M]); - mov(LDA, qword[LDA]); - sub(A, -128); - sub(B, -128); - lea(LDA3, ptr[LDA+LDA*2]); - cmp(N, 0x30); - jl(lf64, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - mov(I, LDA); - shl(I, 0x5); - lea(I, ptr[I+LDA*8]); - lea(I, ptr[I+LDA*8]); - add(A, I); - vxorps(ymm8, ymm8, ymm8); - vxorps(ymm9, ymm9, ymm9); - vxorps(ymm10, ymm10, ymm10); - vxorps(ymm11, ymm11, ymm11); - vxorps(ymm12, ymm12, ymm12); - vxorps(ymm13, ymm13, ymm13); - vxorps(ymm14, ymm14, ymm14); - vxorps(ymm15, ymm15, ymm15); - mov(I, M); - sar(I, 0x3); - jle(l5cc, T_NEAR); - align(4); - -L(l6c); - vmovq(xmm0, qword[A1-0x80]); - vmovq(xmm1, qword[A1+LDA*1-0x80]); - vmovq(xmm2, qword[A1+LDA*2-0x80]); - vmovq(xmm3, qword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - vpunpckldq(xmm1, xmm0, xmm1); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm1, xmm3); - vpunpckhqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x80], xmm0); - vmovdqu(xword[B+0x40], xmm1); - vmovq(xmm2, qword[A2-0x80]); - vmovq(xmm3, qword[A2+LDA*1-0x80]); - vmovq(xmm4, qword[A2+LDA*2-0x80]); - vmovq(xmm5, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpckldq(xmm5, xmm4, xmm5); - vpunpcklqdq(xmm2, xmm3, xmm5); - vpunpckhqdq(xmm3, xmm3, xmm5); - vmovdqu(xword[B-0x70], xmm2); - vmovdqu(xword[B+0x50], xmm3); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm2); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm8, ymm8, ymm5); - vpmovsxbw(ymm5, xmm1); - vmovhlps(xmm6, xmm1, xmm1); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm8, ymm8, ymm5); - vmovq(xmm0, qword[A2-0x80]); - vmovq(xmm1, qword[A2+LDA*1-0x80]); - vmovq(xmm2, qword[A2+LDA*2-0x80]); - vmovq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm0, xmm1); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm1, xmm3); - vpunpckhqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x60], xmm0); - vmovdqu(xword[B+0x60], xmm1); - vmovq(xmm2, qword[A2-0x80]); - vmovq(xmm3, qword[A2+LDA*1-0x80]); - vmovq(xmm4, qword[A2+LDA*2-0x80]); - vmovq(xmm5, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpckldq(xmm5, xmm4, xmm5); - vpunpcklqdq(xmm2, xmm3, xmm5); - vpunpckhqdq(xmm3, xmm3, xmm5); - vmovdqu(xword[B-0x50], xmm2); - vmovdqu(xword[B+0x70], xmm3); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm2); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm9, ymm9, ymm5); - vpmovsxbw(ymm5, xmm1); - vmovhlps(xmm6, xmm1, xmm1); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm9, ymm9, ymm5); - vmovq(xmm0, qword[A2-0x80]); - vmovq(xmm1, qword[A2+LDA*1-0x80]); - vmovq(xmm2, qword[A2+LDA*2-0x80]); - vmovq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm0, xmm1); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm1, xmm3); - vpunpckhqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x40], xmm0); - vmovdqu(xword[B+0x80], xmm1); - vmovq(xmm2, qword[A2-0x80]); - vmovq(xmm3, qword[A2+LDA*1-0x80]); - vmovq(xmm4, qword[A2+LDA*2-0x80]); - vmovq(xmm5, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpckldq(xmm5, xmm4, xmm5); - vpunpcklqdq(xmm2, xmm3, xmm5); - vpunpckhqdq(xmm3, xmm3, xmm5); - vmovdqu(xword[B-0x30], xmm2); - vmovdqu(xword[B+0x90], xmm3); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm2); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm10, ymm10, ymm5); - vpmovsxbw(ymm5, xmm1); - vmovhlps(xmm6, xmm1, xmm1); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm10, ymm10, ymm5); - vmovq(xmm0, qword[A2-0x80]); - vmovq(xmm1, qword[A2+LDA*1-0x80]); - vmovq(xmm2, qword[A2+LDA*2-0x80]); - vmovq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm0, xmm1); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm1, xmm3); - vpunpckhqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x20], xmm0); - vmovdqu(xword[B+0xa0], xmm1); - vmovq(xmm2, qword[A2-0x80]); - vmovq(xmm3, qword[A2+LDA*1-0x80]); - vmovq(xmm4, qword[A2+LDA*2-0x80]); - vmovq(xmm5, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpckldq(xmm5, xmm4, xmm5); - vpunpcklqdq(xmm2, xmm3, xmm5); - vpunpckhqdq(xmm3, xmm3, xmm5); - vmovdqu(xword[B-0x10], xmm2); - vmovdqu(xword[B+0xb0], xmm3); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm2); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm11, ymm11, ymm5); - vpmovsxbw(ymm5, xmm1); - vmovhlps(xmm6, xmm1, xmm1); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm11, ymm11, ymm5); - vmovq(xmm0, qword[A2-0x80]); - vmovq(xmm1, qword[A2+LDA*1-0x80]); - vmovq(xmm2, qword[A2+LDA*2-0x80]); - vmovq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm0, xmm1); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm1, xmm3); - vpunpckhqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B], xmm0); - vmovdqu(xword[B+0xc0], xmm1); - vmovq(xmm2, qword[A2-0x80]); - vmovq(xmm3, qword[A2+LDA*1-0x80]); - vmovq(xmm4, qword[A2+LDA*2-0x80]); - vmovq(xmm5, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpckldq(xmm5, xmm4, xmm5); - vpunpcklqdq(xmm2, xmm3, xmm5); - vpunpckhqdq(xmm3, xmm3, xmm5); - vmovdqu(xword[B+0x10], xmm2); - vmovdqu(xword[B+0xd0], xmm3); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm2); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm12, ymm12, ymm5); - vpmovsxbw(ymm5, xmm1); - vmovhlps(xmm6, xmm1, xmm1); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm12, ymm12, ymm5); - vmovq(xmm0, qword[A2-0x80]); - vmovq(xmm1, qword[A2+LDA*1-0x80]); - vmovq(xmm2, qword[A2+LDA*2-0x80]); - vmovq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm0, xmm1); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm1, xmm3); - vpunpckhqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B+0x20], xmm0); - vmovdqu(xword[B+0xe0], xmm1); - vmovq(xmm2, qword[A2-0x80]); - vmovq(xmm3, qword[A2+LDA*1-0x80]); - vmovq(xmm4, qword[A2+LDA*2-0x80]); - vmovq(xmm5, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm3, xmm2, xmm3); - vpunpckldq(xmm5, xmm4, xmm5); - vpunpcklqdq(xmm2, xmm3, xmm5); - vpunpckhqdq(xmm3, xmm3, xmm5); - vmovdqu(xword[B+0x30], xmm2); - vmovdqu(xword[B+0xf0], xmm3); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm2); - vmovhlps(xmm7, xmm2, xmm2); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm13, ymm13, ymm5); - vpmovsxbw(ymm5, xmm1); - vmovhlps(xmm6, xmm1, xmm1); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm3); - vmovhlps(xmm7, xmm3, xmm3); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm13, ymm13, ymm5); - sub(A1, -8); - sub(B, -384); - dec(I); - jg(l6c, T_NEAR); - align(4); - -L(l5cc); - test(M, 0x4); - jle(l968, T_NEAR); - vmovd(xmm0, dword[A1-0x80]); - vmovd(xmm1, dword[A1+LDA*1-0x80]); - vmovd(xmm2, dword[A1+LDA*2-0x80]); - vmovd(xmm3, dword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - vpunpckldq(xmm0, xmm0, xmm1); - vpunpckldq(xmm2, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm0, xmm2); - vmovdqu(xword[B-0x80], xmm0); - vmovd(xmm1, dword[A2-0x80]); - vmovd(xmm2, dword[A2+LDA*1-0x80]); - vmovd(xmm3, dword[A2+LDA*2-0x80]); - vmovd(xmm4, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm1, xmm2); - vpunpckldq(xmm3, xmm3, xmm4); - vpunpcklqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x70], xmm1); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm8, ymm8, ymm5); - vmovd(xmm0, dword[A2-0x80]); - vmovd(xmm1, dword[A2+LDA*1-0x80]); - vmovd(xmm2, dword[A2+LDA*2-0x80]); - vmovd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm0, xmm0, xmm1); - vpunpckldq(xmm2, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm0, xmm2); - vmovdqu(xword[B-0x60], xmm0); - vmovd(xmm1, dword[A2-0x80]); - vmovd(xmm2, dword[A2+LDA*1-0x80]); - vmovd(xmm3, dword[A2+LDA*2-0x80]); - vmovd(xmm4, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm1, xmm2); - vpunpckldq(xmm3, xmm3, xmm4); - vpunpcklqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x50], xmm1); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm9, ymm9, ymm5); - vmovd(xmm0, dword[A2-0x80]); - vmovd(xmm1, dword[A2+LDA*1-0x80]); - vmovd(xmm2, dword[A2+LDA*2-0x80]); - vmovd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm0, xmm0, xmm1); - vpunpckldq(xmm2, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm0, xmm2); - vmovdqu(xword[B-0x40], xmm0); - vmovd(xmm1, dword[A2-0x80]); - vmovd(xmm2, dword[A2+LDA*1-0x80]); - vmovd(xmm3, dword[A2+LDA*2-0x80]); - vmovd(xmm4, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm1, xmm2); - vpunpckldq(xmm3, xmm3, xmm4); - vpunpcklqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x30], xmm1); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm10, ymm10, ymm5); - vmovd(xmm0, dword[A2-0x80]); - vmovd(xmm1, dword[A2+LDA*1-0x80]); - vmovd(xmm2, dword[A2+LDA*2-0x80]); - vmovd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm0, xmm0, xmm1); - vpunpckldq(xmm2, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm0, xmm2); - vmovdqu(xword[B-0x20], xmm0); - vmovd(xmm1, dword[A2-0x80]); - vmovd(xmm2, dword[A2+LDA*1-0x80]); - vmovd(xmm3, dword[A2+LDA*2-0x80]); - vmovd(xmm4, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm1, xmm2); - vpunpckldq(xmm3, xmm3, xmm4); - vpunpcklqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B-0x10], xmm1); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm11, ymm11, ymm5); - vmovd(xmm0, dword[A2-0x80]); - vmovd(xmm1, dword[A2+LDA*1-0x80]); - vmovd(xmm2, dword[A2+LDA*2-0x80]); - vmovd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm0, xmm0, xmm1); - vpunpckldq(xmm2, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm0, xmm2); - vmovdqu(xword[B], xmm0); - vmovd(xmm1, dword[A2-0x80]); - vmovd(xmm2, dword[A2+LDA*1-0x80]); - vmovd(xmm3, dword[A2+LDA*2-0x80]); - vmovd(xmm4, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm1, xmm2); - vpunpckldq(xmm3, xmm3, xmm4); - vpunpcklqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B+0x10], xmm1); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm12, ymm12, ymm5); - vmovd(xmm0, dword[A2-0x80]); - vmovd(xmm1, dword[A2+LDA*1-0x80]); - vmovd(xmm2, dword[A2+LDA*2-0x80]); - vmovd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm0, xmm0, xmm1); - vpunpckldq(xmm2, xmm2, xmm3); - vpunpcklqdq(xmm0, xmm0, xmm2); - vmovdqu(xword[B+0x20], xmm0); - vmovd(xmm1, dword[A2-0x80]); - vmovd(xmm2, dword[A2+LDA*1-0x80]); - vmovd(xmm3, dword[A2+LDA*2-0x80]); - vmovd(xmm4, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpunpckldq(xmm1, xmm1, xmm2); - vpunpckldq(xmm3, xmm3, xmm4); - vpunpcklqdq(xmm1, xmm1, xmm3); - vmovdqu(xword[B+0x30], xmm1); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxbw(ymm6, xmm1); - vmovhlps(xmm7, xmm1, xmm1); - vpmovsxbw(ymm7, xmm7); - vphaddw(ymm6, ymm6, ymm7); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm13, ymm13, ymm5); - sub(A1, -4); - sub(B, -192); - align(4); - -L(l968); - test(M, 0x2); - jle(lc80, T_NEAR); - mov(ax, word[A1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x7); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm8, ymm8, ymm5); - vmovdqu(xword[B-0x80], xmm0); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm9, ymm9, ymm5); - vmovdqu(xword[B-0x70], xmm0); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm10, ymm10, ymm5); - vmovdqu(xword[B-0x60], xmm0); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm11, ymm11, ymm5); - vmovdqu(xword[B-0x50], xmm0); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm12, ymm12, ymm5); - vmovdqu(xword[B-0x40], xmm0); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrw(xmm0, xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - vpinsrw(xmm0, xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - vpmovsxbw(ymm5, xmm0); - vmovhlps(xmm6, xmm0, xmm0); - vpmovsxbw(ymm6, xmm6); - vphaddw(ymm5, ymm5, ymm6); - vpmovsxwd(ymm5, xmm5); - vpaddd(ymm13, ymm13, ymm5); - vmovdqu(xword[B-0x30], xmm0); - sub(A1, -2); - sub(B, -96); - align(4); - -L(lc80); - test(M, 0x1); - jle(lf1c, T_NEAR); - mov(al, byte[A1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0xf); - vpmovsxbd(ymm7, xmm0); - vpaddd(ymm8, ymm8, ymm7); - vmovhlps(xmm7, xmm0, xmm0); - vpmovsxbd(ymm7, xmm7); - vpaddd(ymm9, ymm9, ymm7); - vmovdqu(xword[B-0x80], xmm0); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x0); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x1); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x2); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0xf); - vpmovsxbd(ymm7, xmm0); - vpaddd(ymm10, ymm10, ymm7); - vmovhlps(xmm7, xmm0, xmm0); - vpmovsxbd(ymm7, xmm7); - vpaddd(ymm11, ymm11, ymm7); - vmovdqu(xword[B-0x70], xmm0); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x0); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x1); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x2); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - vpinsrb(xmm0, xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - vpinsrb(xmm0, xmm0, eax, 0xf); - vpmovsxbd(ymm7, xmm0); - vpaddd(ymm12, ymm12, ymm7); - vmovhlps(xmm7, xmm0, xmm0); - vpmovsxbd(ymm7, xmm7); - vpaddd(ymm13, ymm13, ymm7); - vmovdqu(xword[B-0x60], xmm0); - sub(B, -48); - align(4); - -L(lf1c); - mov(A1, qword[ARG_BIAS]); - vmovdqu(yword[A1], ymm8); - vmovdqu(yword[A1+0x20], ymm9); - vmovdqu(yword[A1+0x40], ymm10); - vmovdqu(yword[A1+0x60], ymm11); - vmovdqu(yword[A1+0x80], ymm12); - vmovdqu(yword[A1+0xa0], ymm13); - add(qword[ARG_BIAS], 0xc0); - sub(N, 0x30); - cmp(N, 0x30); - jge(l20, T_NEAR); - vzeroupper(); - align(4); - -L(lf64); - cmp(N, 0x20); - jl(l22b8, T_NEAR); - align(4); - -L(lf70); - mov(A1, A); - mov(I, LDA); - shl(I, 0x5); - add(A, I); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - pxor(xmm10, xmm10); - pxor(xmm11, xmm11); - pxor(xmm12, xmm12); - pxor(xmm13, xmm13); - pxor(xmm14, xmm14); - pxor(xmm15, xmm15); - mov(I, M); - sar(I, 0x4); - jle(l1750, T_NEAR); - align(4); - -L(lfb4); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B+0x80], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B+0x100], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B+0x10], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B+0x90], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B+0x110], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B+0x20], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B+0xa0], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B+0x120], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B+0x30], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B+0xb0], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B+0x130], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B-0x40], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B+0x40], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B+0xc0], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B+0x140], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B-0x30], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B+0x50], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B+0xd0], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B+0x150], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B-0x20], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B+0x60], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B+0xe0], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B+0x160], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B-0x10], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B+0x70], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B+0xf0], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B+0x170], xmm3); - sub(A1, -16); - sub(B, -512); - dec(I); - jg(lfb4, T_NEAR); - align(4); - -L(l1750); - test(M, 0x8); - jle(l1b6c, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B+0x10], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B+0x20], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B+0x30], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B-0x40], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B+0x40], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B-0x30], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B+0x50], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B-0x20], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B+0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B-0x10], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B+0x70], xmm1); - sub(A1, -8); - sub(B, -256); - align(4); - -L(l1b6c); - test(M, 0x4); - jle(l1e14, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movdqu(xword[B-0x40], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm13, xmm5); - movdqu(xword[B-0x30], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movdqu(xword[B-0x20], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm15, xmm5); - movdqu(xword[B-0x10], xmm0); - sub(A1, -4); - sub(B, -128); - align(4); - -L(l1e14); - test(M, 0x2); - jle(l2068, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x70], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm12, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm13, xmm6); - movdqu(xword[B-0x60], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - lea(A2, ptr[A2+LDA*4]); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm14, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm15, xmm6); - movdqu(xword[B-0x50], xmm0); - sub(A1, -2); - sub(B, -64); - align(4); - -L(l2068); - test(M, 0x1); - jle(l226c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - pmovsxbd(xmm5, xmm0); - paddd(xmm8, xmm5); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm9, xmm6); - pshufd(xmm5, xmm0, 0xaa); - pmovsxbd(xmm5, xmm5); - paddd(xmm10, xmm5); - pshufd(xmm6, xmm0, 0xff); - pmovsxbd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x80], xmm0); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xf); - pmovsxbd(xmm5, xmm0); - paddd(xmm12, xmm5); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm13, xmm6); - pshufd(xmm5, xmm0, 0xaa); - pmovsxbd(xmm5, xmm5); - paddd(xmm14, xmm5); - pshufd(xmm6, xmm0, 0xff); - pmovsxbd(xmm6, xmm6); - paddd(xmm15, xmm6); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - align(4); - -L(l226c); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - movdqu(xword[A1+0x20], xmm10); - movdqu(xword[A1+0x30], xmm11); - movdqu(xword[A1+0x40], xmm12); - movdqu(xword[A1+0x50], xmm13); - movdqu(xword[A1+0x60], xmm14); - movdqu(xword[A1+0x70], xmm15); - add(qword[ARG_BIAS], 0x80); - sub(N, 0x20); - cmp(N, 0x20); - jge(lf70, T_NEAR); - align(4); - -L(l22b8); - cmp(N, 0x10); - jl(l2c94, T_NEAR); - align(4); - -L(l22c4); - mov(A1, A); - mov(I, LDA); - shl(I, 0x4); - add(A, I); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - pxor(xmm10, xmm10); - pxor(xmm11, xmm11); - mov(I, M); - sar(I, 0x4); - jle(l26b4, T_NEAR); - align(4); - -L(l22f4); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x40], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B+0x40], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x30], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B+0x10], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B+0x50], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x20], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B+0x20], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B+0x60], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x10], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B+0x30], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B+0x70], xmm3); - sub(A1, -16); - sub(B, -256); - dec(I); - jg(l22f4, T_NEAR); - align(4); - -L(l26b4); - test(M, 0x8); - jle(l28cc, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x40], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x30], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x20], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x10], xmm1); - sub(A1, -8); - sub(B, -128); - align(4); - -L(l28cc); - test(M, 0x4); - jle(l2a2c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movdqu(xword[B-0x60], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm11, xmm5); - movdqu(xword[B-0x50], xmm0); - sub(A1, -4); - sub(B, -64); - align(4); - -L(l2a2c); - test(M, 0x2); - jle(l2b5c, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - pinsrw(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm10, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x70], xmm0); - sub(A1, -2); - sub(B, -32); - align(4); - -L(l2b5c); - test(M, 0x1); - jle(l2c64, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - lea(A2, ptr[A1+LDA*4]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0x7); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x8); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x9); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xa); - mov(al, byte[A2+LDA3*1-0x80]); - lea(A2, ptr[A2+LDA*4]); - pinsrb(xmm0, eax, 0xb); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0xc); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0xd); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0xe); - mov(al, byte[A2+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0xf); - pmovsxbd(xmm5, xmm0); - paddd(xmm8, xmm5); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm9, xmm6); - pshufd(xmm5, xmm0, 0xaa); - pmovsxbd(xmm5, xmm5); - paddd(xmm10, xmm5); - pshufd(xmm6, xmm0, 0xff); - pmovsxbd(xmm6, xmm6); - paddd(xmm11, xmm6); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l2c64); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - movdqu(xword[A1+0x20], xmm10); - movdqu(xword[A1+0x30], xmm11); - add(qword[ARG_BIAS], 0x40); - sub(N, 0x10); - cmp(N, 0x10); - jge(l22c4, T_NEAR); - align(4); - -L(l2c94); - cmp(N, 0x8); - jl(l31c0, T_NEAR); - align(4); - -L(l2ca0); - mov(A1, A); - lea(A2, ptr[A1+LDA*4]); - lea(I, ptr[A1+LDA*8]); - mov(A, I); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - mov(I, M); - sar(I, 0x4); - jle(l2eac, T_NEAR); - align(4); - -L(l2cc8); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - sub(A1, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x60], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x40], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x20], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x50], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x30], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x10], xmm3); - sub(B, -128); - dec(I); - jg(l2cc8, T_NEAR); - align(4); - -L(l2eac); - test(M, 0x8); - jle(l2fc0, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - sub(A1, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - align(4); - -L(l2fc0); - test(M, 0x4); - jle(l3078, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - sub(A1, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - align(4); - -L(l3078); - test(M, 0x2); - jle(l3118, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l3118); - test(M, 0x1); - jle(l319c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x7); - pmovsxbd(xmm5, xmm0); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm8, xmm5); - paddd(xmm9, xmm6); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l319c); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - add(qword[ARG_BIAS], 0x20); - sub(N, 0x8); - cmp(N, 0x8); - jge(l2ca0, T_NEAR); - align(4); - -L(l31c0); - cmp(N, 0x4); - jl(l349c, T_NEAR); - align(4); - -L(l31cc); - mov(A1, A); - lea(A2, ptr[A1+LDA*2]); - lea(I, ptr[A1+LDA*4]); - mov(A, I); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x4); - jle(l32e4, T_NEAR); - align(4); - -L(l31ec); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - sub(A1, -16); - movdqu(xmm2, xword[A2-0x80]); - movdqu(xmm3, xword[A2+LDA*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x60], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x50], xmm3); - sub(B, -64); - dec(I); - jg(l31ec, T_NEAR); - align(4); - -L(l32e4); - test(M, 0x8); - jle(l3378, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - sub(A1, -8); - movq(xmm2, qword[A2-0x80]); - movq(xmm3, qword[A2+LDA*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l3378); - test(M, 0x4); - jle(l33dc, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - sub(A1, -4); - movd(xmm2, dword[A2-0x80]); - movd(xmm3, dword[A2+LDA*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l33dc); - test(M, 0x2); - jle(l3434, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x3); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l3434); - test(M, 0x1); - jle(l347c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x3); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l347c); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm7); - add(qword[ARG_BIAS], 0x10); - sub(N, 0x4); - cmp(N, 0x4); - jge(l31cc, T_NEAR); - align(4); - -L(l349c); - cmp(N, 0x2); - jl(l368a, T_NEAR); - align(4); - -L(l34a8); - mov(A1, A); - lea(A2, ptr[A1+LDA*1]); - lea(I, ptr[A1+LDA*2]); - mov(A, I); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x4); - jle(l3558, T_NEAR); - align(4); - -L(l34c8); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - movdqu(xmm1, xword[A2-0x80]); - sub(A2, -16); - movdqa(xmm2, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm2, xmm1); - pshufd(xmm6, xmm0, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - pshufd(xmm6, xmm2, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm2); - sub(B, -32); - dec(I); - jg(l34c8, T_NEAR); - align(4); - -L(l3558); - test(M, 0x8); - jle(l35b0, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - movq(xmm1, qword[A2-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - pshufd(xmm6, xmm0, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l35b0); - test(M, 0x4); - jle(l35f4, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - movd(xmm1, dword[A2-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l35f4); - test(M, 0x2); - jle(l3638, T_NEAR); - mov(ax, word[A1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l3638); - test(M, 0x1); - jle(l366c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(byte[B-0x80], al); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(byte[B-0x7f], al); - sub(B, -2); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - align(4); - -L(l366c); - mov(A1, qword[ARG_BIAS]); - movq(qword[A1], xmm7); - add(qword[ARG_BIAS], 0x8); - sub(N, 0x2); - cmp(N, 0x2); - jge(l34a8, T_NEAR); - align(4); - -L(l368a); - cmp(N, 0x1); - jl(l37d8, T_NEAR); - align(4); - -L(l3694); - mov(A1, A); - add(A, LDA); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x4); - jle(l36ec, T_NEAR); - align(4); - -L(l36a8); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(I); - jg(l36a8, T_NEAR); - align(4); - -L(l36ec); - test(M, 0x8); - jle(l3728, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l3728); - test(M, 0x4); - jle(l3760, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l3760); - test(M, 0x2); - jle(l3794, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - mov(word[B-0x80], ax); - sub(A1, -2); - sub(B, -2); - align(4); - -L(l3794); - test(M, 0x1); - jle(l37b8, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l37b8); - mov(A1, qword[ARG_BIAS]); - movd(dword[A1], xmm7); - add(qword[ARG_BIAS], 0x4); - sub(N, 0x1); - cmp(N, 0x1); - jge(l3694, T_NEAR); - align(4); - -L(l37d8); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -#undef ARG_BIAS -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp deleted file mode 100644 index c7f1393c9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp +++ /dev/null @@ -1,821 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#define ARG_BIAS 24+stacksize+rsp - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp -#define ARG_BIAS 72+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l20; -Xbyak::Label l22c; -Xbyak::Label l340; -Xbyak::Label l3f8; -Xbyak::Label l48; -Xbyak::Label l498; -Xbyak::Label l51c; -Xbyak::Label l540; -Xbyak::Label l54c; -Xbyak::Label l56c; -Xbyak::Label l664; -Xbyak::Label l6f8; -Xbyak::Label l75c; -Xbyak::Label l7b4; -Xbyak::Label l7fc; -Xbyak::Label l81c; -Xbyak::Label l828; -Xbyak::Label l848; -Xbyak::Label l8d8; -Xbyak::Label l930; -Xbyak::Label l974; -Xbyak::Label l9b8; -Xbyak::Label l9ec; -Xbyak::Label la0a; -Xbyak::Label la14; -Xbyak::Label la28; -Xbyak::Label la6c; -Xbyak::Label laa8; -Xbyak::Label lae0; -Xbyak::Label lb14; -Xbyak::Label lb38; -Xbyak::Label lb58; - - preamble(); - auto stacksize = get_size_of_abi_save_regs(); -#ifdef _WIN32 - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(N, qword[N]); - mov(M, qword[M]); - mov(LDA, qword[LDA]); - sub(A, -128); - sub(B, -128); - lea(LDA3, ptr[LDA+LDA*2]); - cmp(N, 0x8); - jl(l540, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - lea(A2, ptr[A1+LDA*4]); - lea(I, ptr[A1+LDA*8]); - mov(A, I); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - mov(I, M); - sar(I, 0x4); - jle(l22c, T_NEAR); - align(4); - -L(l48); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - movdqu(xmm2, xword[A1+LDA*2-0x80]); - movdqu(xmm3, xword[A1+LDA3*1-0x80]); - sub(A1, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x60], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x40], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x20], xmm3); - movdqu(xmm0, xword[A2-0x80]); - movdqu(xmm1, xword[A2+LDA*1-0x80]); - movdqu(xmm2, xword[A2+LDA*2-0x80]); - movdqu(xmm3, xword[A2+LDA3*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x50], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x30], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x10], xmm3); - sub(B, -128); - dec(I); - jg(l48, T_NEAR); - align(4); - -L(l22c); - test(M, 0x8); - jle(l340, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - movq(xmm2, qword[A1+LDA*2-0x80]); - movq(xmm3, qword[A1+LDA3*1-0x80]); - sub(A1, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x60], xmm1); - movq(xmm0, qword[A2-0x80]); - movq(xmm1, qword[A2+LDA*1-0x80]); - movq(xmm2, qword[A2+LDA*2-0x80]); - movq(xmm3, qword[A2+LDA3*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - align(4); - -L(l340); - test(M, 0x4); - jle(l3f8, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - movd(xmm2, dword[A1+LDA*2-0x80]); - movd(xmm3, dword[A1+LDA3*1-0x80]); - sub(A1, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A2-0x80]); - movd(xmm1, dword[A2+LDA*1-0x80]); - movd(xmm2, dword[A2+LDA*2-0x80]); - movd(xmm3, dword[A2+LDA3*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - align(4); - -L(l3f8); - test(M, 0x2); - jle(l498, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A1+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A1+LDA3*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x3); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x4); - mov(ax, word[A2+LDA*1-0x80]); - pinsrw(xmm0, eax, 0x5); - mov(ax, word[A2+LDA*2-0x80]); - pinsrw(xmm0, eax, 0x6); - mov(ax, word[A2+LDA3*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l498); - test(M, 0x1); - jle(l51c, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A2+LDA*2-0x80]); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A2+LDA3*1-0x80]); - pinsrb(xmm0, eax, 0x7); - pmovsxbd(xmm5, xmm0); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm8, xmm5); - paddd(xmm9, xmm6); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l51c); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - add(qword[ARG_BIAS], 0x20); - sub(N, 0x8); - cmp(N, 0x8); - jge(l20, T_NEAR); - align(4); - -L(l540); - cmp(N, 0x4); - jl(l81c, T_NEAR); - align(4); - -L(l54c); - mov(A1, A); - lea(A2, ptr[A1+LDA*2]); - lea(I, ptr[A1+LDA*4]); - mov(A, I); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x4); - jle(l664, T_NEAR); - align(4); - -L(l56c); - movdqu(xmm0, xword[A1-0x80]); - movdqu(xmm1, xword[A1+LDA*1-0x80]); - sub(A1, -16); - movdqu(xmm2, xword[A2-0x80]); - movdqu(xmm3, xword[A2+LDA*1-0x80]); - sub(A2, -16); - movdqa(xmm4, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm4, xmm1); - movdqa(xmm5, xmm2); - punpckldq(xmm2, xmm3); - punpckhdq(xmm5, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - movdqa(xmm3, xmm4); - punpcklqdq(xmm4, xmm5); - punpckhqdq(xmm3, xmm5); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm1); - pmovsxbw(xmm5, xmm4); - movhlps(xmm6, xmm4); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x60], xmm4); - pmovsxbw(xmm5, xmm3); - movhlps(xmm6, xmm3); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x50], xmm3); - sub(B, -64); - dec(I); - jg(l56c, T_NEAR); - align(4); - -L(l664); - test(M, 0x8); - jle(l6f8, T_NEAR); - movq(xmm0, qword[A1-0x80]); - movq(xmm1, qword[A1+LDA*1-0x80]); - sub(A1, -8); - movq(xmm2, qword[A2-0x80]); - movq(xmm3, qword[A2+LDA*1-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklqdq(xmm0, xmm2); - punpckhqdq(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l6f8); - test(M, 0x4); - jle(l75c, T_NEAR); - movd(xmm0, dword[A1-0x80]); - movd(xmm1, dword[A1+LDA*1-0x80]); - sub(A1, -4); - movd(xmm2, dword[A2-0x80]); - movd(xmm3, dword[A2+LDA*1-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - punpckldq(xmm2, xmm3); - punpcklqdq(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l75c); - test(M, 0x2); - jle(l7b4, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1+LDA*1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x1); - mov(ax, word[A2-0x80]); - pinsrw(xmm0, eax, 0x2); - mov(ax, word[A2+LDA*1-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x3); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l7b4); - test(M, 0x1); - jle(l7fc, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A2+LDA*1-0x80]); - pinsrb(xmm0, eax, 0x3); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l7fc); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm7); - add(qword[ARG_BIAS], 0x10); - sub(N, 0x4); - cmp(N, 0x4); - jge(l54c, T_NEAR); - align(4); - -L(l81c); - cmp(N, 0x2); - jl(la0a, T_NEAR); - align(4); - -L(l828); - mov(A1, A); - lea(A2, ptr[A1+LDA*1]); - lea(I, ptr[A1+LDA*2]); - mov(A, I); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x4); - jle(l8d8, T_NEAR); - align(4); - -L(l848); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - movdqu(xmm1, xword[A2-0x80]); - sub(A2, -16); - movdqa(xmm2, xmm0); - punpckldq(xmm0, xmm1); - punpckhdq(xmm2, xmm1); - pshufd(xmm6, xmm0, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - pshufd(xmm6, xmm2, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm2); - sub(B, -32); - dec(I); - jg(l848, T_NEAR); - align(4); - -L(l8d8); - test(M, 0x8); - jle(l930, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - movq(xmm1, qword[A2-0x80]); - sub(A2, -8); - punpckldq(xmm0, xmm1); - pshufd(xmm6, xmm0, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l930); - test(M, 0x4); - jle(l974, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - movd(xmm1, dword[A2-0x80]); - sub(A2, -4); - punpckldq(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l974); - test(M, 0x2); - jle(l9b8, T_NEAR); - mov(ax, word[A1-0x80]); - sub(A1, -2); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A2-0x80]); - sub(A2, -2); - pinsrw(xmm0, eax, 0x1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l9b8); - test(M, 0x1); - jle(l9ec, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - mov(byte[B-0x80], al); - mov(al, byte[A2-0x80]); - pinsrb(xmm0, eax, 0x1); - mov(byte[B-0x7f], al); - sub(B, -2); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - align(4); - -L(l9ec); - mov(A1, qword[ARG_BIAS]); - movq(qword[A1], xmm7); - add(qword[ARG_BIAS], 0x8); - sub(N, 0x2); - cmp(N, 0x2); - jge(l828, T_NEAR); - align(4); - -L(la0a); - cmp(N, 0x1); - jl(lb58, T_NEAR); - align(4); - -L(la14); - mov(A1, A); - add(A, LDA); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x4); - jle(la6c, T_NEAR); - align(4); - -L(la28); - movdqu(xmm0, xword[A1-0x80]); - sub(A1, -16); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(I); - jg(la28, T_NEAR); - align(4); - -L(la6c); - test(M, 0x8); - jle(laa8, T_NEAR); - movq(xmm0, qword[A1-0x80]); - sub(A1, -8); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(laa8); - test(M, 0x4); - jle(lae0, T_NEAR); - movd(xmm0, dword[A1-0x80]); - sub(A1, -4); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(lae0); - test(M, 0x2); - jle(lb14, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - mov(word[B-0x80], ax); - sub(A1, -2); - sub(B, -2); - align(4); - -L(lb14); - test(M, 0x1); - jle(lb38, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrb(xmm0, eax, 0x0); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(lb38); - mov(A1, qword[ARG_BIAS]); - movd(dword[A1], xmm7); - add(qword[ARG_BIAS], 0x4); - sub(N, 0x1); - cmp(N, 0x1); - jge(la14, T_NEAR); - align(4); - -L(lb58); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -#undef ARG_BIAS -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp deleted file mode 100644 index afe4f1713..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp +++ /dev/null @@ -1,647 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_generator.hpp" -#include "common.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) -{ - -#ifndef _WIN32 -#define M rdi -#define N rsi -#define A rdx -#define LDA rcx -#define ALPHA r8 -#define B r9 - -#define I rax -#define A1 r10 -#define A2 r8 -#define LDA3 r11 - -#define ARG_BIAS 24+stacksize+rsp - -#else - -#define M rcx -#define N rdx -#define A r8 -#define LDA r9 -#define ALPHA rax -#define B rdi - -#define I rax -#define A1 rsi -#define A2 r10 -#define LDA3 r11 - -#define ARG_ALPHA 40+stacksize+rsp -#define ARG_B 48+stacksize+rsp -#define ARG_BIAS 72+stacksize+rsp - -#endif - -inLocalLabel(); -{ - -Xbyak::Label l15c; -Xbyak::Label l1f4; -Xbyak::Label l20; -Xbyak::Label l248; -Xbyak::Label l280; -Xbyak::Label l2a4; -Xbyak::Label l2b0; -Xbyak::Label l2c8; -Xbyak::Label l384; -Xbyak::Label l3e8; -Xbyak::Label l40; -Xbyak::Label l424; -Xbyak::Label l448; -Xbyak::Label l468; -Xbyak::Label l474; -Xbyak::Label l48c; -Xbyak::Label l550; -Xbyak::Label l5bc; -Xbyak::Label l600; -Xbyak::Label l628; -Xbyak::Label l646; -Xbyak::Label l650; -Xbyak::Label l668; -Xbyak::Label l700; -Xbyak::Label l760; -Xbyak::Label l7a4; -Xbyak::Label l7c8; -Xbyak::Label l7e8; - - preamble(); - auto stacksize = get_size_of_abi_save_regs(); -#ifdef _WIN32 - mov(ALPHA, ptr[ARG_ALPHA]); - mov(B, ptr[ARG_B]); -#endif - - mov(M, qword[M]); - mov(N, qword[N]); - mov(LDA, qword[LDA]); - lea(LDA3, ptr[LDA+LDA*2]); - sub(A, -128); - sub(B, -128); - cmp(N, 0x8); - jl(l2a4, T_NEAR); - align(4); - -L(l20); - mov(A1, A); - add(A, 0x8); - pxor(xmm8, xmm8); - pxor(xmm9, xmm9); - mov(I, M); - sar(I, 0x3); - jle(l15c, T_NEAR); - align(4); - -L(l40); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x60], xmm0); - movdqu(xword[B-0x50], xmm1); - sub(B, -64); - dec(I); - jg(l40, T_NEAR); - align(4); - -L(l15c); - test(M, 0x4); - jle(l1f4, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - movq(xmm2, qword[A1-0x80]); - add(A1, LDA); - movq(xmm3, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - movdqa(xmm1, xmm0); - punpcklwd(xmm0, xmm2); - punpckhwd(xmm1, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - pmovsxbw(xmm5, xmm1); - movhlps(xmm6, xmm1); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm9, xmm5); - movdqu(xword[B-0x80], xmm0); - movdqu(xword[B-0x70], xmm1); - sub(B, -32); - align(4); - -L(l1f4); - test(M, 0x2); - jle(l248, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - movq(xmm1, qword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm8, xmm5); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm6, xmm6); - pmovsxwd(xmm6, xmm6); - paddd(xmm9, xmm6); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l248); - test(M, 0x1); - jle(l280, T_NEAR); - movq(xmm0, qword[A1-0x80]); - add(A1, LDA); - pmovsxbd(xmm5, xmm0); - pshufd(xmm6, xmm0, 0x55); - pmovsxbd(xmm6, xmm6); - paddd(xmm8, xmm5); - paddd(xmm9, xmm6); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l280); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm8); - movdqu(xword[A1+0x10], xmm9); - add(qword[ARG_BIAS], 0x20); - sub(N, 0x8); - cmp(N, 0x8); - jge(l20, T_NEAR); - align(4); - -L(l2a4); - cmp(N, 0x4); - jl(l468, T_NEAR); - align(4); - -L(l2b0); - mov(A1, A); - add(A, 0x4); - pxor(xmm7, xmm7); - mov(I, M); - sar(I, 0x3); - jle(l384, T_NEAR); - align(4); - -L(l2c8); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x70], xmm0); - sub(B, -32); - dec(I); - jg(l2c8, T_NEAR); - align(4); - -L(l384); - test(M, 0x4); - jle(l3e8, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - movd(xmm2, dword[A1-0x80]); - add(A1, LDA); - movd(xmm3, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - movhlps(xmm6, xmm0); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - align(4); - -L(l3e8); - test(M, 0x2); - jle(l424, T_NEAR); - movd(xmm0, dword[A1-0x80]); - add(A1, LDA); - movd(xmm1, dword[A1-0x80]); - add(A1, LDA); - punpcklbw(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l424); - test(M, 0x1); - jle(l448, T_NEAR); - movd(xmm0, dword[A1-0x80]); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l448); - mov(A1, qword[ARG_BIAS]); - movdqu(xword[A1], xmm7); - add(qword[ARG_BIAS], 0x10); - sub(N, 0x4); - cmp(N, 0x4); - jge(l2b0, T_NEAR); - align(4); - -L(l468); - cmp(N, 0x2); - jl(l646, T_NEAR); - align(4); - -L(l474); - mov(A1, A); - add(A, 0x2); - pxor(xmm7, xmm7); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l550, T_NEAR); - align(4); - -L(l48c); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm4, eax, 0x0); - punpcklbw(xmm1, xmm2); - punpcklbw(xmm3, xmm4); - punpcklwd(xmm1, xmm3); - punpcklqdq(xmm0, xmm1); - pshufd(xmm6, xmm0, 0xd8); - pmovsxbw(xmm5, xmm6); - movhlps(xmm6, xmm6); - pmovsxbw(xmm6, xmm6); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movdqu(xword[B-0x80], xmm0); - sub(B, -16); - dec(LDA3); - jg(l48c, T_NEAR); - align(4); - -L(l550); - test(M, 0x4); - jle(l5bc, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm2, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm3, eax, 0x0); - punpcklbw(xmm0, xmm1); - punpcklbw(xmm2, xmm3); - punpcklwd(xmm0, xmm2); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - align(4); - -L(l5bc); - test(M, 0x2); - jle(l600, T_NEAR); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm0, eax, 0x0); - mov(ax, word[A1-0x80]); - add(A1, LDA); - pinsrw(xmm1, eax, 0x0); - punpcklbw(xmm0, xmm1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l600); - test(M, 0x1); - jle(l628, T_NEAR); - mov(ax, word[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - mov(word[B-0x80], ax); - sub(B, -2); - align(4); - -L(l628); - mov(A1, qword[ARG_BIAS]); - movq(qword[A1], xmm7); - add(qword[ARG_BIAS], 0x8); - sub(N, 0x2); - cmp(N, 0x2); - jge(l474, T_NEAR); - align(4); - -L(l646); - cmp(N, 0x1); - jl(l7e8, T_NEAR); - align(4); - -L(l650); - mov(A1, A); - add(A, 0x1); - pxor(xmm7, xmm7); - mov(LDA3, M); - sar(LDA3, 0x3); - jle(l700, T_NEAR); - align(4); - -L(l668); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x4); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x5); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x6); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x7); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm6); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movq(qword[B-0x80], xmm0); - sub(B, -8); - dec(LDA3); - jg(l668, T_NEAR); - align(4); - -L(l700); - test(M, 0x4); - jle(l760, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x2); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x3); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - movd(dword[B-0x80], xmm0); - sub(B, -4); - align(4); - -L(l760); - test(M, 0x2); - jle(l7a4, T_NEAR); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x0); - mov(byte[B-0x80], al); - mov(al, byte[A1-0x80]); - add(A1, LDA); - pinsrb(xmm0, eax, 0x1); - pmovsxbw(xmm5, xmm0); - phaddw(xmm5, xmm5); - pmovsxwd(xmm5, xmm5); - paddd(xmm7, xmm5); - mov(byte[B-0x7f], al); - sub(B, -2); - align(4); - -L(l7a4); - test(M, 0x1); - jle(l7c8, T_NEAR); - mov(al, byte[A1-0x80]); - pinsrw(xmm0, eax, 0x0); - pmovsxbd(xmm5, xmm0); - paddd(xmm7, xmm5); - mov(byte[B-0x80], al); - sub(B, -1); - align(4); - -L(l7c8); - mov(A1, qword[ARG_BIAS]); - movd(dword[A1], xmm7); - add(qword[ARG_BIAS], 0x4); - sub(N, 0x1); - cmp(N, 0x1); - jge(l650, T_NEAR); - align(4); - -L(l7e8); - - postamble(); -} -outLocalLabel(); - -#undef M -#undef N -#undef A -#undef LDA -#undef ALPHA -#undef B -#undef I -#undef A1 -#undef A2 -#undef LDA3 -#ifdef _WIN32 -#undef ARG_ALPHA -#undef ARG_B -#endif -#undef ARG_BIAS -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp deleted file mode 100644 index 4fc11afcb..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp +++ /dev/null @@ -1,116 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "../f32/ref_gemm_f32.hpp" -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co) { - - if (*M == 0 || *N == 0 || *K == 0) - return mkldnn_success; - - bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); - bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); - bool AisN = (*transa == 'N' || *transa == 'n'); - bool BisN = (*transb == 'N' || *transb == 'n'); - - int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; - size_t sizeA = AisN ? lda * k : lda * m; - size_t sizeB = BisN ? ldb * n : ldb * k; - size_t sizeC = ldc * n; - - double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); - double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); - double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); - - if (utils::any_null(dA, dB, dC)) { - free(dA); - free(dB); - free(dC); - return mkldnn_out_of_memory; - } - - auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; }; - auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; }; - - auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; }; - auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; }; - - const int a_rows = AisN ? m : k; - const int a_cols = AisN ? k : m; - mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) { - da_setter(i, j, - static_cast(ia_accessor(i, j)) + static_cast(ao[0])); - }); - - const int b_rows = BisN ? k : n; - const int b_cols = BisN ? n : k; - mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) { - db_setter(i, j, - static_cast(ib_accessor(i, j)) + static_cast(bo[0])); - }); - double one = 1.0, zero = 0.0; - ref_gemm(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, - dC, LDC, nullptr); - - auto i2d = [=] (int32_t v) { return static_cast(v); }; - auto f2d = [=] (float v) { return static_cast(v); }; - - mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) { - double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); - double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) - + f2d(*alpha) * dC[i + j * ldc] + coffset; - C[i + j * ldc] = math::out_round(math::saturate(val)); - }); - - free(dA); - free(dB); - free(dC); - return mkldnn_success; -} - -template mkldnn_status_t ref_gemm_s8x8s32( - const char *transa, const char *transb, const char *offsetc, - const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const uint8_t *B, const int *LDB, const int8_t *bo, - const float *beta, int32_t *C, const int *LDC, const int32_t *co); - -template mkldnn_status_t ref_gemm_s8x8s32( - const char *transa, const char *transb, const char *offsetc, - const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const int8_t *B, const int *LDB, const int8_t *bo, - const float *beta, int32_t *C, const int *LDC, const int32_t *co); - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp deleted file mode 100644 index 6c0370ae9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef REF_GEMM_S8X8S32_HPP -#define REF_GEMM_S8X8S32_HPP - -#include - -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co); - -} -} -} -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp deleted file mode 100644 index de1035f3b..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp +++ /dev/null @@ -1,180 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "common.hpp" -#include "nstl.hpp" -#include "math_utils.hpp" - -#include "../gemm.hpp" -#include "jit_avx512_core_gemm_s8u8s32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -void compensation_init(const char *offsetC, int32_t *compensation, int len, - const int32_t *oc) { - bool OCisC = (*offsetC == 'C' || *offsetC == 'c'); - bool OCisF = (*offsetC == 'F' || *offsetC == 'f'); - - if (OCisF && (*oc) != 0) { - for (int i = 0; i < len; i++) - compensation[i] = *oc; - } else if (OCisC) { - for (int i = 0; i < len; i++) - compensation[i] = oc[i]; - } else { - parallel_nd(len, [=](int i) { compensation[i] = 0; }); - } -} - -void compensation_compute(bool transa, int m, int k, float alpha, - const int8_t *a, int lda, int32_t *compensation) { - if (!transa) { - const int L2_cache_size = get_cache_size(2, true); - const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1); - const int npanels = k / blocking_factor; - const bool has_tile = k % blocking_factor > 0; - - parallel_nd(npanels, m, [&](int j, int i) { - int32_t val = 0; - for (int jb = 0; jb < blocking_factor; jb++) { - val += a[(i + (ptrdiff_t)j * blocking_factor * lda) - + (ptrdiff_t)jb * lda]; - } - if (alpha != 1.0f) { - val = math::out_round(math::saturate( - (double)val * alpha * -128.0)); - } else { - val *= -128; - } - fetch_and_add(&compensation[i], val); - }); - - if (has_tile) { - parallel_nd(m, [=](int i) { - int32_t val = 0; - for (int j = npanels * blocking_factor; j < k; j++) { - val += a[i + (ptrdiff_t)j * lda]; - } - if (alpha != 1.0f) { - val = math::out_round(math::saturate( - (double)val * alpha * -128.0)); - } else { - val *= -128; - } - fetch_and_add(&compensation[i], val); - }); - } - } else { - parallel_nd(m, [=](int i) { - int32_t val = 0; - for (int j = 0; j < k; j++) { - val += a[j + (ptrdiff_t)i * lda]; - } - if (alpha != 1.0f) { - val = math::out_round(math::saturate( - (double)val * alpha * -128.0)); - } else { - val *= -128; - } - compensation[i] += val; - }); - } -} - -void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8, - const int8_t *b_s8, int ldb_s8) { - const int b_cols = transb ? k : n; - - parallel_nd(b_cols, [=](int j) { - const int b_rows = transb ? n : k; - - uint8_t *pb_u8 = b_u8 + j * ldb_u8; - const int8_t *pb_s8 = b_s8 + j * ldb_s8; - - for (int i = 0; i < b_rows; i++) { - (*pb_u8) = (*pb_s8) + 128; - pb_u8++; - pb_s8++; - } - }); -} - -/** - * gemm_s8s8s32 operation is defined as follows: - * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation - * - * where - * - compensation is a vector of length m that contains computed compensation - * that may contain C_offset if applicable. The compensation is applied inside - * gemm_s8u8s32 as a C_offset - * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128 - * - * What is the compensation: - * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied: - * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset = - * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset - * compensation = -alpha * op(A) * B_shift - * Since B_shift is a matrix, every element of which is equal to 128 then - * - if op(A) = A: compensation contains sum of the elements in each row - * scaled by -128 * alpha - * - if op(A) = A**T: compensation contains sum of the elements in each column - * scaled by -128 * alpha - * - * The rest of parameters is described in mkldnn.h - */ -mkldnn_status_t simple_gemm_s8s8s32( - const char *transA, const char *transB, const char *offsetC, - const int *m, const int *n, const int *k, - const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, - const int8_t *b, const int *ldb, const int8_t *ob, - const float *beta, int32_t *c, const int *ldc, const int32_t *oc) { - if (*oa != 0 || *ob != 0) return mkldnn_unimplemented; - - int M = *m, N = *n, K = *k; - bool transa = (*transA == 'T' || *transA == 't'); - bool transb = (*transB == 'T' || *transB == 't'); - int ld = transb ? N : K; - - uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64); - int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64); - - if (utils::any_null(b_u8, compensation)) { - free(b_u8); - free(compensation); - return mkldnn_out_of_memory; - } - - compensation_init(offsetC, compensation, M, oc); - compensation_compute(transa, M, K, *alpha, a, *lda, compensation); - copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb); - - gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8, - &ld, ob, beta, c, ldc, compensation); - - if ((*offsetC == 'R' || *offsetC == 'r')) - parallel_nd(M, N, - [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; }); - - free(b_u8); - free(compensation); - - return mkldnn_success; -} -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp deleted file mode 100644 index 03a3d2f7e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SIMPLE_GEMM_S8S8S32_HPP -#define SIMPLE_GEMM_S8S8S32_HPP - -#include -#include "mkldnn_types.h" - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t simple_gemm_s8s8s32( - const char *transA, const char *transB, const char *offsetC, - const int *m, const int *n, const int *k, - const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, - const int8_t *b, const int *ldb, const int8_t *ob, - const float *beta, int32_t *c, const int *ldc, const int32_t *oc); -} -} -} - -#endif // SIMPLE_GEMM_S8S8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp deleted file mode 100644 index 604a728b4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp +++ /dev/null @@ -1,307 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "gemm_convolution.hpp" -#include "utils.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "ref_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - auto col = scratchpad(ctx).get(key_conv_gemm_col); - - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - const int M = jcp.os * jcp.od; - const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; - const size_t dst_step = jcp.oc * M; - const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; - - assert(IMPLICATION( - jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); - assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); - - const int K = jcp.ic * jcp.ks; - const int N = jcp.oc; - - if (jcp.im2col_sz && jcp.id != 1) - parallel_nd(jcp.im2col_sz * jcp.nthr, - [&](ptrdiff_t i) { col[i] = (data_t)0; }); - - const int nb_oh = div_up(jcp.oh, jcp.oh_block); - const int nb_ow = div_up(jcp.ow, jcp.ow_block); - const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow; - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; - - int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 }; - size_t start = 0, end = 0; - - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, - nb_oh, owb, nb_ow); - for (size_t iwork = start; iwork < end; ++iwork) { - int oh = ohb * jcp.oh_block; - int ow = owb * jcp.ow_block; - const data_t *_src = src + (n * jcp.ngroups + g) * src_step; - const data_t *_weights = weights + g * weights_g_size; - data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step; - const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); - const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); - if (jcp.im2col_sz) { - if (jcp.id == 1) - jit_gemm_convolution_utils::im2col( - jcp, _src, _col, oh, h_step, ow, w_step); - else - jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od); - } - - const data_t one = 1.0; - - const int m = h_step * w_step; - const int LDA = jcp.im2col_sz ? m : M; - data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow; - - extended_sgemm("N", "N", &m, &N, &K, &one, - jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K, - &this->beta_, _dst, &M); - - data_t *d = _dst; - if (eltwise_) { - // fast branch for ReLU case - if (eltwise_->alg_ == alg_kind::eltwise_relu) { - parallel_nd(jcp.oc, [&](const int oc) { - data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; - data_t *d_ = d + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_; - } - }); - } else { - parallel_nd(jcp.oc, [&](const int oc) { - data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; - data_t *d_ = d + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - d_[oS] = eltwise_->compute_scalar(d_[oS]); - } - }); - } - } else if (jcp.with_bias) { - parallel_nd(jcp.oc, [&](const int oc) { - data_t b = bias[g * jcp.oc + oc]; - data_t *d_ = d + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - } - }); - } - nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh, - owb, nb_ow); - } - }); -} - -void gemm_convolution_bwd_data_t::execute_backward_data( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - auto col = scratchpad(ctx).get(key_conv_gemm_col); - - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - const int M = jcp.os * jcp.od; - const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; - const size_t dst_step = jcp.oc * M; - const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; - - const int m = jcp.os; - const int K = jcp.oc; - const int N = jcp.ic * jcp.ks; - const int LDC = jcp.im2col_sz ? m : M; - - const size_t work_amount = (size_t)jcp.ngroups * jcp.mb; - - if (jcp.id > 1) { - const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step); - parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; }); - } - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; - - int g{0}, n{0}; - size_t start = 0, end = 0; - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); - for (size_t iwork = start; iwork < end; ++iwork) { - - data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step; - const data_t *_weights = weights + g * weights_g_size; - for (int od = 0; od < jcp.od; ++od) { - const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g) - *dst_step + od * m; - - const data_t zero = 0.0, one = 1.0; - extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M, - _weights, &N, &zero, - jcp.im2col_sz ? _col:_diff_src + od * m, &LDC); - - if (jcp.im2col_sz) { - if (jcp.id == 1) - jit_gemm_convolution_utils::col2im(jcp, _col, - _diff_src); - else - jit_gemm_convolution_utils::col2im_3d(jcp, _col, - _diff_src, od); - } - } - nd_iterator_step(g, jcp.ngroups, n, jcp.mb); - } - }); -} - -void gemm_convolution_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - auto col = scratchpad(ctx).get(key_conv_gemm_col); - auto wei_reduction = scratchpad(ctx).get(key_conv_wei_reduction); - - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - const int K = jcp.os * jcp.od; - const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; - const size_t dst_step = jcp.oc * K; - const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; - - const int k = jcp.os; - const int N = jcp.oc; - const int M = jcp.ic * jcp.ks; - const int LDA = jcp.im2col_sz ? k : K; - - parallel_nd(jcp.im2col_sz * jcp.nthr, - [&](ptrdiff_t i) { col[i] = (data_t)0; }); - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - int ithr_g, nthr_g, ithr_mb, nthr_mb; - size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0}; - - const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; - jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, - mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); - - assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); - const int need_reduction = nthr_mb != 1; - - if (ithr_g != -1 && ithr_mb != -1) { - balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); - balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); - - assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); - - data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; - data_t *weights_reduce_base = wei_reduction - + ithr_g * nthr_mb * weights_g_size; - data_t *weights_reduce = weights_reduce_base - + ithr_mb * weights_g_size; - - for (size_t g = g_start; g < g_end; ++g) { - data_t *_diff_weights = need_reduction - ? weights_reduce : (diff_weights + g * weights_g_size); - for (size_t mb = mb_start; mb < mb_end; ++mb) { - const data_t *_src = src + (mb*jcp.ngroups+g)*src_step; - for (int od = 0; od < jcp.od; ++od) { - const data_t *_diff_dst = diff_dst - + (mb*jcp.ngroups+g)*dst_step + od * k; - - if (jcp.im2col_sz) { - if (jcp.id == 1) - jit_gemm_convolution_utils::im2col( - jcp, _src, _col, 0, jcp.oh, 0, jcp.ow); - else - jit_gemm_convolution_utils::im2col_3d(jcp, _src, - _col, od); - } - - const data_t zero = 0.0, one = 1.0; - extended_sgemm( - "T", "N", &M, &N, &k, &one, - jcp.im2col_sz ? _col : _src + od * k, - &LDA, _diff_dst, &K, - mb == mb_start && od == 0 ? &zero : &one, - _diff_weights, &M); - } - } - } - if (need_reduction) { - mkldnn_thr_barrier(); - data_t *weights_base = diff_weights + g_start * weights_g_size; - jit_gemm_convolution_utils::bwd_weights_reduction_par( - ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base); - } - } else - if (need_reduction) { mkldnn_thr_barrier(); } - }); - - if (jcp.with_bias) { - parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) { - data_t db = 0; - size_t offset_ = (size_t)g * dst_step + (size_t)oc * K; - for (int mb = 0; mb < jcp.mb; ++mb) - { - size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step; - for (int od = 0; od < jcp.od; ++od) - for (int oh = 0; oh < jcp.oh; ++oh) - PRAGMA_OMP_SIMD(reduction(+:db)) - for (int ow = 0; ow < jcp.ow; ++ow) { - db += diff_dst[offset]; - offset++; - } - } - diff_bias[g*jcp.oc+oc] = db; - }); - } -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp deleted file mode 100644 index 302e46369..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp +++ /dev/null @@ -1,250 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP -#define CPU_JIT_GEMM_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "gemm_convolution_utils.hpp" -#include "gemm/gemm.hpp" -#include "ref_eltwise.hpp" - -#include "cpu_convolution_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct gemm_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) - && post_ops_ok() - && memory_desc_matches_tag(*src_md(), dat_tag()) - && memory_desc_matches_tag(*dst_md(), dat_tag()) - && memory_desc_matches_tag(*weights_md(), wei_tag()); - if (!ok) return status::unimplemented; - - auto scratchpad = scratchpad_registry().registrar(); - return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, - *desc(), src_md(), weights_md(0), dst_md(), - mkldnn_get_max_threads()); - } - - jit_gemm_conv_conf_t jcp_; - - protected: - format_tag_t dat_tag() const { - using namespace format_tag; - return utils::pick(ndims() - 3, ncw, nchw, ncdhw); - } - - format_tag_t wei_tag() const { - using namespace format_tag; - return with_groups() - ? utils::pick(ndims() - 3, goiw, goihw, goidhw) - : utils::pick(ndims() - 3, oiw, oihw, oidhw); - } - - bool post_ops_ok() const { - auto const &po = attr()->post_ops_; - auto is_eltwise = [&](int idx) - { return po.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; - - switch (po.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - return false; - } - }; - - gemm_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd, true) - , eltwise_(nullptr) - { - const auto &post_ops = pd()->attr()->post_ops_; - const data_t one = 1.0, zero = 0.0; - beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero; - - const int entry_idx = post_ops.find(primitive_kind::eltwise); - if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t( - post_ops.entry_[entry_idx].eltwise); - } - - ~gemm_convolution_fwd_t() { delete eltwise_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - data_t beta_; - - ref_eltwise_scalar_fwd_t* eltwise_; -}; - -struct gemm_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::undef, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) - && memory_desc_matches_tag(*diff_src_md(), dat_tag()) - && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) - && memory_desc_matches_tag(*weights_md(), wei_tag()); - if (!ok) return status::unimplemented; - - auto scratchpad = scratchpad_registry().registrar(); - return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, - *desc(), diff_src_md(), weights_md(0), diff_dst_md(), - mkldnn_get_max_threads()); - } - - jit_gemm_conv_conf_t jcp_; - - protected: - format_tag_t dat_tag() const { - using namespace format_tag; - return utils::pick(ndims() - 3, ncw, nchw, ncdhw); - } - - format_tag_t wei_tag() const { - using namespace format_tag; - return with_groups() - ? utils::pick(ndims() - 3, goiw, goihw, goidhw) - : utils::pick(ndims() - 3, oiw, oihw, oidhw); - } - }; - - gemm_convolution_bwd_data_t(const pd_t *apd) - : cpu_primitive_t(apd, true) {} - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct gemm_convolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) - && memory_desc_matches_tag(*src_md(), dat_tag()) - && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) - && memory_desc_matches_tag(*diff_weights_md(), wei_tag()); - if (!ok) return status::unimplemented; - - auto scratchpad = scratchpad_registry().registrar(); - return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, - *desc(), src_md(), diff_weights_md(0), diff_dst_md(), - mkldnn_get_max_threads()); - } - - jit_gemm_conv_conf_t jcp_; - - protected: - format_tag_t dat_tag() const { - using namespace format_tag; - return utils::pick(ndims() - 3, ncw, nchw, ncdhw); - } - - format_tag_t wei_tag() const { - using namespace format_tag; - return with_groups() - ? utils::pick(ndims() - 3, goiw, goihw, goidhw) - : utils::pick(ndims() - 3, oiw, oihw, oidhw); - } - }; - - gemm_convolution_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd, true) {} - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp deleted file mode 100644 index f133b1e62..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp +++ /dev/null @@ -1,771 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" -#include "cpu_isa_traits.hpp" - -#include "gemm_convolution_utils.hpp" -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; -using namespace prop_kind; -using namespace data_type; - -namespace jit_gemm_convolution_utils { - -void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, - int od) -{ - const size_t OHW = jcp.oh * jcp.ow; - const size_t im_step = jcp.ih * jcp.iw * jcp.id; - const size_t col_step = jcp.ks * OHW; - - parallel_nd(jcp.ic, [&](int ic) { - const float *__restrict im_loc = im + ic * im_step; - float *__restrict col_loc = col + ic * col_step; - int id = od * jcp.stride_d - jcp.f_pad; - for (int kd = 0; kd < jcp.kd; ++kd) { - float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; - if (id < 0 || id >= jcp.id) { - int ih_ = -jcp.t_pad; - for (int kh = 0; kh < jcp.kh; ++kh) { - int ih = ih_; - for (int oh = 0; oh < jcp.oh; ++oh) { - if (ih < 0 || ih >= jcp.ih) { - ih += jcp.stride_h; - continue; - } - int iw_ = -jcp.l_pad; - for (int kw = 0; kw < jcp.kw; ++kw) { - int iw = iw_; - for (int ow = 0; ow < jcp.ow; ++ow) { - if (iw < 0 || iw >= jcp.iw) { - iw += jcp.stride_w; - continue; - } - - const size_t col_idx = kw * OHW + oh * jcp.ow - + ow; - - col_[col_idx] = 0; - iw += jcp.stride_w; - } - iw_ += (1 + jcp.dilate_w); - } - ih += jcp.stride_h; - } - ih_ += (1 + jcp.dilate_h); - col_ += jcp.kw * OHW; - } - } else { - const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; - int ih_ = -jcp.t_pad; - for (int kh = 0; kh < jcp.kh; ++kh) { - int ih = ih_; - for (int oh = 0; oh < jcp.oh; ++oh) { - if (ih < 0 || ih >= jcp.ih) { - ih += jcp.stride_h; - continue; - } - int iw_ = -jcp.l_pad; - for (int kw = 0; kw < jcp.kw; ++kw) { - int iw = iw_; - for (int ow = 0; ow < jcp.ow; ++ow) { - if (iw < 0 || iw >= jcp.iw) { - iw += jcp.stride_w; - continue; - } - - const size_t col_idx = kw * OHW + oh * jcp.ow - + ow; - const size_t im_idx = ih * jcp.iw + iw; - - col_[col_idx] = im_[im_idx]; - iw += jcp.stride_w; - } - iw_ += (1 + jcp.dilate_w); - } - ih += jcp.stride_h; - } - ih_ += (1 + jcp.dilate_h); - col_ += jcp.kw * OHW; - } - } - id += (1 + jcp.dilate_d); - } - }); -} - -/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ -void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, - float *__restrict col, int hs, int hb, int ws, int wb) { - const size_t im_step = jcp.is; - const size_t col_step = jcp.ks * hb * wb; - if (jcp.stride_w == 1) { - // Generated code is more optimized for stride_w == 1 - // because innermost loop is by width - auto ker = [&](int ic, int kh, int kw, int oh) { - const float *__restrict im_ = im + ic * im_step; - float *__restrict col_ - = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb; - - const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) { - for (int ow = 0; ow < wb; ++ow) - col_[ow] = 0.f; - } else { - for (int ow = 0; ow < wb; ++ow) { - const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w); - if (iw < 0 || iw >= jcp.iw) - col_[ow] = 0.f; - else { - const size_t im_idx = ih * jcp.iw + iw; - col_[ow] = im_[im_idx]; - } - } - } - }; - - if (jcp.outer_threading) { - for (int ic = 0; ic < jcp.ic; ic++) - for (int kh = 0; kh < jcp.kh; kh++) - for (int kw = 0; kw < jcp.kw; kw++) - for (int oh = 0; oh < hb; oh++) - ker(ic, kh, kw, oh); - } - else { - parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker); - } - } else if (jcp.ic == 1) { - parallel_nd(jcp.kh, hb, [&](int kh, int oh) { - const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) - for (int kw = 0; kw < jcp.kw; ++kw) { - for (int ow = 0; ow < wb; ++ow) { - const size_t col_idx - = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; - col[col_idx] = 0; - } - } - else - for (int kw = 0; kw < jcp.kw; ++kw) { - for (int ow = 0; ow < wb; ++ow) { - const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad - + kw * (1 + jcp.dilate_w); - const size_t col_idx - = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; - const size_t im_idx = ih * jcp.iw + iw; - if (iw < 0 || iw >= jcp.iw) - col[col_idx] = 0; - else - col[col_idx] = im[im_idx]; - } - } - }); - } else { - - parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, - [&](int ic, int kh, int kw, int oh) { - const float *__restrict im_ = im + ic * im_step; - float *__restrict col_ = col + ic * col_step - + ((kh * jcp.kw + kw) * hb + oh) * wb; - - const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) { - for (int ow = 0; ow < wb; ++ow) - col_[ow] = 0.f; - } else { - for (int ow = 0; ow < wb; ++ow) { - const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad - + kw * (1 + jcp.dilate_w); - const size_t im_idx = ih * jcp.iw + iw; - if (iw < 0 || iw >= jcp.iw) - col_[ow] = 0.f; - else - col_[ow] = im_[im_idx]; - } - } - }); - } -} - -inline int limit(int low, int upper, int value) { - return nstl::max(low, nstl::min(upper, value)); -} - -/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */ -template -void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, - T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws, - int wb) { - uint8_t shift = jcp.signed_input ? 128 : 0; - const int dh = 1 + jcp.dilate_h; - const int dw = 1 + jcp.dilate_w; - const int sh = jcp.stride_h; - const int sw = jcp.stride_w; - const int im_iw_stride = jcp.ic * jcp.ngroups; - const int im_ih_stride = jcp.iw * im_iw_stride; - const int tp = jcp.t_pad; - const int lp = jcp.l_pad; - - if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { - /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ - const int hp = hs - tp; - const int wp = ws - lp; - const int ih_start = limit(0, jcp.ih, hp); - const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh); - const int iw_start = limit(0, jcp.iw, wp); - const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw); - - const int ihb = ih_end - ih_start; - const int iwb = iw_end - iw_start; - - const int imtr_ic_stride = ihb * iwb; - const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start; - for (int ic = 0; ic < jcp.ic; ic++) { - const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift; - for (int ih = ih_start; ih < ih_end; ih++) { - const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride; - const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb; - for (int iw = iw_start; iw < iw_end; iw++) - imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride]; - } - } - - const int col_ic_str = hb * wb; - const int col_kw_stride = jcp.ic * col_ic_str; - const int col_kh_stride = jcp.kw * col_kw_stride; - - const int oh_init = ih_start - hp; - const int ow_init = iw_start - wp; - for (int kh = 0; kh < jcp.kh; kh++) { - const ptrdiff_t col_idx_kh = kh * col_kh_stride; - const int oh_kh = oh_init - kh; - const int oh_start = limit(0, hb, oh_kh); - const int oh_end = limit(0, hb, oh_kh + ihb); - for (int kw = 0; kw < jcp.kw; kw++) { - const ptrdiff_t col_idx_kw - = col_idx_kh + kw * jcp.ic * col_ic_str; - const int ow_kw = ow_init - kw; - const int imtr_shift = oh_kh * iwb + ow_kw; - const int ow_start = limit(0, wb, ow_kw); - const int ow_end = limit(0, wb, ow_kw + iwb); - for (int ic = 0; ic < jcp.ic; ic++) { - const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; - const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; - for (int oh = 0; oh < oh_start; oh++) { - const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - for (int ow = 0; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; - } - for (int oh = oh_start; oh < oh_end; oh++) { - const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; - for (int ow = 0; ow < ow_start; ++ow) - col[col_idx_oh + ow] = shift; - for (int ow = ow_start; ow < ow_end; ++ow) - col[col_idx_oh + ow] - = imtr[imtr_idx_oh + ow] + shift; - for (int ow = ow_end; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; - } - for (int oh = oh_end; oh < hb; oh++) { - const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; - for (int ow = 0; ow < wb; ++ow) - col[col_idx_oh + ow] = shift; - } - } - } - } - } else { - parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, - [&](int kh, int kw, int ic, int oh) { - const int hp = tp - kh * dh; - const int ih = (oh + hs) * sh - hp; - const ptrdiff_t col_idx_base - = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; - if (ih < 0 || ih >= jcp.ih) - for (int ow = 0; ow < wb; ow++) - col[col_idx_base + ow] = shift; - else { - const int wp = lp - kw * dw; - const int ow_start = limit(0, wb, div_up(wp, sw) - ws); - const int ow_end - = limit(0, wb, div_up(jcp.iw + wp, sw) - ws); - for (int ow = 0; ow < ow_start; ow++) - col[col_idx_base + ow] = shift; - const int iw_base = ws * sw - wp; - const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; - for (int ow = ow_start; ow < ow_end; ow++) { - const int iw = iw_base + ow * sw; - const ptrdiff_t im_idx - = im_idx_base + iw * im_iw_stride; - col[col_idx_base + ow] = im[im_idx] + shift; - } - for (int ow = ow_end; ow < wb; ow++) - col[col_idx_base + ow] = shift; - } - }); - } -} - -template void im2col_u8(const jit_gemm_conv_conf_t &jcp, - const int8_t *__restrict im, int8_t *__restrict imtr, - uint8_t *__restrict col, int hs, int hb, int ws, int wb); -template void im2col_u8(const jit_gemm_conv_conf_t &jcp, - const uint8_t *__restrict im, uint8_t *__restrict imtr, - uint8_t *__restrict col, int hs, int hb, int ws, int wb); - -/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */ -void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, - int32_t *__restrict im) -{ - parallel(0, [&](const int ithr, const int nthr) { - int h_nthr = nstl::min(jcp.ih, nthr); - int w_nthr = nstl::min(jcp.iw, nthr / h_nthr); - int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0; - if (ithr < h_nthr * w_nthr) { - h_ithr = ithr / w_nthr; - w_ithr = ithr % w_nthr; - balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e); - balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e); - } else { - h_ithr = w_ithr = -ithr; - h_s = h_e = w_s = w_e = -1; - } - - for (int ih = h_s; ih < h_e; ++ih) { - for (int iw = w_s; iw < w_e; ++iw) { - PRAGMA_OMP_SIMD() - for (int ic = 0; ic < jcp.ic; ++ic) { - im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0; - } - } - } - - // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh] - for (int oh = 0; oh < jcp.oh; ++oh) { - for (int ow = 0; ow < jcp.ow; ++ow) { - for (int kh = 0; kh < jcp.kh; ++kh) { - const int ih = oh * jcp.stride_h - - jcp.t_pad + kh * (1 + jcp.dilate_h); - if (ih < h_s || ih >= h_e) continue; - - for (int kw = 0; kw < jcp.kw; ++kw) { - const int iw = ow * jcp.stride_w - - jcp.l_pad + kw * (1 + jcp.dilate_w); - if (iw < w_s || iw >= w_e) continue; - - const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh - + kh) * jcp.kw + kw) * jcp.ic; - const size_t im_idx - = (ih * jcp.iw + iw) * jcp.ic; - PRAGMA_OMP_SIMD() - for (int ic = 0; ic < jcp.ic; ++ic) { - im[im_idx + ic] += col[col_idx + ic]; - } - } - } - } - } - }); -} - -void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, - int od) -{ - parallel_nd(jcp.ic, [&](int ic) { - const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os; - float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id; - - int id = od * jcp.stride_d - jcp.f_pad; - for (int kd = 0; kd < jcp.kd; ++kd) { - if (id < 0 || id >= jcp.id) { - col_ += jcp.kh * jcp.kw * jcp.os; - id += (1 + jcp.dilate_d); - continue; - } - - float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw; - - for (int oh = 0; oh < jcp.oh; ++oh) { - for (int kh = 0; kh < jcp.kh; ++kh) { - const int ih = oh * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) continue; - - for (int ow = 0; ow < jcp.ow; ++ow) { - for (int kw = 0; kw < jcp.kw; ++kw) { - const int iw = ow * jcp.stride_w - jcp.l_pad - + kw * (1 + jcp.dilate_w); - if (iw < 0 || iw >= jcp.iw) continue; - - const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; - const size_t im_idx = ih*jcp.iw + iw; - im_[im_idx] += col_[col_idx]; - }} - }} - - col_ += jcp.kh * jcp.kw * jcp.os; - id += (1 + jcp.dilate_d); - } - }); -} - -void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) { - const size_t col_step = jcp.ks * jcp.os; - const size_t im_step = jcp.ih * jcp.iw; - const int iS = jcp.ih * jcp.iw; - - parallel_nd(jcp.ic, [&](int ic) { - float *__restrict im_ = im + ic * im_step; - const float *__restrict col_ = col + ic * col_step; - PRAGMA_OMP_SIMD() - for (int is = 0; is < iS; ++is) im_[is] = 0.; - - for (int kh = 0; kh < jcp.kh; ++kh) { - for (int oh = 0; oh < jcp.oh; ++oh) { - const int ih = - oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) continue; - - for (int kw = 0; kw < jcp.kw; ++kw) { - for (int ow = 0; ow < jcp.ow; ++ow) { - const int iw = - ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w); - if (iw < 0 || iw >= jcp.iw) continue; - - const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; - const size_t im_idx = ih*jcp.iw + iw; - im_[im_idx] += col_[col_idx]; - } - } - } - } - }); -} - -status_t init_conf(jit_gemm_conv_conf_t &jcp, - memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, int max_threads) { - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - const int ndims = src_d.ndims(); - const int is_1d = ndims == 3; - const int is_3d = ndims == 5; - - jcp.prop_kind = cd.prop_kind; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.id = is_3d ? src_d.dims()[2] : 1; - jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.od = is_3d ? dst_d.dims()[2] : 1; - jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; - jcp.ow = dst_d.dims()[ndims - 1]; - - jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; - jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.f_pad = is_3d ? cd.padding[0][0] : 0; - jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; - jcp.l_pad = cd.padding[0][ndims - 3]; - - jcp.stride_d = is_3d ? cd.strides[0] : 1; - jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; - jcp.stride_w = cd.strides[ndims - 3]; - - jcp.dilate_d = is_3d ? cd.dilates[0] : 0; - jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; - jcp.dilate_w = cd.dilates[ndims - 3]; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef - || cd.diff_bias_desc.format_kind != format_kind::undef; - - jcp.is = jcp.ih * jcp.iw; - jcp.os = jcp.oh * jcp.ow; - jcp.ks = jcp.kh * jcp.kw * jcp.kd; - - jcp.signed_input = src_d.data_type() == data_type::s8; - - jcp.im2col_sz = !everyone_is(true, - jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id, - jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1, - jcp.ks == 1, !jcp.signed_input) - ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0; - - jcp.outer_threading = false; - - bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8) - && weights_d.data_type() == s8; - - const int vlen = mayiuse(avx512_common) - ? cpu_isa_traits::vlen - : mayiuse(avx) - ? cpu_isa_traits::vlen - : mayiuse(sse42) ? cpu_isa_traits::vlen : 4; - const int simd_w = vlen / (is_int8_conv ? 1 : 4); - - const bool is_bwd_d = jcp.prop_kind == backward_data; - const bool is_bwd_w = jcp.prop_kind == backward_weights; - const bool is_fwd = !is_bwd_d && !is_bwd_w; - jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; - jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; - - using namespace memory_tracking::names; - bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; - - // TODO: maybe mitigate blocking restriction - const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; - const int L2 = get_cache_size(2, true) - / (is_int8_conv ? sizeof(int8_t) : sizeof(float)); - bool is_blocking_applicable = true - && is_fwd && jcp.im2col_sz - && jcp.id == 1 && jcp.od == 1 - && jcp.dilate_h == 0 && jcp.dilate_w == 0 - && !is_depthwise - && wei_size < L2/2; - if (is_blocking_applicable) { - // looking for oh and ow blocking - int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; - const int ic = jcp.ic; - const int oc = jcp.oc; - const int iw = jcp.iw; - const int ow = jcp.ow; - const int oh = jcp.oh; - const int os = oh * ow; - - // 1. cache requirement - int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); - if (is_int8_conv) { - // Heuristic rule: gemm needed a lot of memory for internal usage - row_size *= 5; - // memory for accumulators - row_size += oc * ow * sizeof(uint32_t); - // memory for transposition - row_size += ic * iw; - } - - h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); - if (h_block == 1) { - int col_size = ic * jcp.ks + 2 * (ic + oc); - if (is_int8_conv) { - col_size *= 5; - col_size += oc * sizeof(uint32_t); - col_size += ic; - } - w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); - } - - // 2. threading requirement - if (h_block != oh) - h_block = nstl::max(1, rnd_dn(h_block, 4)); - if (w_block != ow) - w_block = nstl::max(1, rnd_dn(w_block, simd_w)); - - float thr_eff = 0.f; - float thr_eff_treshold = 0.9f; - if (w_block == ow) { - do { - int nb_h = div_up(oh, h_block); - size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; - float disb = (float)oh / rnd_up(oh, h_block); - thr_eff = (float)work / rnd_up(work, max_threads); - thr_eff = (thr_eff + disb) / 2.f; - if (thr_eff >= thr_eff_treshold) - break; - h_block = rnd_dn(h_block - 4, 4); - } while (h_block > 0); - } - if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block - { - h_block = 1; - int nb_h = oh; - do { - int nb_w = div_up(ow, w_block); - size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; - float disb = (float)ow / rnd_up(ow, w_block); - thr_eff = (float)work_amount / rnd_up(work_amount, max_threads); - thr_eff = (thr_eff + disb) / 2.f; - if (thr_eff > thr_eff_treshold) - break; - w_block = rnd_dn(w_block - simd_w, simd_w); - } while (w_block > 0); - } - h_block = nstl::max(1, h_block); - w_block = nstl::max(1, w_block); - const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); - const float inner_thr_eff - = (float)inner_work / rnd_up(inner_work, max_threads); - if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) { - jcp.oh_block = h_block; - jcp.ow_block = w_block; - jcp.outer_threading = true; - } - // updating jcp.im2col_sz - if (jcp.oh_block != 1) - jcp.ow_block = ow; - jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; - } - // For threading selection in bwd_d we do: - // 1. Rough estimation of efficiency for inner and outer threading. - // 2. Gemm size estimation in assumption that it does not work - // so effectively for small sizes. - // 64K - this is heuristic gemm size per thread threshold. - const int gemm_thrld = 64 * 1024; - - if (is_int8_conv) { - if (is_fwd) { - if (!jcp.outer_threading) { - bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; - const size_t outer_work = jcp.ngroups * jcp.mb; - const float outer_thr_eff - = (float)outer_work / rnd_up(outer_work, max_threads); - const size_t inner_work - = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); - const float inner_thr_eff - = (float)inner_work / rnd_up(inner_work, max_threads); - jcp.outer_threading = (is_depthwise - || (jcp.is / max_threads < 64 && jcp.mb != 1)) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); - } - jcp.nthr = jcp.outer_threading ? max_threads : 1; - scratchpad.book(key_conv_gemm_col, - sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); - scratchpad.book(key_conv_int_dat_in_acc_dt, - sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); - scratchpad.book(key_conv_gemm_imtr, - sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); - } else if (is_bwd_d) { - bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; - const size_t outer_work = jcp.ngroups * jcp.mb; - const float outer_thr_eff - = (float)outer_work / rnd_up(outer_work, max_threads); - const size_t inner_work - = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); - const float inner_thr_eff - = (float)inner_work / rnd_up(inner_work, max_threads); - jcp.outer_threading = (is_depthwise - || (jcp.is / max_threads < 64 && jcp.mb != 1)) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); - - jcp.nthr = jcp.outer_threading ? max_threads : 1; - scratchpad.book(key_conv_gemm_col, - sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); - scratchpad.book(key_conv_int_dat_in_acc_dt, - sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); - } else if (is_bwd_w) { - assert(!"unimplemented prop_kind"); - return status::unimplemented; - } - } else { - if (is_fwd) { - if (!jcp.outer_threading) { - const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; - const float outer_thr_eff = (float)outer_work_amount - / rnd_up(outer_work_amount, max_threads); - const size_t inner_work_amount - = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w); - const float inner_thr_eff = (float)inner_work_amount - / rnd_up(inner_work_amount, max_threads); - jcp.outer_threading = jcp.os / max_threads < 512 - && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); - } - } else if (is_bwd_d) { - const size_t outer_work_amount = jcp.ngroups * jcp.mb; - const float outer_thr_eff = (float)outer_work_amount - / rnd_up(outer_work_amount, max_threads); - const size_t inner_work - = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); - const float inner_thr_eff = (float)inner_work - / rnd_up(inner_work, max_threads); - jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) - && (jcp.mb != 1 || jcp.ngroups > 2) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); - } else if (is_bwd_w) - jcp.outer_threading = jcp.os / max_threads < 256 - && (jcp.mb != 1 || jcp.ngroups > 2); - - jcp.nthr = jcp.outer_threading ? max_threads : 1; - scratchpad.book(key_conv_gemm_col, - sizeof(float) * jcp.nthr * jcp.im2col_sz); - - if (is_bwd_w) { - jcp.need_wei_reduction = mkldnn_thr_syncable() - ? jcp.mb != 1 && jcp.nthr != 1 : false; - scratchpad.book(key_conv_wei_reduction, - sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size()); - } - } - - return status::success; -} - -void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, - int &nthr_g, int &ithr_mb, int &nthr_mb) { - nthr_g = nstl::min(ngroups, nthr); - nthr_mb = nstl::min(mb, nthr / nthr_g); - if (ithr / nthr_mb >= ngroups) { - ithr_g = ithr_mb = -1; - } else { - ithr_g = ithr / nthr_mb; - ithr_mb = ithr % nthr_mb; - } -} - -void bwd_weights_reduction_par(int ithr, int nthr, - const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, - float *weights) { - const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; - - size_t weights_start{0}, weights_end{0}; - balance211(weights_g_size, nthr, ithr, weights_start, weights_end); - - for (int i = 0; i < nthr; ++i) { - const float *ws_i = weights_reduce_ws + i * weights_g_size; - for (size_t s = weights_start; s < weights_end; ++s) - weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s]; - } -} - -}; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp deleted file mode 100644 index e00678934..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp +++ /dev/null @@ -1,66 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP -#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_engine.hpp" -#include "jit_primitive_conf.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace jit_gemm_convolution_utils { - -void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, - int od); -void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, - float *__restrict col, int hs, int hb, int ws, int wb); -template -void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, - T* __restrict imtr, uint8_t *__restrict col, - int hs, int hb, int ws, int wb); - -void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, - int32_t *__restrict im); -void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, - int od); -void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im); - -status_t init_conf(jit_gemm_conv_conf_t &jcp, - memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, int max_threads); - -void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, - int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb); -void bwd_weights_reduction_par(int ithr, int nthr, - const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, - float *weights); - -} - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp deleted file mode 100644 index 2872122f0..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp +++ /dev/null @@ -1,156 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" - -#include "gemm_inner_product.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::data_type; -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::primitive_kind; - -template -void gemm_inner_product_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int IC = pd()->IC_total_padded(); - - bool wei_tr = !memory_desc_matches_one_of_tag( - *pd()->weights_md(), hwio, dhwio, io); - - const auto &post_ops = pd()->attr()->post_ops_; - const bool do_relu = post_ops.len_ == 1; - - float alpha = 1.0, beta = 0.0; - extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights, - wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias); - - if (do_relu) { - float nslope = post_ops.entry_[0].eltwise.alpha; - parallel_nd(MB, OC, [&](int mb, int oc) { - size_t dst_off = mb * OC + oc; - if (dst[dst_off] < 0) - dst[dst_off] *= nslope; - }); - } -} - -template -void gemm_inner_product_bwd_data_t::execute_backward_data( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int IC = pd()->IC_total_padded(); - - bool wei_tr = memory_desc_matches_one_of_tag( - *pd()->weights_md(), hwio, dhwio, io); - - float alpha = 1.0, beta = 0.0; - extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights, - wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC); -} - -template -void gemm_inner_product_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); - - diff_dst += diff_dst_d.offset0(); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int IC = pd()->IC_total_padded(); - - bool wei_tr = memory_desc_matches_one_of_tag( - *pd()->diff_weights_md(), hwio, dhwio, io); - - float alpha = 1.0, beta = 0.0; - if (wei_tr) - extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC, - &beta, diff_weights, &OC); - else - extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC, - &beta, diff_weights, &IC); - - if (diff_bias) { - diff_bias += diff_bias_d.offset0(); - constexpr int blksize = 8; - const int OC_blocks = OC / blksize; - const int rem_OC = OC % blksize; - parallel(0, [&](const int ithr, const int nthr) { - int oc_st{0}, oc_e{0}; - balance211(OC_blocks, nthr, ithr, oc_st, oc_e); - oc_st = oc_st * blksize; - oc_e = oc_e * blksize; - - PRAGMA_OMP_SIMD() - for (int oc = oc_st; oc < oc_e; ++oc) { - diff_bias[oc] = diff_dst[oc]; - } - - for (int mb = 1; mb < MB; ++mb) { - PRAGMA_OMP_SIMD() - for (int oc = oc_st; oc < oc_e; ++oc) { - diff_bias[oc] += diff_dst[mb * OC + oc]; - } - } - - if (rem_OC != 0 && ithr == nthr-1) { - for (int oc = OC_blocks * blksize; oc < OC; oc++) - diff_bias[oc] = diff_dst[oc]; - for (int mb = 1; mb < MB; ++mb) { - for (int oc = OC_blocks * blksize; oc < OC; oc++) { - diff_bias[oc] += diff_dst[mb * OC + oc]; - } - } - } - }); - } -} - -template struct gemm_inner_product_fwd_t; -template struct gemm_inner_product_bwd_data_t; -template struct gemm_inner_product_bwd_weights_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp deleted file mode 100644 index acf0a49b9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp +++ /dev/null @@ -1,157 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_GEMM_INNER_PRODUCT_HPP -#define CPU_GEMM_INNER_PRODUCT_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "gemm/gemm.hpp" - -#include "cpu_inner_product_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct gemm_inner_product_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_fwd_pd_t { - using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; - - DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t); - - status_t init() { - using namespace utils; - - bool ok = true - && set_default_params() == status::success - && is_fwd() - && !has_zero_dim_memory() - && everyone_is(data_type, - src_md()->data_type, - weights_md()->data_type, - dst_md()->data_type, - with_bias() ? weights_md(1)->data_type : data_type) - && attr()->output_scales_.has_default_values() - && attr()->post_ops_.len_ <= 1 - && IMPLICATION(attr()->post_ops_.len_ == 1, - attr()->post_ops_.entry_[0].is_relu(true, false)) - && dense_gemm_consitency_check(src_md(), weights_md(), - dst_md()); - return ok ? status::success : status::unimplemented; - } - }; - - gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct gemm_inner_product_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_bwd_data_pd_t { - using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; - - DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t); - - status_t init() { - bool ok = true - && set_default_params() == status::success - && desc()->prop_kind == prop_kind::backward_data - && !has_zero_dim_memory() - && utils::everyone_is(data_type, - diff_src_md()->data_type, - weights_md()->data_type, - diff_dst_md()->data_type) - && attr()->has_default_values() - && dense_gemm_consitency_check(diff_src_md(), weights_md(), - diff_dst_md()); - return ok ? status::success : status::unimplemented; - } - }; - - gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_bwd_weights_pd_t { - using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; - - DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t); - - status_t init() { - bool ok = true - && set_default_params() == status::success - && desc()->prop_kind == prop_kind::backward_weights - && !has_zero_dim_memory() - && utils::everyone_is(data_type, - src_md()->data_type, - diff_weights_md()->data_type, - diff_dst_md()->data_type, - with_bias() ? diff_weights_md(1)->data_type : data_type) - && attr()->has_default_values() - && dense_gemm_consitency_check(src_md(), diff_weights_md(), - diff_dst_md()); - - return ok ? status::success : status::unimplemented; - } - }; - - gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp deleted file mode 100644 index fed7e4d69..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp +++ /dev/null @@ -1,740 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "utils.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "math_utils.hpp" - -#include "simple_q10n.hpp" - -#include "gemm/gemm.hpp" -#include "gemm_x8s8s32x_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::math; -using namespace mkldnn::impl::memory_tracking::names; - -template -void _gemm_x8s8s32x_convolution_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const { - auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - auto scratchpad = this->scratchpad(ctx); - - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - assert(IMPLICATION( - jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); - assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, - scratchpad); - }); -} - -template -_gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::pp_ker_t( - const pd_t *pd) - : ker_(nullptr) - , jcp_(pd->jcp_) - , OC_(pd->jcp_.oc) - , OS_(pd->jcp_.os) - , bias_data_type_(data_type::undef) - , bias_data_type_size_(0) - , scale_idx_mult_(0) - , do_bias_(false) - , do_relu_(false) - , do_sum_(false) -{ - using namespace types; - - const auto dst_md = memory_desc_wrapper(pd->dst_md()); - dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1); - - scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); - - auto &post_ops = pd->attr()->post_ops_; - - int entry_idx = -1; - for (int idx = 0; idx < post_ops.len_; ++idx) { - const auto &e = post_ops.entry_[idx]; - if (e.is_relu(true, false)) { - entry_idx = idx; - break; - } - } - do_relu_ = entry_idx >= 0; - - do_signed_scaling_ = jcp_.signed_input; - - do_sum_ = post_ops.contain(primitive_kind::sum, 0); - do_bias_ = pd->with_bias(); - bias_data_type_ = pd->desc()->bias_desc.data_type; - if (do_bias_) { - assert(bias_data_type_ != data_type::undef); - bias_data_type_size_ = data_type_size(bias_data_type_); - } - const size_t vlen_start - = cpu_isa_traits::vlen / sizeof(float); - - for (size_t i = vlen_start; i > 0; i--) { - if (OC_ % i == 0) { - vlen_ = i; - break; - } - } - - if (!mayiuse(avx512_core)) - // use fallback code for older CPUs - return; - else - generate(); -} - -template -void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::generate() -{ - using namespace Xbyak; - using namespace utils; - - // TODO: clean-up - Reg64 reg_param = abi_param1; - Reg64 reg_dst = rdx; - Reg64 reg_acc = rax; - Reg64 reg_bias = rbx; - Reg64 reg_scales = rsi; - - Reg64 reg_len = r8; - Reg64 reg_tmp = rcx; // intentional for shifting purposes - Reg64 reg_oc_offset = r9; - Reg64 reg_rem_mask_short = r10; - Reg64 reg_rem_mask_vlen = r11; - Opmask kreg_rem_mask_short = k1; - Opmask kreg_rem_mask_vlen = k3; - Opmask kreg_relu_cmp = k2; - - const size_t vlen = vlen_; - - Zmm vreg_zero = Zmm(0); - Zmm vreg_scale = Zmm(1); - Zmm vreg_nslope = Zmm(2); - Zmm vreg_sum_scale = Zmm(3); - Zmm vreg_signed_scale = Zmm(4); - - size_t def_unroll = 4; - size_t max_unroll = 12; - size_t zmm_step = 2; - if (do_sum_) { - max_unroll = 8; - zmm_step = 3; - } - - auto vreg_dst = [&](int idx) { - return Zmm(5 + idx * zmm_step + 0); - }; - auto vreg_bias = [&](int idx) { - return Zmm(5 + idx * zmm_step + 1); - }; - auto vreg_prev_dst = [&](int idx) { - return Zmm(5 + idx * zmm_step + 2); - }; - - preamble(); - -#define PARAM_OFF(x) offsetof(ker_args, x) - mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); - mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); - mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); - mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); - mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); - mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); - vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); - vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); - vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]); - if (scale_idx_mult_ == 0) - vbroadcastss(vreg_scale, dword[reg_scales]); - -#undef PARAM_OFF - - mov(reg_rem_mask_vlen, 1); - shl(reg_rem_mask_vlen, vlen); - sub(reg_rem_mask_vlen, 1); - kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen); - - if (do_relu_ || dst_type == data_type::u8) - vxorps(vreg_zero, vreg_zero, vreg_zero); - - // Load accumulated value, convert to float, apply sum (if any), - // bias (if any), scaling, and relu (if any); - // then convert to destination type and store - auto compute = [&](size_t offset, int idx, bool apply_mask) { - auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; - - if (scale_idx_mult_ > 0) { - assert(scale_idx_mult_ == 1); - auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; - auto vreg_scale_ = vreg_scale; - if (apply_mask) - vreg_scale_ = vreg_scale_ | kreg_rem_mask_short; - else - vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen; - vmovups(vreg_scale_, scale_addr); - } - - auto vreg_dst_ = vreg_dst(idx); - if (apply_mask) - vreg_dst_ = vreg_dst_ | kreg_rem_mask_short; - else - vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen; - vcvtdq2ps(vreg_dst_, acc_addr); - - if (do_signed_scaling_) - vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale); - - if (do_bias_) { - auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; - auto vreg_bias_ = vreg_bias(idx); - if (apply_mask) - vreg_bias_ = vreg_bias_ | kreg_rem_mask_short; - else - vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen; - - switch (bias_data_type_) { - case data_type::s8: - vpmovsxbd(vreg_bias_, bias_addr); - break; - case data_type::u8: - vpmovzxbd(vreg_bias_, bias_addr); - break; - case data_type::s32: - case data_type::f32: - vmovups(vreg_bias_, bias_addr); - break; - default: assert(!"unimplemented"); - } - if (bias_data_type_ != data_type::f32) - vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); - vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); - } - - vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); - - auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; - - if (do_sum_) - { - auto vreg_prev_dst_ = vreg_prev_dst(idx); - if (apply_mask) - vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short; - else - vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen; - - switch (dst_type) { - case data_type::f32: - case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break; - case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break; - case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break; - default: assert(!"unsupported data type"); - } - if (dst_type != data_type::f32) - vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx)); - - vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); - } - - if (do_relu_) { - vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); - vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); - } - - if (dst_type != data_type::f32) { - vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); - } - - if (dst_type == data_type::u8) - vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero); - - switch (dst_type) { - case data_type::s8: - vpmovsdb(dst_addr, vreg_dst_); - break; - case data_type::u8: - vpmovusdb(dst_addr, vreg_dst_); - break; - case data_type::f32: - case data_type::s32: - vmovups(dst_addr, vreg_dst_); - break; - default: assert(!"unimplemented"); - } - }; - - // Advance all pointers by an immediate - auto advance_ptrs_imm = [&](size_t offset) { - add(reg_dst, offset * sizeof(dst_data_t)); - add(reg_acc, offset * sizeof(acc_data_t)); - if (scale_idx_mult_) { - assert(scale_idx_mult_ == 1); - add(reg_scales, offset * sizeof(float)); - } - if (do_bias_) - add(reg_bias, offset * bias_data_type_size_); - }; - - // Advance all pointers by a value stored in a register - auto advance_ptrs_reg = [&](Reg64 offset) { - lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); - lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); - if (scale_idx_mult_) { - assert(scale_idx_mult_ == 1); - lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); - } - if (do_bias_) - lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); - }; - - // Rewind pointers that point to data that is indexed by output channel - // (bias or per-oc scaling factors) - auto rewind_ptrs = [&]() { - if (do_bias_) - sub(reg_bias, OC_ * bias_data_type_size_); - if (scale_idx_mult_) { - assert(scale_idx_mult_ == 1); - sub(reg_scales, OC_ * sizeof(float)); - } - add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t)); - }; - - // <--------- OC ---------------> - // - // ^ ................+..............+-------------+....................... - // | . : not accessed |Prologue loop| . - // | . +--------------+-------------+ . - // . | | . - // O . | Main loop (unrolled) | . - // S . | | . - // . +--------------+-------------+ . - // | . | Epilogue loop|not accessed : . - // v ................+--------------+.............+....................... - - Label prologue_end; - cmp(reg_oc_offset, 0); - je(prologue_end, T_NEAR); - - // Prologue loop - { - mov(reg_tmp, OC_); - sub(reg_tmp, reg_oc_offset); - cmp(reg_tmp, reg_len); - cmovg(reg_tmp, reg_len); - sub(reg_len, reg_tmp); - - Label prologue_loop, prologue_loop_tail, prologue_loop_end; - cmp(reg_tmp, vlen); - jle(prologue_loop_tail, T_NEAR); - L(prologue_loop); { - compute(0, 0, false); - advance_ptrs_imm(vlen); - sub(reg_tmp, vlen); - cmp(reg_tmp, vlen); - jge(prologue_loop, T_NEAR); - } - - L(prologue_loop_tail); - mov(reg_rem_mask_short, 1); - // cl == reg_tmp because reg_tmp <= vlen here - shl(reg_rem_mask_short, cl); - sub(reg_rem_mask_short, 1); - jz(prologue_loop_end, T_NEAR); - - kmovq(kreg_rem_mask_short, reg_rem_mask_short); - compute(0, 0, true); - advance_ptrs_reg(reg_tmp); - - L(prologue_loop_end); - rewind_ptrs(); - } - L(prologue_end); - - // Main loop - Label main_loop_end; - { - cmp(reg_len, OC_); - jle(main_loop_end, T_NEAR); - - Label main_loop; - L(main_loop); { - size_t OC_loop, OC_tail; - if (OC_ < max_unroll * vlen) { - // Fully unroll small loops - OC_loop = 0; - OC_tail = OC_; - } - else { - OC_loop = vlen * def_unroll; - OC_tail = OC_ % OC_loop; - } - - assert(!!OC_loop || !!OC_tail); - - if (OC_tail % vlen) { - int vlen_tail = OC_tail % vlen; - unsigned tail_mask = (1 << vlen_tail) - 1; - mov(reg_tmp, tail_mask); - kmovq(kreg_rem_mask_short, reg_tmp); - } - - if (OC_loop) { - mov(reg_tmp, rnd_dn(OC_, OC_loop)); - Label oc_loop; - L(oc_loop); { - for (size_t offset = 0; offset < OC_loop; offset += vlen) - compute(offset, offset / vlen, false); - advance_ptrs_imm(OC_loop); - sub(reg_tmp, OC_loop); - jnz(oc_loop); - } - } - - if (OC_tail) { - for (size_t offset = 0; offset < OC_tail; offset += vlen) { - bool use_mask = (offset + vlen) > OC_tail; - compute(offset, offset / vlen, use_mask); - } - advance_ptrs_imm(OC_tail); - } - - rewind_ptrs(); - sub(reg_len, OC_); - cmp(reg_len, OC_); - jge(main_loop, T_NEAR); - } - } - L(main_loop_end); - - // Epilogue loop - Label epilogue_end; - { - cmp(reg_len, 0); - je(epilogue_end, T_NEAR); - - Label epilogue_loop, epilogue_loop_tail; - cmp(reg_len, vlen); - jle(epilogue_loop_tail, T_NEAR); - L(epilogue_loop); { - compute(0, 0, false); - sub(reg_len, vlen); - advance_ptrs_imm(vlen); - cmp(reg_len, vlen); - jge(epilogue_loop, T_NEAR); - } - - L(epilogue_loop_tail); - mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift - mov(reg_rem_mask_short, 1); - shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen - sub(reg_rem_mask_short, 1); - jz(epilogue_end, T_NEAR); - kmovq(kreg_rem_mask_short, reg_rem_mask_short); - compute(0, 0, true); - } - - L(epilogue_end); - - postamble(); - - ker_ = getCode(); -} - -template -void _gemm_x8s8s32x_convolution_fwd_t::pp_ker_t::operator () - (dst_data_t *dst, const acc_data_t *acc, const char *bias, - const float *scales, float nslope, float sum_scale, float signed_scale, - int g, size_t start, size_t end) -{ - using math::get_bias; - - if (end <= start) - return; - - if (ker_) { - // JIT - ker_args args; - size_t oc_offset = start % OC_; - size_t os_offset = start / OC_; - args.acc = acc + start; - args.dst = dst + os_offset * dst_os_stride_ + oc_offset; - args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; - args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); - args.nslope = nslope; - args.sum_scale = sum_scale; - args.signed_scale = signed_scale; - args.len = end - start; - args.oc_offset = oc_offset; - ker_(&args); - } - else { - // Fallback - const size_t first_oc = start % OC_; - const size_t last_oc = (end - 1) % OC_; - const size_t first_os = start / OC_; - const size_t last_os = (end - 1) / OC_; - for (size_t os = first_os; os <= last_os; os++) { - const size_t start_oc = (os == first_os) ? first_oc : 0; - const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; - for (size_t oc = start_oc; oc <= end_oc; oc++) { - const size_t acc_off = os * jcp_.oc + oc; - const size_t dst_off = os * dst_os_stride_ + oc; - - float d = (float)(acc[acc_off]); - if (jcp_.signed_input) - d *= signed_scale; - - if (do_bias_) - d += get_bias(bias, g * jcp_.oc + oc, - bias_data_type_); - - d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; - if (do_sum_) - d += sum_scale * dst[dst_off]; - if (do_relu_ && d < 0) - d *= nslope; - dst[dst_off] = qz_a1b0()(d); - } - } - } -}; - -template -void _gemm_x8s8s32x_convolution_fwd_t:: -execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base, - const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base, - const memory_tracking::grantor_t &scratchpad) const { - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - const auto src_md = memory_desc_wrapper(pd()->src_md()); - const size_t src_mb_stride = src_md.blk_off(1); - const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic; - - const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); - const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; - - const auto dst_md = memory_desc_wrapper(pd()->dst_md()); - const size_t dst_mb_stride = dst_md.blk_off(1); - const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc; - - const float *scales = pd()->attr()->output_scales_.scales_; - - const auto &post_ops = pd()->attr()->post_ops_; - const bool do_sum = post_ops.contain(primitive_kind::sum, 0); - const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; - - float nslope = 0; - for (int idx = 0; idx < post_ops.len_; ++idx) { - const auto &e = post_ops.entry_[idx]; - if (e.is_relu(true, false)) { - nslope = e.eltwise.alpha; - break; - } - } - - auto col = scratchpad.get(key_conv_gemm_col) - + (ptrdiff_t)ithr * jcp.im2col_sz; - src_data_t *__restrict imtr = scratchpad.get(key_conv_gemm_imtr) - + (ptrdiff_t)ithr * jcp.is * jcp.ic; - auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) - + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; - - const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc; - const int32_t *_wei_comp = (const int32_t *)(wei_base + offset); - - int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 }; - size_t start = 0, end = 0; - - const int nb_oh = div_up(jcp.oh, jcp.oh_block); - const int nb_ow = div_up(jcp.ow, jcp.ow_block); - const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, - nb_oh, owb, nb_ow); - - for (size_t iwork = start; iwork < end; ++iwork) { - int oh = ohb * jcp.oh_block; - int ow = owb * jcp.ow_block; - const src_data_t *__restrict src = src_base + n * src_mb_stride - + g * src_g_stride; - const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; - dst_data_t *__restrict dst = - dst_base + n * dst_mb_stride + g * dst_g_stride; - const int32_t *wei_comp = _wei_comp + g * jcp.oc; - const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); - const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); - - if (jcp.im2col_sz) - jit_gemm_convolution_utils::im2col_u8( - jcp, src, imtr, col, oh, h_step, ow, w_step); - - const int M = jcp.oc; - const int K = jcp.ks * jcp.ic; - const int N = h_step * w_step; - const int LDA = M * jcp.ngroups; - const int LDB = jcp.im2col_sz ? N : K; - const char *BT = jcp.im2col_sz ? "T" : "N"; - const int8_t off_a = 0, off_b = 0; - const int32_t off_c = 0; - const float onef = 1.0, zerof = 0.0; - gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", - &M, &N, &K, &onef, wei, &LDA, &off_a, - jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b, - &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); - - auto wei_adj_scale = - (wei_md.extra().flags | memory_extra_flags::scale_adjust) - ? wei_md.extra().scale_adjust : 1.f; - - parallel(0, [&](int ithr, int nthr) { - size_t start, end; - balance211((size_t)N * jcp.oc, nthr, ithr, start, end); - (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_, - acc, bia_base, scales, nslope, sum_scale, - 1.f / wei_adj_scale, g, start, end); - }); - - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, - owb, nb_ow); - } -} - -template -void _gemm_u8s8s32x_convolution_bwd_data_t:: -execute_backward_data(const exec_ctx_t &ctx) const { - auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - auto scratchpad = this->scratchpad(ctx); - - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base, - bia_base, diff_src_base, scratchpad); - }); -} - -template -void _gemm_u8s8s32x_convolution_bwd_data_t:: -execute_backward_data_thr(const int ithr, const int nthr, - const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, - const char *bia_base, diff_src_data_t *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const -{ - const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - - const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); - const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1); - const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc; - - const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); - const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; - - const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md()); - const size_t diff_src_mb_stride = diff_src_md.blk_off(1); - const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic; - const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1); - - /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ - const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); - const float *scales = pd()->attr()->output_scales_.scales_; - const size_t work_amount = jcp.ngroups * jcp.mb; - - auto col = scratchpad.get(key_conv_gemm_col) - + (ptrdiff_t)ithr * jcp.im2col_sz; - auto acc = scratchpad.get(key_conv_int_dat_in_acc_dt) - + (ptrdiff_t)ithr * jcp.is * jcp.ic; - - int n{0}, g{0}; - size_t start = 0, end = 0; - - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); - - for (size_t iwork = start; iwork < end; ++iwork) { - const diff_dst_data_t *diff_dst = diff_dst_base - + n * diff_dst_mb_stride + g * diff_dst_g_stride; - const wei_data_t *wei = wei_base + g * wei_g_stride; - diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride - + g * diff_src_g_stride; - - const int M = jcp.ks * jcp.ic; - const int N = jcp.os; - const int K = jcp.oc; - const int8_t off_a = 0, off_b = 0; - const int32_t off_c = 0; - const float onef = 1.0, zerof = 0.0; - const int LD = K * jcp.ngroups; - - gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, - wei, &LD, &off_a, diff_dst, &LD, &off_b, - &zerof, jcp.im2col_sz ? col : acc, &M, &off_c); - - if (jcp.im2col_sz) - jit_gemm_convolution_utils::col2im_s32(jcp, col, acc); - - parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) { - float d = (float)acc[is * jcp.ic + ic]; - if (jcp.with_bias) - d += get_bias(bia_base, g * jcp.ic + ic, - pd()->desc()->bias_desc.data_type); - d *= scales[(g * jcp.ic + ic) * scale_idx_mult]; - const size_t diff_src_off = is * diff_src_os_stride + ic; - diff_src[diff_src_off] = - qz_a1b0()(d); - }); - nd_iterator_step(n, jcp.mb, g, jcp.ngroups); - } -} - -using namespace data_type; - -template struct _gemm_x8s8s32x_convolution_fwd_t; -template struct _gemm_x8s8s32x_convolution_fwd_t; -template struct _gemm_x8s8s32x_convolution_fwd_t; -template struct _gemm_x8s8s32x_convolution_fwd_t; - -template struct _gemm_x8s8s32x_convolution_fwd_t; -template struct _gemm_x8s8s32x_convolution_fwd_t; -template struct _gemm_x8s8s32x_convolution_fwd_t; -template struct _gemm_x8s8s32x_convolution_fwd_t; - -template struct _gemm_u8s8s32x_convolution_bwd_data_t; -template struct _gemm_u8s8s32x_convolution_bwd_data_t; -template struct _gemm_u8s8s32x_convolution_bwd_data_t; -template struct _gemm_u8s8s32x_convolution_bwd_data_t; -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp deleted file mode 100644 index 9e77b890d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp +++ /dev/null @@ -1,266 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP -#define GEMM_X8S8S32X_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_primitive_conf.hpp" -#include "jit_generator.hpp" -#include "gemm_convolution_utils.hpp" - -#include "gemm/gemm.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, - _gemm_x8s8s32x_convolution_fwd_t); - - status_t init() { - using namespace data_type; - - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, s8, data_type::undef, dst_type, - s32) - && IMPLICATION(with_bias(), utils::one_of( - desc()->bias_desc.data_type, f32, s32, s8, u8)) - && !has_zero_dim_memory() - && set_default_formats_common( - dat_tag(), format_tag::any, dat_tag()) - && post_ops_ok() - && memory_desc_matches_tag(*src_md(), dat_tag()) - && memory_desc_matches_tag(*dst_md(), dat_tag()) - && set_or_check_wei_format(); - if (!ok) return status::unimplemented; - - auto scratchpad = scratchpad_registry().registrar(); - return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, - *desc(), src_md(), weights_md(0), dst_md(), - mkldnn_get_max_threads()); - } - - jit_gemm_conv_conf_t jcp_; - - protected: - format_tag_t dat_tag() const { return format_tag::nhwc; } - - bool set_or_check_wei_format() { - using namespace format_tag; - - const bool is_src_s8 = src_md_.data_type == data_type::s8; - - memory_desc_t want_wei_md = weights_md_; - memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio); - - if (is_src_s8) { - want_wei_md.extra.flags = 0 - | memory_extra_flags::compensation_conv_s8s8 - | memory_extra_flags::scale_adjust; - want_wei_md.extra.compensation_mask = (1 << 0) - + (with_groups() ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust = - mayiuse(avx512_core_vnni) ? 1.f : 0.5f; - } - - if (weights_md_.format_kind == format_kind::any) { - weights_md_ = want_wei_md; - return true; - } - - return weights_md_ == want_wei_md; - } - - bool post_ops_ok() const { - using namespace mkldnn::impl::primitive_kind; - auto const &po = attr()->post_ops_; - auto is_relu = [&](int idx) { - return po.entry_[idx].is_relu(true, false); }; - - switch (po.len_) { - case 0: return true; - case 1: return is_relu(0) || po.contain(sum, 0); - case 2: return po.contain(sum, 0) && is_relu(1); - default: return false; - } - return false; - } - }; - - _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd, true), pp_ker_(nullptr) - { pp_ker_ = new pp_ker_t(pd()); } - ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - // XXX: this is throwaway code that will become unnecessary when we have a - // sufficiently advanced igemm jit generator that supports quantization, - // relu, and whatnot - class pp_ker_t : jit_generator { - public: - DECLARE_CPU_JIT_AUX_FUNCTIONS( - _gemm_x8s8s32x_convolution_fwd_t::pp_kernel); - pp_ker_t(const pd_t *pd); - - void operator()(dst_data_t *dst, const acc_data_t *acc, - const char *bias, const float *scales, - float nslope, float sum_scale, float signed_scale, - int g, size_t start, size_t end); - - size_t dst_os_stride_; - - private: - void generate(); - - struct ker_args { - dst_data_t *dst; - const acc_data_t *acc; - const char *bias; - const float *scales; - float nslope; - float sum_scale; - float signed_scale; - size_t len; - size_t oc_offset; - }; - void(*ker_)(const ker_args *args); - - const jit_gemm_conv_conf_t &jcp_; - size_t OC_; - size_t OS_; - data_type_t bias_data_type_; - size_t bias_data_type_size_; - size_t scale_idx_mult_; - bool do_bias_; - bool do_relu_; - bool do_sum_; - bool do_signed_scaling_; - size_t vlen_; - }; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - void execute_forward(const exec_ctx_t &ctx) const; - void execute_forward_thr(const int ithr, const int nthr, - const src_data_t *src_base, const wei_data_t *wei_base, - const char *bia_base, dst_data_t *dst_base, - const memory_tracking::grantor_t &scratchpad) const; - - int nthr_; - pp_ker_t *pp_ker_; - -}; - -template -struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t{ - pd_t(engine_t *engine, - const convolution_desc_t *adesc, const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, - _gemm_u8s8s32x_convolution_bwd_data_t); - - status_t init() { - using namespace data_type; - - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(dst_type, s8, data_type::undef, u8, s32) - && IMPLICATION(with_bias(), utils::one_of( - desc()->bias_desc.data_type, f32, s32, s8, u8)) - && !has_zero_dim_memory() - && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) - && attr()->post_ops_.has_default_values() - && memory_desc_matches_tag(*diff_src_md(), dat_tag()) - && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) - && memory_desc_matches_tag(*weights_md(), wei_tag()); - if (!ok) return status::unimplemented; - - auto scratchpad = scratchpad_registry().registrar(); - return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, - *desc(), diff_src_md(), weights_md(), diff_dst_md(), - mkldnn_get_max_threads()); - } - - virtual bool support_bias() const override { return true; } - - jit_gemm_conv_conf_t jcp_; - - protected: - format_tag_t dat_tag() const { return format_tag::nhwc; } - - format_tag_t wei_tag() const { - return with_groups() ? format_tag::hwigo : format_tag::hwio; - } - }; - - _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd) - : cpu_primitive_t(apd, true) {} - - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_src_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - void execute_backward_data_thr(const int ithr, const int nthr, - const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, - const char *bia_base, diff_src_data_t *diff_src_base, - const memory_tracking::grantor_t &scratchpad) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp deleted file mode 100644 index 1e435a233..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp +++ /dev/null @@ -1,453 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "simple_q10n.hpp" - -#include "gemm/gemm.hpp" -#include "gemm_x8s8s32x_inner_product.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace math; -using namespace format_tag; -using namespace memory_tracking::names; - -template -gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::pp_kernel_t( - const pd_t *pd, bool dst_is_acc) - : ker_(nullptr), OC_(pd->OC()) - , bias_data_type_(data_type::undef), bias_data_type_size_(0) - , scale_idx_mult_(0), do_bias_(false), do_relu_(false) -{ - using namespace types; - - scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); - - auto &post_ops = pd->attr()->post_ops_; - do_relu_ = post_ops.len_ == 1; - do_bias_ = pd->with_bias(); - bias_data_type_ = pd->desc()->bias_desc.data_type; - if (do_bias_) { - assert(bias_data_type_ != data_type::undef); - bias_data_type_size_ = data_type_size(bias_data_type_); - } - - if (!mayiuse(avx512_core)) - // use fallback code for older CPUs since they do not have optimized - // x8s8s32 GEMM anyways. The configuration variables above are used by - // the fallback code. - return; - else - generate(); -} - -template -void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::generate() -{ - using namespace Xbyak; - using namespace utils; - - // TODO: clean-up - Reg64 reg_param = abi_param1; - Reg64 reg_dst = rdx; - Reg64 reg_acc = rax; - Reg64 reg_bias = rbx; - Reg64 reg_scales = rsi; - - Reg64 reg_len = r8; - Reg64 reg_tmp = rcx; // intentional for shifting purposes - Reg64 reg_oc_offset = r9; - Reg64 reg_rem_mask = r10; - Opmask kreg_rem_mask = k1; - Opmask kreg_relu_cmp = k2; - - const size_t vlen = cpu_isa_traits::vlen / sizeof(float); - - Zmm vreg_zero = Zmm(0); - Zmm vreg_scale = Zmm(1); - Zmm vreg_nslope = Zmm(2); - - auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); }; - auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); }; - - preamble(); - -#define PARAM_OFF(x) offsetof(ker_args, x) - mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); - mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); - mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); - mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); - mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); - mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); - vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); - if (scale_idx_mult_ == 0) - vbroadcastss(vreg_scale, dword[reg_scales]); -#undef PARAM_OFF - - if (do_relu_ || dst_type == data_type::u8) - vxorps(vreg_zero, vreg_zero, vreg_zero); - - // Load accumulated value, convert to float, apply bias (if any), scaling, - // and relu (if any); then convert to destination type and store - auto compute = [&](size_t offset, int idx, bool apply_mask) { - auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; - - if (scale_idx_mult_ > 0) { - assert(scale_idx_mult_ == 1); - auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; - auto vreg_scale_ = vreg_scale; - if (apply_mask) - vreg_scale_ = vreg_scale_ | kreg_rem_mask; - vmovups(vreg_scale, scale_addr); - } - - auto vreg_dst_ = vreg_dst(idx); - if (apply_mask) - vreg_dst_ = vreg_dst_ | kreg_rem_mask; - vcvtdq2ps(vreg_dst_, acc_addr); - - if (do_bias_) { - auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; - auto vreg_bias_ = vreg_bias(idx); - if (apply_mask) - vreg_bias_ = vreg_bias_ | kreg_rem_mask; - - switch (bias_data_type_) { - case data_type::s8: - vpmovsxbd(vreg_bias_, bias_addr); - break; - case data_type::u8: - vpmovzxbd(vreg_bias_, bias_addr); - break; - case data_type::s32: - case data_type::f32: - vmovups(vreg_bias_, bias_addr); - break; - default: assert(!"unimplemented"); - } - if (bias_data_type_ != data_type::f32) - vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); - vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); - } - - vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); - if (do_relu_) { - vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); - vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); - } - - if (dst_type == data_type::u8) - vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero); - - if (dst_type != data_type::f32) { - vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); - } - - auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; - switch (dst_type) { - case data_type::s8: - vpmovsdb(dst_addr, vreg_dst_); - break; - case data_type::u8: - vpmovusdb(dst_addr, vreg_dst_); - break; - case data_type::f32: - case data_type::s32: - vmovups(dst_addr, vreg_dst_); - break; - default: assert(!"unimplemented"); - } - }; - - // Advance all pointers by an immediate - auto advance_ptrs_imm = [&](size_t offset) { - add(reg_dst, offset * sizeof(dst_data_t)); - add(reg_acc, offset * sizeof(acc_data_t)); - if (scale_idx_mult_) { - assert(scale_idx_mult_ == 1); - add(reg_scales, offset * sizeof(float)); - } - if (do_bias_) - add(reg_bias, offset * bias_data_type_size_); - }; - - // Advance all pointers by a value stored in a register - auto advance_ptrs_reg = [&](Reg64 offset) { - lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); - lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); - if (scale_idx_mult_) { - assert(scale_idx_mult_ == 1); - lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); - } - if (do_bias_) - lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); - }; - - // Rewind pointers that point to data that is indixed by output channel - // (bias or per-oc scaling factors) - auto rewind_ptrs = [&]() { - if (do_bias_) - sub(reg_bias, OC_ * bias_data_type_size_); - if (scale_idx_mult_) { - assert(scale_idx_mult_ == 1); - sub(reg_scales, OC_ * sizeof(float)); - } - }; - - // <-------------------- OC -------------------------------> - // - // ^ +....................+----------------------------------+ - // | : not accessed | Prologue loop | - // | +--------------------+----------------------------------+ - // | | - // M | Main loop (unrolled) | - // B | | - // +--------------------------------+----------------------+ - // | | Epilogue loop | not accessed : - // v +--------------------------------+......................+ - - Label prologue_end; - cmp(reg_oc_offset, 0); - je(prologue_end, T_NEAR); - - // Prologue loop - { - mov(reg_tmp, OC_); - sub(reg_tmp, reg_oc_offset); - cmp(reg_tmp, reg_len); - cmovg(reg_tmp, reg_len); - sub(reg_len, reg_tmp); - - Label prologue_loop, prologue_loop_tail, prologue_loop_end; - cmp(reg_tmp, vlen); - jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?) - L(prologue_loop); { - compute(0, 0, false); - advance_ptrs_imm(vlen); - sub(reg_tmp, vlen); - cmp(reg_tmp, vlen); - jge(prologue_loop, T_NEAR); - } - - L(prologue_loop_tail); - mov(reg_rem_mask, 1); - shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here - sub(reg_rem_mask, 1); - jz(prologue_loop_end, T_NEAR); - - kmovq(kreg_rem_mask, reg_rem_mask); - compute(0, 0, true); - advance_ptrs_reg(reg_tmp); - - L(prologue_loop_end); - rewind_ptrs(); - } - L(prologue_end); - - // Main loop - Label main_loop_end; - { - cmp(reg_len, OC_); - jle(main_loop_end, T_NEAR); - - Label main_loop; - L(main_loop); { - size_t def_unroll = 4; - size_t max_unroll = 13; - - size_t OC_loop, OC_tail; - if (OC_ < max_unroll * vlen) { - // Fully unroll small loops - OC_loop = 0; - OC_tail = OC_; - } else { - OC_loop = vlen * def_unroll; - OC_tail = OC_ % OC_loop; - } - - assert(!!OC_loop || !!OC_tail); - - if (OC_tail % vlen) { - int vlen_tail = OC_tail % vlen; - unsigned tail_mask = (1 << vlen_tail) - 1; - mov(reg_tmp, tail_mask); - kmovq(kreg_rem_mask, reg_tmp); - } - - if (OC_loop) { - mov(reg_tmp, rnd_dn(OC_, OC_loop)); - Label oc_loop; - L(oc_loop); { - for (size_t offset = 0; offset < OC_loop; offset += vlen) - compute(offset, offset / vlen, false); - advance_ptrs_imm(OC_loop); - sub(reg_tmp, OC_loop); - jnz(oc_loop); - } - } - - if (OC_tail) { - for (size_t offset = 0; offset < OC_tail; offset += vlen) { - bool use_mask = (offset + vlen) > OC_tail; - compute(offset, offset / vlen, use_mask); - } - advance_ptrs_imm(OC_tail); - } - - rewind_ptrs(); - sub(reg_len, OC_); - cmp(reg_len, OC_); - jge(main_loop, T_NEAR); - } - } - L(main_loop_end); - - // Epilogue loop - Label epilogue_end; - { - cmp(reg_len, 0); - je(epilogue_end, T_NEAR); - - Label epilogue_loop, epilogue_loop_tail; - cmp(reg_len, vlen); - jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?) - L(epilogue_loop); { - compute(0, 0, false); - sub(reg_len, vlen); - advance_ptrs_imm(vlen); - cmp(reg_len, vlen); - jge(epilogue_loop, T_NEAR); - } - - L(epilogue_loop_tail); - mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift - mov(reg_rem_mask, 1); - shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16 - sub(reg_rem_mask, 1); - jz(epilogue_end, T_NEAR); - kmovq(kreg_rem_mask, reg_rem_mask); - compute(0, 0, true); - } - - L(epilogue_end); - - postamble(); - - ker_ = getCode(); -} - -template -void gemm_x8s8s32x_inner_product_fwd_t::pp_kernel_t::operator ()( - dst_data_t *dst, const acc_data_t *acc, - const char *bias, const float *scales, float nslope, - size_t start, size_t end) -{ - using math::get_bias; - - if (end <= start) - return; - - if (ker_) { - // JIT - ker_args args; - size_t oc_offset = start % OC_; - args.dst = dst + start; - args.acc = acc + start; - args.bias = bias + oc_offset * bias_data_type_size_; - args.scales = scales + scale_idx_mult_ * oc_offset; - args.nslope = nslope; - args.len = end - start; - args.oc_offset = oc_offset; - ker_(&args); - } else { - // Fallback - size_t oc = start % OC_; - for (size_t i = start; i < end; i++) { - float d = (float)acc[i]; - float b = get_bias(bias, oc, bias_data_type_); - d = d + b; - d *= scales[oc * scale_idx_mult_]; - if (do_relu_ && d < 0) - d *= nslope; - dst[i] = qz_a1b0()(d); - oc = (oc == OC_ - 1) ? 0 : oc + 1; - } - } -}; - -template -void gemm_x8s8s32x_inner_product_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - - bool wei_tr = memory_desc_matches_one_of_tag( - *pd()->weights_md(), oiw, oihw, oidhw, oi); - - const int M = OC; - const int N = MB; - const int K = pd()->IC_total_padded(); - const int8_t off_a = 0, off_b = 0; - const int32_t off_c = 0; - - const float *scales = pd()->attr()->output_scales_.scales_; - - const auto &post_ops = pd()->attr()->post_ops_; - const bool do_relu = post_ops.len_ == 1; - const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; - - acc_data_t *acc = pd()->dst_is_acc_ - ? (acc_data_t *)dst - : scratchpad(ctx).template get(key_iprod_int_dat_in_acc_dt); - - const float onef = 1.0, zerof = 0.0; - gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights, - wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c); - - if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_ - || pd()->with_bias()) { - const bool force_sequential = MB * OC < 2000; - parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { - size_t start, end; - balance211((size_t)OC * MB, nthr, ithr, start, end); - (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end); - }); - } -} - -using namespace data_type; - -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; -template struct gemm_x8s8s32x_inner_product_fwd_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp deleted file mode 100644 index ac6a5c8f8..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp +++ /dev/null @@ -1,166 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP -#define GEMM_X8S8S32X_INNER_PRODUCT_HPP - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "gemm/gemm.hpp" -#include "jit_generator.hpp" - -#include "cpu_inner_product_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_fwd_pd_t { - using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; - - DECLARE_COMMON_PD_T(src_type == data_type::u8 - ? IGEMM_S8U8S32_IMPL_STR - : IGEMM_S8S8S32_IMPL_STR, - gemm_x8s8s32x_inner_product_fwd_t); - - status_t init() { - using namespace data_type; - - bool ok = true - && set_default_params() == status::success - && is_fwd() - && !has_zero_dim_memory() - && src_md()->data_type == src_type - && dst_md()->data_type == dst_type - && weights_md()->data_type == s8 - && IMPLICATION(with_bias(), utils::one_of( - weights_md(1)->data_type, f32, s32, s8, u8)) - && attr()->post_ops_.len_ <= 1 - && IMPLICATION(attr()->post_ops_.len_, - attr()->post_ops_.entry_[0].is_relu(true, false)) - && dense_gemm_consitency_check(src_md(), weights_md(), - dst_md()); - if (!ok) return status::unimplemented; - - dst_is_acc_ = utils::one_of(dst_type, s32, f32); - - init_scratchpad(); - - return status::success; - } - - bool dst_is_acc_; - - protected: - status_t set_default_params() { - using namespace format_tag; - if (src_md_.format_kind == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md_, - utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc))); - } - if (dst_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(dst_md_, nc)); - if (weights_md_.format_kind == format_kind::any) { - CHECK(memory_desc_init_by_tag(weights_md_, - utils::pick(ndims() - 2, io, wio, hwio, dhwio))); - } - return inner_product_fwd_pd_t::set_default_params(); - } - - private: - void init_scratchpad() { - if (!dst_is_acc_) { - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book( - memory_tracking::names::key_iprod_int_dat_in_acc_dt, - sizeof(acc_data_t) * MB() * OC()); - } - } - }; - - gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd, true) - { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); } - ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; } - - typedef typename prec_traits::type data_t; - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - // XXX: this is throwaway code that will become unnecessary when we have a - // sufficiently advanced igemm jit generator that supports quantization, - // relu, and whatnot - class pp_kernel_t: jit_generator { - public: - DECLARE_CPU_JIT_AUX_FUNCTIONS( - gemm_x8s8s32x_inner_product_fwd_t::pp_kernel); - pp_kernel_t(const pd_t *pd, bool dst_is_acc); - - void operator()(dst_data_t *dst, const acc_data_t *acc, - const char *bias, const float *scales, float nslope, - size_t start, size_t end); - private: - void generate(); - - struct ker_args { - dst_data_t *dst; - const acc_data_t *acc; - const char *bias; - const float *scales; - float nslope; - size_t len; - size_t oc_offset; - }; - void (*ker_)(const ker_args *args); - - size_t OC_; - data_type_t bias_data_type_; - size_t bias_data_type_size_; - size_t scale_idx_mult_; - bool do_bias_; - bool do_relu_; - }; - - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - pp_kernel_t *pp_kernel_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp deleted file mode 100644 index 6fa251d46..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp +++ /dev/null @@ -1,674 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* Copyright 2018 YANDEX LLC -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_memory.hpp" - -#include "jit_avx2_1x1_conv_kernel_f32.hpp" - -#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) -{ - mov(aux1_reg_bcast_data, reg_bcast_data); - mov(aux_reg_output_data, reg_output_data); - mov(bcast_loop_iter, reg_bcast_loop_work); - - Label bcast_loop, bcast_loop_tail; - - cmp(bcast_loop_iter, jcp.ur); - jl(bcast_loop_tail, T_NEAR); - - L(bcast_loop); { - assert(jcp.bcast_block % jcp.ur == 0); - int num_substeps = jcp.bcast_block / jcp.ur; - assert(num_substeps > 0 && num_substeps < 10); - for (int i = 0; i < num_substeps; i++) { - generate_reduce_loop(load_loop_blk, jcp.ur); - if (i < num_substeps - 1) { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_substep); - } else { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step - - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_step - - (num_substeps - 1) * jcp.bcast_loop_output_substep); - } - } - sub(bcast_loop_iter, jcp.bcast_block); - cmp(bcast_loop_iter, jcp.bcast_block); - jge(bcast_loop, T_NEAR); - } - - L(bcast_loop_tail); - if (jcp.ur_tail) { - Label bcast_loop_tail_out; - cmp(bcast_loop_iter, 0); - jz(bcast_loop_tail_out, T_NEAR); - generate_reduce_loop(load_loop_blk, jcp.ur_tail); - L(bcast_loop_tail_out); - } -} - -void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( - int load_loop_blk, int ur) -{ - auto vreg_load = [=](int i) { - return Ymm(ur * load_loop_blk + i); - }; - - auto vreg_accum = [=](int i, int j) { - return Ymm(j * load_loop_blk + i); - }; - - auto bias_ptr = [=](int i) { - return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; - }; - - auto bcast_ptr = [=](int u, int j) { - assert(j < jcp.ur); - assert(u <= jcp.reduce_loop_unroll); - size_t offt; - if (one_of(jcp.prop_kind, - forward_training, forward_inference, backward_data)) - { - assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) - ? jcp.oc_block : jcp.ic_block); - auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; - offt = (u == jcp.reduce_loop_unroll) - ? (height + j) * jcp.reduce_loop_unroll - : j * jcp.reduce_loop_unroll + u; - } else - offt = u * jcp.ic_block + j; - return ptr[aux_reg_bcast_data + sizeof(float) * offt]; - }; - - auto load_ptr = [=](int u, int i) { - size_t offt; - size_t u0 = u % jcp.reduce_loop_unroll; - size_t u1 = u / jcp.reduce_loop_unroll; - switch (jcp.prop_kind) { - case backward_data: - offt = (i * jcp.oc_block + u0) * jcp.ic_block; - break; - case backward_weights: - offt = (i * jcp.os + u0) * jcp.oc_block; - break; - default: - offt = (i * jcp.ic + u0) * jcp.oc_block; - } - return ptr[aux_reg_load_data - + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt]; - }; - - auto output_ptr = [=](int i, int j) { - switch (jcp.prop_kind) { - case backward_data: - return ptr[aux_reg_output_data + - (i * jcp.is + j) * jcp.ic_block * sizeof(float)]; - case backward_weights: - return ptr[aux_reg_output_data - + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale - + sizeof(float) * jcp.oc_block * j]; - default: - return ptr[aux_reg_output_data + - (i * jcp.os + j) * jcp.oc_block * sizeof(float)]; - } - }; - - auto init = [=]() { - Label init_done, init_zero; - - if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, - forward_inference)) { - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jz(init_zero); - - for (int i = 0; i < load_loop_blk; i++) - for (int j = 0; j < ur; ++j) - vmovups(vreg_accum(i, j), bias_ptr(i)); - jmp(init_done); - } - - L(init_zero); - for (int i = 0; i < load_loop_blk; ++i) - for (int j = 0; j < ur; ++j) { - auto r = vreg_accum(i, j); - vxorps(r, r, r); - } - - L(init_done); - for (int i = 0; i < load_loop_blk; ++i) - vmovups(vreg_load(i), load_ptr(0, i)); - vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); - }; - - auto store = [=]() { - Label store_noadd; - - if (!jcp.with_sum) { - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jnz(store_noadd, T_NEAR); - } - - for (int j = 0; j < ur; ++j) - for (int i = 0; i < load_loop_blk; ++i) { - auto r = vreg_accum(i, j); - vaddps(r, r, output_ptr(i, j)); - } - - L(store_noadd); - - if (jcp.with_eltwise) { - assert(ur * load_loop_blk < 14); - - Label store_norelu; - test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); - jz(store_norelu, T_NEAR); - - eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); - - L(store_norelu); - } - - for (int j = 0; j < ur; ++j) - for (int i = 0; i < load_loop_blk; ++i) { - vmovups(output_ptr(i, j), vreg_accum(i, j)); - } - }; - - auto fma_block = [=](bool last_block) { - for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { - for (int j = 0; j < ur; ++j) { - for (int i = 0; i < load_loop_blk; ++i) { - if (mayiuse(avx2)) - vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast); - else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support - vmulps(vtmp, vreg_bcast, vreg_load(i)); - vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp); - } - if (j == ur - 1 && !(last_block - && u == jcp.reduce_loop_unroll - 1)) - vmovups(vreg_load(i), load_ptr(u + 1, i)); - } - if (j < ur - 1) - vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); - } - if (!last_block || u < jcp.reduce_loop_unroll - 1) - vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); - } - }; - - Label reduce_loop, reduce_loop_tail; - - mov(aux_reg_load_data, reg_load_data); - mov(aux_reg_bcast_data, aux1_reg_bcast_data); - - init(); - - mov(reduce_loop_iter, reg_reduce_loop_work); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jle(reduce_loop_tail, T_NEAR); - - L(reduce_loop); { - fma_block(false); - add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); - add(aux_reg_load_data, jcp.reduce_loop_load_step); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jg(reduce_loop, T_NEAR); - } - - L(reduce_loop_tail); - fma_block(true); - - store(); -} - -void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) -{ - if (!jcp.with_bias || jcp.prop_kind != backward_weights) - return; - - Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; - Label diff_bias_load; - - auto diff_bias_ptr = [=](int i) { - return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; - }; - - auto load_ptr = [=](int u, int i) { - return ptr[aux_reg_load_data - + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; - }; - - auto diff_bias_reg = [=](int i) { return Ymm(i); }; - - mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); - cmp(reg_diff_bias_data, 0); - je(diff_bias_loop_out, T_NEAR); - - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jz(diff_bias_load, T_NEAR); - - for (int i = 0; i < load_loop_blk; ++i) { - auto r = diff_bias_reg(i); - vxorps(r, r, r); - } - jmp(diff_bias_init_out, T_NEAR); - - L(diff_bias_load); - for (int i = 0; i < load_loop_blk; ++i) - vmovups(diff_bias_reg(i), diff_bias_ptr(i)); - - L(diff_bias_init_out); - mov(aux_reg_load_data, reg_load_data); - mov(reduce_loop_iter, reg_reduce_loop_work); - L(diff_bias_loop); { - for(int u = 0; u < jcp.reduce_loop_unroll; ++u) - for (int i = 0; i < load_loop_blk; ++i) - vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); - assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); - add(aux_reg_load_data, jcp.reduce_loop_load_step); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jnz(diff_bias_loop, T_NEAR); - } - - for (int i = 0; i < load_loop_blk; i++) - vmovups(diff_bias_ptr(i), diff_bias_reg(i)); - add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); - mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); - - L(diff_bias_loop_out); -} - -void jit_avx2_1x1_conv_kernel_f32::generate() -{ - preamble(); - - mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); - mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); - mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); - if (jcp.with_bias) { - if (jcp.prop_kind == backward_weights) { - sub(rsp, stack_space_needed); - mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); - mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); - } else - mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); - } - - mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); - mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); - mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); - mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); - if (jcp.prop_kind == backward_weights) - mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); - - auto generate_load_loop_body = [=] (int load_loop_blk) { - generate_bcast_loop(load_loop_blk); - add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); - switch (jcp.prop_kind) { - case forward_training: - case forward_inference: - add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); - add(reg_output_data, - load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); - break; - case backward_data: - add(reg_output_data, - load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); - break; - case backward_weights: - for (int i = 0; i < load_loop_blk; i++) - add(reg_output_data, reg_output_stride); - break; - default: - assert(!"invalid prop_kind"); - } - sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); - }; - - Label load_loop_blk_8; - Label load_loop_blk_16; - Label load_loop_blk_24; - Label load_loop_blk_end; - - cmp(reg_load_loop_work, 8); - jle(load_loop_blk_8, T_NEAR); - - cmp(reg_load_loop_work, 32); - je(load_loop_blk_16, T_NEAR); - - cmp(reg_load_loop_work, 16); - jle(load_loop_blk_16, T_NEAR); - - L(load_loop_blk_24); { - generate_diff_bias_loop(3); - generate_load_loop_body(3); - cmp(reg_load_loop_work, 32); - je(load_loop_blk_16); - cmp(reg_load_loop_work, 24); - jge(load_loop_blk_24); - } - - cmp(reg_load_loop_work, 8); - jle(load_loop_blk_8, T_NEAR); - - L(load_loop_blk_16); { - generate_diff_bias_loop(2); - generate_load_loop_body(2); - cmp(reg_load_loop_work, 16); - jge(load_loop_blk_16); - } - - L(load_loop_blk_8); { - cmp(reg_load_loop_work, 0); - je(load_loop_blk_end, T_NEAR); - generate_diff_bias_loop(1); - generate_load_loop_body(1); - } - - L(load_loop_blk_end); - - if (jcp.with_bias && jcp.prop_kind == backward_weights) - add(rsp, 8); - - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok( - jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr) -{ - if (!mayiuse(avx)) return status::unimplemented; - - // TODO (Roma): this code is duplicated from the generic kernel; maybe the - // configuration struct could do some stuff below - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - const int ndims = src_d.ndims(); - - jcp.prop_kind = cd.prop_kind; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; - jcp.ow = dst_d.dims()[ndims - 1]; - - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; - jcp.l_pad = cd.padding[0][ndims - 3]; - - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; - jcp.stride_w = cd.strides[ndims - 3]; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - jcp.os = jcp.oh * jcp.ow; - jcp.is = jcp.ih * jcp.iw; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) { - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) - return status::unimplemented; - } - - const int is_bwd_d = jcp.prop_kind == backward_data; - - format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; - format_tag_t wei_tag = with_groups - ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, - gOIhw8o8i) - : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, - OIhw8o8i); - - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - - const int simd_w = 8; - - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - - bool args_ok = true - && jcp.ngroups == 1 - && jcp.src_tag == dat_tag - && jcp.wei_tag == wei_tag - && jcp.dst_tag == dat_tag; - if (!args_ok) return status::unimplemented; - - args_ok = true - && jcp.ih == jcp.oh && jcp.iw == jcp.ow - && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 - && jcp.t_pad == 0 && jcp.l_pad == 0 - && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides - && jcp.kh == 1 && jcp.kw == 1; - if (!args_ok) return status::unimplemented; - - // TODO: remove this restriction - // optimized 1x1 bwd_w does not support Intel AVX - if (jcp.prop_kind == backward_weights && !mayiuse(avx2)) - return status::unimplemented; - - jcp.ic_block = jcp.oc_block = simd_w; - - jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support - - int load_blocking{ 0 }; - int load_blocking_max{ 0 }; - int bcast_blocking{ 0 }; - int bcast_blocking_max{ 0 }; - int reduce_blocking{ 0 }; - - if (one_of(jcp.prop_kind, forward_training, forward_inference)) { - jcp.reduce_dim = jcp.ic; - jcp.reduce_block = jcp.ic_block; - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.is; - jcp.bcast_block = jcp.ur; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.is * sizeof(float); - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); - - jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); - jcp.bcast_loop_output_substep = -1; // unused - jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); - jcp.load_loop_iter_step = jcp.oc_block; - - load_blocking = 120; // assumes the kernel is jcp.ur x 3 - load_blocking_max = 144; - bcast_blocking = 128; // affects load balancing across threads - bcast_blocking_max = 192; - reduce_blocking = 128; // affects L1$ utilization - } else if (jcp.prop_kind == backward_data) { - jcp.reduce_dim = jcp.oc; - jcp.reduce_block = jcp.oc_block; - - jcp.load_dim = jcp.ic; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.os; - jcp.bcast_block = jcp.ur; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.os * sizeof(float); - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); - - jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); - jcp.bcast_loop_output_substep = -1; // unused - jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); - jcp.load_loop_iter_step = jcp.ic_block; - - load_blocking = 96; // assumes the kernel is jcp.ur x 3 - load_blocking_max = 144; - bcast_blocking = 128; // affects load balancing across threads - bcast_blocking_max = 196; - reduce_blocking = 64; // affects L1$ utilization - } else if (jcp.prop_kind == backward_weights) { - jcp.reduce_dim = jcp.os; - jcp.reduce_block = 1; - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.ic; - jcp.bcast_block = jcp.ic_block; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); - - jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); - jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); - jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); - jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); - - jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); - jcp.load_loop_iter_step = jcp.oc_block; - - /* --- */ - - load_blocking = div_up(jcp.load_dim, jcp.load_block); - while (true) { - if (load_blocking <= 32) break; - else if (load_blocking % 2 == 0) load_blocking /= 2; - else if (load_blocking % 3 == 0) load_blocking /= 3; - else break; - } - load_blocking *= jcp.load_block; - load_blocking_max = load_blocking; - assert(jcp.load_dim % load_blocking == 0); - - bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); - while (true) { - if (bcast_blocking <= 9) break; - else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; - else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; - else break; - } - bcast_blocking *= jcp.bcast_block; - bcast_blocking_max = bcast_blocking; - assert(jcp.bcast_dim % bcast_blocking == 0); - - reduce_blocking = 128; // affects L1$ utilization - } else - return status::unimplemented; - - assert(load_blocking); - assert(load_blocking_max); - assert(bcast_blocking); - assert(bcast_blocking_max); - assert(reduce_blocking); - - assert(jcp.bcast_block % jcp.ur == 0); - jcp.ur_tail = jcp.bcast_dim % jcp.ur; - - jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; - jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; - jcp.nb_load_blocking = load_blocking / jcp.load_block; - jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; - jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; - - jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); - jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - return status::success; -} - -void jit_avx2_1x1_conv_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp) { - using namespace mkldnn::impl::memory_tracking::names; - - if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) - scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp deleted file mode 100644 index bfdeb2b18..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp +++ /dev/null @@ -1,110 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP -#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "cpu_memory.hpp" -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx2_1x1_conv_kernel_f32: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32) - - jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, - const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, - jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); - } - - ~jit_avx2_1x1_conv_kernel_f32() { - delete eltwise_injector_; - } - - static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp); - - jit_1x1_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_1x1_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - using ymm_t = const Xbyak::Ymm; - - reg64_t reg_bcast_data = rax; - reg64_t reg_load_data = rsi; - reg64_t reg_output_data = rbx; - reg64_t aux_reg_bcast_data = rdx; - reg64_t aux1_reg_bcast_data = abi_not_param1; - reg64_t aux_reg_load_data = abi_param1; - reg64_t aux_reg_output_data = rbp; - reg64_t reg_load_loop_work = r9; - reg64_t reg_bcast_loop_work = r10; - reg64_t reg_reduce_loop_work = r11; - reg64_t load_loop_iter = r13; - reg64_t bcast_loop_iter = r14; - reg64_t reduce_loop_iter = r15; - reg64_t imm_addr64 = reduce_loop_iter; - reg64_t reg_reduce_pos_flag = r8; - reg64_t reg_output_stride = r12; - reg64_t reg_bias_data = r12; - reg64_t reg_diff_bias_data = bcast_loop_iter; - - int reg_diff_bias_data_stack_offt = 0; - int stack_space_needed = 8; - - ymm_t vreg_bcast = ymm_t(15); - ymm_t vtmp = ymm_t(14); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - void generate_bcast_loop(int load_loop_blk); - void generate_reduce_loop(int load_loop_blk, int ur); - void generate_diff_bias_loop(int load_loop_blk); - - void generate(); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp deleted file mode 100644 index f116ac905..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp +++ /dev/null @@ -1,545 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -#include "jit_avx2_1x1_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -#define data_blk_off(f, n, c, h, w) \ - ((ndims == 3) \ - ? (f).blk_off(n, c, w) \ - : (f).blk_off(n, c, h, w)) - -/* convolution forward */ - -void jit_avx2_1x1_convolution_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - const int ndims = dst_d.ndims(); - - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][ndims - 3]; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - auto ker = [&](const int ithr, const int nthr) { - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - - const int nb_oc = jcp.nb_load; - const int nb_ic = jcp.nb_reduce; - const int nb_ic_blocking = jcp.nb_reduce_blocking; - const int os_block = jcp.bcast_block; - - int start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - int iwork = start; - while (iwork < end) { - int n{0}, g{0}, osb{0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - - int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, end - iwork); - - const int os = osb * os_block; - const int oh = os / jcp.ow; - const int ow = os % jcp.ow; - - const int ih = nstl::max(oh * stride_h - pad_t, 0); - const int iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); - rp.os = p.bcast_dim; - - int ocb = 0; - while (ocb < jcp.nb_load) { - const int load_step = step(jcp.nb_load_blocking, - jcp.nb_load - ocb, jcp.nb_load_blocking_max); - - const int _ocb = g * nb_oc + ocb; - p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, - load_step * jcp.oc_block); - const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); - - p.output_data = &dst[dst_off]; - - p.bias_data = &bias[_ocb * jcp.oc_block]; - - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - p.first_last_flag = 0 - | (icb == 0 ? FLAG_REDUCE_FIRST : 0) - | (icb + nb_ic_blocking >= nb_ic - ? FLAG_REDUCE_LAST : 0); - - p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic, - nb_ic_blocking * jcp.ic_block); - rp.icb = p.reduce_dim / jcp.reduce_block; - - p.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - const int _icb = g * nb_ic + icb; - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_ - + _icb * jcp.is * jcp.ic_block; - - if (ocb == 0) { - rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); - rtus_driver_->ker_(&rp); - } - - p.bcast_data = rp.ws; - } else - p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); - - kernel_->jit_ker(&p); - } - - ocb += load_step; - } - - iwork += bcast_step; - } - }; - - if (pd()->wants_padded_bias()) { - auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); - utils::array_copy(padded_bias, bias, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bias = padded_bias; - } - - parallel(0, ker); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); -} - -/* convolution backward wtr data */ - -void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - - const auto &jcp = kernel_->jcp; - auto rtus_space = scratchpad(ctx).get(key_conv_rtus_space); - - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - const int ndims = diff_dst_d.ndims(); - - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][ndims - 3]; - - const int nb_ic = jcp.nb_load; - const int nb_oc = jcp.nb_reduce; - const int os_block = jcp.bcast_block; - const int nb_oc_blocking = jcp.nb_reduce_blocking; - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - auto ker = [&](const int ithr, const int nthr) { - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - - int start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - int load_step = 0; - for (int icb = 0; icb < jcp.nb_load; icb += load_step) { - load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, - jcp.nb_load_blocking_max); - - p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic, - load_step * jcp.ic_block); - rp.icb = p.load_dim / jcp.ic_block; - - int bcast_step; - for (int iwork = start; iwork < end; iwork += bcast_step) { - int n{0}, g{0}, osb{0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - - bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, end - iwork); - - const int os = osb * os_block; - p.bcast_dim = this_block_size(os, jcp.os, - bcast_step * os_block); - rp.os = p.bcast_dim; - - const int oh = os / jcp.ow; - const int ow = os % jcp.ow; - const int ih = nstl::max(oh * stride_h - pad_t, 0); - const int iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - const int _icb = g * nb_ic + icb; - rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_; - p.output_data = rp.ws; - } else - p.output_data = rp.src; - - for (int ocb = 0; ocb < jcp.nb_reduce; - ocb += jcp.nb_reduce_blocking) { - const int _ocb = g * nb_oc + ocb; - size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, - ow); - p.bcast_data = &diff_dst[diff_dst_off]; - - p.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; - - p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, - nb_oc_blocking * jcp.oc_block); - - kernel_->jit_ker(&p); - } - - if (pd()->rtus_.reduce_src_) - rtus_driver_->ker_(&rp); - } - } - }; - - parallel(0, ker); -} - -/* convolution backward wtr weights */ - -jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t( - const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr) - , rtus_driver_(nullptr) -{ - kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); - reducer_weights_ = - new cpu_reducer_2d_t(pd()->reducer_wei_conf_); - reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); - init_rtus_driver(this); -} - -void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - auto scratchpad = this->scratchpad(ctx); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); - - const auto &jcp = kernel_->jcp; - auto rtus_space = scratchpad.get(key_conv_rtus_space); - - data_t *diff_bias = pd()->wants_padded_bias() - ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; - - auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, - prefix_reducer_bia); - auto rb = this->reducer_bias_; - rb->init(reducer_bia_scratchpad); - - auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, - prefix_reducer_wei); - auto rw = this->reducer_weights_; - rw->init(reducer_wei_scratchpad); - - const int ndims = diff_dst_d.ndims(); - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - - const int nb_ic = jcp.nb_bcast; - const int nb_ic_blocking = jcp.nb_bcast_blocking; - const int bcast_work = div_up(nb_ic, nb_ic_blocking); - - const int nb_oc = jcp.nb_load; - const int nb_oc_blocking = jcp.nb_load_blocking; - const int load_work = div_up(nb_oc, nb_oc_blocking); - - const int sp_dim = jcp.reduce_dim; - const int mb_sp_work = jcp.mb * sp_dim; - - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][ndims - 3]; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image, - data_t *store_to, size_t store_to_ld, const data_t *diff_dst, - const data_t *src, int ithr) { - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - - p.output_stride = store_to_ld * sizeof(float); - const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block; - - int oc_b_step = 0; - for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) { - oc_b_step = step(12, nb_oc_blocking - oc_b, 18); - p.load_dim = oc_b_step * jcp.oc_block; - - int ic_b_step = 0; - for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) { - ic_b_step = step(12, nb_ic_blocking - ic_b, 18); - p.bcast_dim = ic_b_step * jcp.ic_block; - rp.icb = p.bcast_dim / jcp.ic_block; - - p.output_data = store_to + oc_b * store_to_ld - + ic_b * jcp.ic_block * jcp.oc_block; - - /* spatial reduction */ - int sp_step = 0; - for (int sp = sp_start; sp < sp_end; sp += sp_step) { - sp_step = step(sp_step_def, sp_end - sp, 192); - p.reduce_dim = sp_step; - rp.os = p.reduce_dim; - - p.first_last_flag = sp == sp_start && first_image - ? FLAG_REDUCE_FIRST : 0; - - p.load_data = diff_dst - + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block; - - if (pd()->rtus_.reduce_src_) { - const int oh = sp / jcp.ow; - const int ow = sp % jcp.ow; - - const int ih = nstl::max(oh * stride_h - pad_t, 0); - const int iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_ - + (ic_b * jcp.is + sp) * jcp.ic_block; - if (ndims == 3) - rp.src = src - + iw * src_d.blocking_desc().strides[2]; - else - rp.src = src - + ih * src_d.blocking_desc().strides[2] - + iw * src_d.blocking_desc().strides[3]; - - if (oc_b == 0) - rtus_driver_->ker_(&rp); - - p.bcast_data = rp.ws; - } else - p.bcast_data = src - + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block; - - kernel_->jit_ker(&p); - } - } - } - }; - - auto ker = [&](const int ithr, const int nthr) { - assert(nthr == rw->balancer().nthr_); - - const int w_njobs = rw->balancer().ithr_njobs(ithr); - if (w_njobs == 0) return; - - /* setup: independent work (oc, ic) */ - const int w_job_start = rw->balancer().ithr_job_off(ithr); - int g{0}, load_i{0}, bcast_i{0}; - nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work, - bcast_i, bcast_work); - - /* setup: reduction work (mb, sp) */ - int mb_sp_start{0}, mb_sp_end{0}; - balance211(mb_sp_work, rw->balancer().nthr_per_group_, - rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end); - int img_start{0}, sp_start{0}; - nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim); - - /* independent work */ - for (int iwork = 0; iwork < w_njobs; ++iwork) { - const int oc_b = nb_oc_blocking * load_i; - const int ic_b = nb_ic_blocking * bcast_i; - - const int _ic_b = g * nb_ic + ic_b; - const int _oc_b = g * nb_oc + oc_b; - - data_t *store_to; - size_t store_to_ld; - - if (rw->balancer().nthr_per_group_ == 1) { - const size_t off = pd()->with_groups() - ? diff_weights_d.blk_off(g, oc_b, ic_b) - : diff_weights_d.blk_off(oc_b, ic_b); - store_to = &diff_weights[off]; - store_to_ld = jcp.ic * jcp.oc_block; - } else { - const size_t off = iwork * rw->balancer().job_size_; - store_to = - rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off; - store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block; - } - - /* reduction work */ - int img = img_start; - int sp = sp_start; - int sp_step = 0; - for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) - { - sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp); - - const bool first_image = img == img_start; - oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to, - store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)], - &src[src_d.blk_off(img, _ic_b)], ithr); - - sp = 0; - img += 1; - } - - nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i, - bcast_work); - } - rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); - }; - - auto ker_bias = [&](int ithr, int nthr) { - assert(nthr == rb->balancer().nthr_); - - const int b_job_start = rb->balancer().ithr_job_off(ithr); - const int b_njobs = rb->balancer().ithr_njobs(ithr); - - if (b_njobs == 0) return; - - /* reduction dimension */ - int img_start{0}, img_end{0}; - balance211(jcp.mb, rb->balancer().nthr_per_group_, - rb->balancer().id_in_group(ithr), img_start, img_end); - - /* jobs */ - int g_start{0}, ocb_start{0}; - nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc); - - for (int img = img_start; img < img_end; ++img) { - int g = g_start, ocb = ocb_start; - for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { - const size_t _oc = g * nb_oc + ocb; - - const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; - data_t *d_bias = - rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad) - + b_job_loc * rb->balancer().job_size_; - - if (img == img_start) - for (int o = 0; o < 8; ++o) d_bias[o] = 0.; - - for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { - PRAGMA_OMP_SIMD() - for (int o = 0; o < 8; ++o) - d_bias[o] += d_dst[o]; - d_dst += 8; - } - - nd_iterator_step(g, jcp.ngroups, ocb, nb_oc); - } - } - rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); - }; - - parallel(0, [&](const int ithr, const int nthr) { - ker(ithr, nthr); - if (pd()->with_bias()) - ker_bias(ithr, nthr); - }); - - /* TODO: put this in ker_bias */ - if (pd()->wants_padded_bias()) { - assert(jcp.ngroups == 1); - for (int oc = 0; oc < jcp.oc_without_padding; ++oc) - diff_bias_in[oc] = diff_bias[oc]; - } -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp deleted file mode 100644 index 976224217..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp +++ /dev/null @@ -1,344 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP -#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" -#include "cpu_reducer.hpp" - -#include "jit_avx2_1x1_conv_kernel_f32.hpp" -#include "jit_uni_1x1_conv_utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t { - // TODO: (Roma) Code duplication duplication! Remove with templates - // (maybe...)! - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), - jit_avx2_1x1_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, dst_md()); - - status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, - *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); - - rtus_prepare_space_info(this, scratchpad); - - return status::success; - } - - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) - : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr), rtus_driver_(nullptr) - { - kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); - init_rtus_driver(this); - } - - ~jit_avx2_1x1_convolution_fwd_t() { - delete kernel_; - delete rtus_driver_; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx2_1x1_conv_kernel_f32 *kernel_; - rtus_driver_t *rtus_driver_; -}; - -struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), - jit_avx2_1x1_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::undef, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *diff_src_d = diff_src_md(); - rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); - - status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, - *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), - *attr()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); - - rtus_prepare_space_info(this, scratchpad); - - return status::success; - } - - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i) - : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr) - , rtus_driver_(nullptr) - { - kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); - init_rtus_driver(this); - } - - ~jit_avx2_1x1_convolution_bwd_data_t() { - delete kernel_; - delete rtus_driver_; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx2_1x1_conv_kernel_f32 *kernel_; - rtus_driver_t *rtus_driver_; -}; - -struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), - jit_avx2_1x1_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, diff_dst_md()); - - status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, - *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), - *attr()); - if (status != status::success) return status; - - init_balancers(); - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); - - rtus_prepare_space_info(this, scratchpad); - - auto reducer_bia_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_bia); - reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); - - auto reducer_wei_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_wei); - reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); - - return status::success; - } - - jit_1x1_conv_conf_t jcp_; - cpu_reducer_t::conf_t reducer_bia_conf_; - cpu_reducer_2d_t::conf_t reducer_wei_conf_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) - : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - - private: - void init_balancers() { - const int ic_block = jcp_.bcast_block; - const int nb_ic = jcp_.nb_bcast; - const int nb_ic_blocking = jcp_.nb_bcast_blocking; - const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); - - const int oc_block = jcp_.load_block; - const int nb_oc = jcp_.nb_load; - const int nb_oc_blocking = jcp_.nb_load_blocking; - const int load_work = utils::div_up(nb_oc, nb_oc_blocking); - - const int job_size - = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; - const int njobs_x = bcast_work; - const int njobs_y = jcp_.ngroups * load_work; - - const int max_threads = mkldnn_get_max_threads(); - const size_t max_buffer_size = max_threads * job_size * 8; - - if (with_bias()) { - reducer_bia_conf_.init(reduce_balancer_t(max_threads, - oc_block, jcp_.ngroups * jcp_.oc / oc_block, - jcp_.mb, max_buffer_size)); - } - - reducer_wei_conf_.init( - reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, - jcp_.mb * jcp_.nb_reduce, max_buffer_size), - job_size / nb_oc_blocking, nb_oc_blocking, ic_block, - nb_ic * ic_block * oc_block, nb_oc); - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd); - - ~jit_avx2_1x1_convolution_bwd_weights_t() { - delete kernel_; - delete rtus_driver_; - delete reducer_weights_; - delete reducer_bias_; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx2_1x1_conv_kernel_f32 *kernel_; - cpu_reducer_2d_t *reducer_weights_; - cpu_reducer_t *reducer_bias_; - rtus_driver_t *rtus_driver_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp deleted file mode 100644 index e24770a2d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp +++ /dev/null @@ -1,1501 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* Copyright 2018 YANDEX LLC -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "cpu_memory.hpp" - -#include "jit_avx2_conv_kernel_f32.hpp" - -#define GET_OFF(field) offsetof(jit_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, - int pad_l, int pad_r, int oc_blocks) -{ - int iw = jcp.iw; - int ih = jcp.ih; - int id = jcp.id; - int kw = jcp.kw; - int kh = jcp.kh; - int kd = jcp.kd; - int nb_ic = jcp.nb_ic; - int stride_w = jcp.stride_w; - int dilate_w = jcp.dilate_w + 1; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - - for (int ki = 0; ki < kw; ki++) { - int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); - int jj_end = ur_w - - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w)); - for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { - for (int jj = jj_start; jj < jj_end; jj++) { - size_t inp_off; - if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) - inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw - + (ki*dilate_w + jj*stride_w - pad_l)); - else - inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w - - pad_l)*ic_blk + ifm2); - vbroadcastss(Ymm(oc_blocks * ur_w + jj), - make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); - } - - for (int ii = 0; ii < oc_blocks; ii++) { - int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk - + ki * ic_blk * oc_blk + ifm2 * oc_blk; - vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]); - for (int jj = jj_start; jj < jj_end; jj++) - if (mayiuse(avx2)) - vfmadd231ps(Ymm(ur_w * ii + jj), - Ymm(oc_blocks * ur_w + jj), ymm15); - else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support - vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); - vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); - } - } - } - } -} - -void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, - int pad_l, int pad_r, char pad_tag, - int oc_blocks, char oc_blocks_tag) -{ - Label kw_loop; - - int iw = jcp.iw; - int ih = jcp.ih; - int id = jcp.id; - int kw = jcp.kw; - int kh = jcp.kh; - int kd = jcp.kd; - int nb_ic = jcp.nb_ic; - int stride_w = jcp.stride_w; - int dilate_w = jcp.dilate_w + 1; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - - xor_(ki_iter, ki_iter); - L(kw_loop); - { - int jj_start = 0; - int jj_end = ur_w; - for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { - for (int jj = jj_start; jj < jj_end; jj++) { - size_t inp_off; - if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) - inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw - + (jj * stride_w - pad_l)); - else - inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk - + ifm2); - vbroadcastss(Ymm(oc_blocks * ur_w + jj), - make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); - } - for (int ii = 0; ii < oc_blocks; ii++) { - int aux_kernel_offset = - ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk; - vmovups(ymm15, ptr[aux_reg_kernel - + sizeof(float) * aux_kernel_offset]); - for (int jj = jj_start; jj < jj_end; jj++) - if (mayiuse(avx2)) - vfmadd231ps(Ymm(ur_w * ii + jj), - Ymm(oc_blocks * ur_w + jj), ymm15); - else { // Intel AVX support - vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); - vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); - } - } - } - add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); - add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? dilate_w : ic_blk * dilate_w)); - - inc(ki_iter); - cmp(ki_iter, kw); - jl(kw_loop, T_NEAR); - } -} - -void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w, - int pad_l, int pad_r, char pad_tag, - int oc_blocks, char oc_blocks_tag) -{ - int iw = jcp.iw; - int kw = jcp.kw; - int ow = jcp.ow; - int oh = jcp.oh; - int od = jcp.od; - int dilate_h = jcp.dilate_h + 1; - int dilate_w = jcp.dilate_w + 1; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? 1 : ic_blk; - const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? dilate_w : ic_blk * dilate_w; - - Label init_done, init_first; - - if (!jcp.with_sum) { - test(reg_ci_flag, FLAG_IC_FIRST); - jne(init_first, T_NEAR); - } - - for (int ii = 0; ii < oc_blocks; ii++) { - for (int jj = 0; jj < ur_w; jj++) { - size_t offt = - sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; - vmovups(Ymm(ur_w * ii + jj), - make_safe_addr(reg_output, offt, reg_long_offt)); - } - } - - if (jcp.with_sum && jcp.with_bias) { - test(reg_ci_flag, FLAG_IC_FIRST); - je(init_done, T_NEAR); - - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), - yword[reg_bias + sizeof(float) * ii * oc_blk]); - } - - jmp(init_done); - - L(init_first); - if (this->jcp.with_bias) { - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - vmovups(Ymm(ur_w * ii + jj), - yword[reg_bias + sizeof(float) * ii * oc_blk]); - } else { - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj)); - } - - L(init_done); - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - } - - Label skip_kh_loop, skip_kd_loop, kd_loop; - if (jcp.ndims == 5) { - push(reg_output); - push(oi_iter); - - mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); - mov(aux_reg_inp_d, reg_input); - - if ((jcp.dilate_d >= jcp.id) - || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { - cmp(reg_ki, 0); - je(skip_kd_loop, T_NEAR); - } - L(kd_loop); - mov(kj, ptr[param1 + GET_OFF(kh_padding)]); - } else { - mov(kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_input, aux_reg_inp_d); - mov(aux_reg_kernel, aux_reg_ker_d); - } - - if ((jcp.dilate_h >= jcp.ih) - || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { - cmp(kj, 0); - je(skip_kh_loop, T_NEAR); - } - Label kh_loop; - L(kh_loop); - { - if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { - oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks, - oc_blocks_tag); - sub(aux_reg_input, sizeof(float) * kw * inp_off); - add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); - } else { - oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); - add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); - add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); - } - - dec(kj); - cmp(kj, 0); - jg(kh_loop, T_NEAR); - } - - L(skip_kh_loop); - - if (jcp.ndims == 5) { - add(aux_reg_inp_d, - sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult); - add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block - * jcp.ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_loop, T_NEAR); - L(skip_kd_loop); - - pop(oi_iter); - pop(reg_output); - } - - Label regular_store; - - if (jcp.with_eltwise) { - test(reg_ci_flag, FLAG_IC_LAST); - je(regular_store, T_NEAR); - - eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w); - - L(regular_store); - } - - for (int ii = 0; ii < oc_blocks; ii++) { - for (int jj = 0; jj < ur_w; jj++) { - const size_t o_off - = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; - Ymm reg_out = Ymm(ur_w * ii + jj); - vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out); - } - } -} - -inline void jit_avx2_conv_fwd_kernel_f32::solve_common( - int oc_blocks, char oc_blocks_tag) -{ - int ur_w = jcp.ur_w; - int ur_w_tail = jcp.ur_w_tail; - int n_oi = jcp.ow / ur_w; - int iw = jcp.iw; - int kw = jcp.kw; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - int dilate_w = jcp.dilate_w + 1; - int str_w = jcp.stride_w; - const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk; - - int l_pad = jcp.l_pad; - int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w - - (iw + l_pad - 1)); - int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w - - (iw + l_pad - 1); - if (r_pad1 > 0) n_oi--; - - if (l_pad > 0) { - n_oi--; - if (n_oi < 0 && r_pad1 > 0) - width_blk_step(ur_w, l_pad, r_pad1, - 'l', oc_blocks, oc_blocks_tag); // "lrpad" - else - width_blk_step(ur_w, l_pad, 0, - 'l', oc_blocks, oc_blocks_tag); // "lpad" - add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_blk); - } - - Label ow_loop; - xor_(oi_iter, oi_iter); - - if (n_oi > 0) { - L(ow_loop); - - width_blk_step(ur_w, 0, 0, - 'm', oc_blocks, oc_blocks_tag); // "middle" - add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_blk); - - inc(oi_iter); - cmp(oi_iter, n_oi); - jl(ow_loop, T_NEAR); - } - - if (r_pad1 > 0 && n_oi >=0) { - width_blk_step(ur_w, 0, r_pad1, - 'r', oc_blocks, oc_blocks_tag); // "rpad" - add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_blk); - } - - if (ur_w_tail != 0) - width_blk_step(ur_w_tail, 0, r_pad, - 't', oc_blocks, oc_blocks_tag); // "tail" -} - -void jit_avx2_conv_fwd_kernel_f32::generate() -{ - this->preamble(); - - mov(reg_input, ptr[this->param1 + GET_OFF(src)]); - mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); - if (jcp.with_bias) - mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); - mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); - mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); - mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); - - int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; - Label tail, exit; - - if (jcp.nb_oc > jcp.nb_oc_blocking) { - cmp(reg_oc_blocks, jcp.nb_oc_blocking); - jne(nb_oc_tail ? tail : exit, T_NEAR); - - solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); - jmp(exit, T_NEAR); - - if (nb_oc_tail) { - L(tail); - cmp(reg_oc_blocks, nb_oc_tail); - jne(exit, T_NEAR); - solve_common(nb_oc_tail, '0' + nb_oc_tail); - } - - L(exit); - } else if (jcp.nb_oc == jcp.nb_oc_blocking) { - solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); - } else { - solve_common(nb_oc_tail, '0' + nb_oc_tail); - } - - this->postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr) -{ - if (!mayiuse(avx)) return status::unimplemented; - - jcp.prop_kind = cd.prop_kind; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - int ndims = src_d.ndims(); - jcp.ndims = ndims; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; - jcp.iw = src_d.dims()[ndims-1]; - jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2]; - jcp.ow = dst_d.dims()[ndims-1]; - jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; - jcp.kw = weights_d.dims()[with_groups + ndims-1]; - - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; - jcp.l_pad = cd.padding[0][ndims-3]; - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4]; - jcp.stride_w = cd.strides[ndims-3]; - - jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; - jcp.dilate_w = cd.dilates[ndims-3]; - - jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) - - (jcp.ih + jcp.t_pad - 1); - - if (ndims == 3) { - jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); - jcp.wei_tag = weights_d.matches_one_of_tag( - Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); - jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); - } else if (ndims == 4) { - jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); - jcp.wei_tag = weights_d.matches_one_of_tag( - Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); - jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); - } else if (ndims == 5) { - jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); - jcp.wei_tag = weights_d.matches_one_of_tag( - Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); - jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c); - } - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) { - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) - return status::unimplemented; - } - - const int simd_w = 8; - const bool flat = jcp.ic < simd_w; - const bool mimo = !flat; - - - /* Grouped channel offset to support 'non-blocked data' format for - * convolution sizes with '(input_channel / ngroups) < simd' */ - jcp.nonblk_group_off = - one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1; - - bool ok_to_pad_channels = true - && jcp.ngroups == 1; - - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - if (mimo) - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - bool args_ok = true - && IMPLICATION(flat, true - && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) - && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, - gOdhwi8o)) - && IMPLICATION(mimo, true - && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) - && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, - OIdhw8i8o, gOIdhw8i8o)) - && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c); - if (!args_ok) return status::unimplemented; - - jcp.ur_h = 1; /* no code-unrolling by h so far */ - jcp.ur_w = 3; - - jcp.oc_block = simd_w; - jcp.nb_oc = jcp.oc / jcp.oc_block; - - jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ - - // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively - // Thus, we can only assign 14 or 15 YMMs for data storage - const int num_avail_regs = mayiuse(avx2) ? 15 : 14; - if (!mayiuse(avx2)) { - if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { - // current register assignment requires more YMMs than available - // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad - if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1) - jcp.ur_w -= 1; - else - for (int b = 3; b > 1; b--) - if (jcp.nb_oc % b == 0) { - jcp.nb_oc_blocking = b; - break; - } - } - } - - if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - - args_ok = true - && jcp.oc % simd_w == 0 - && jcp.l_pad <= jcp.ur_w - && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) - || (jcp.stride_w == 1 && jcp.stride_h == 1)) - && IMPLICATION(mimo, jcp.ic % simd_w == 0); - if (!args_ok) return status::unimplemented; - - int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - - if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { - /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ - jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, - nstl::min(jcp.ow, num_avail_regs / 2)); - jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - /* check again ... */ - r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) - return status::unimplemented; - } - assert(jcp.nb_oc_blocking > 0); - assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); - - jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - if (one_of(jcp.prop_kind, forward_training, forward_inference)) { - jcp.nb_ic_blocking = 12; - jcp.nb_ic_blocking_max = 16; - } else { - jcp.nb_ic_blocking = 1; - jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; - } - - return status::success; -} - -void jit_avx2_conv_fwd_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) - scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); -} - -void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow, - int r_overflow) -{ - int kw = jcp.kw; - int kh = jcp.kh; - int kd = jcp.kd; - int iw = jcp.iw; - int ih = jcp.ih; - int id = jcp.id; - int ow = jcp.ow; - - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int nb_ic_block = jcp.nb_ic_blocking; - int stride_w = jcp.stride_w; - int stride_h = jcp.stride_h; - - Label kd_loop, skip_kd_loop; - Label oc_loop, skip_oc_loop; - - for (int ii = 0; ii < nb_ic_block; ii++) - for (int jj = 0; jj < ur_w; jj++) { - uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), - Ymm(ur_w * ii + jj)); - } - - if (one_of(jcp.ndims, 3, 4)) { - cmp(reg_channel_work, 0); - jle(skip_oc_loop, T_NEAR); - xor_(reg_channel, reg_channel); - - mov(aux_reg_ddst_oc_loop, reg_ddst); - mov(aux_reg_kernel_oc_loop, reg_kernel); - - L(oc_loop); - mov(aux_reg_ddst, aux_reg_ddst_oc_loop); - mov(aux_reg_kernel, aux_reg_kernel_oc_loop); - } - - if (jcp.ndims == 5) { - assert(jcp.nb_oc_blocking == 1); - push(oi_iter); - - mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); - mov(aux_reg_dst_d, reg_ddst); - mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]); - - L(kd_loop); - mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]); - } else { - mov(kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_ddst, aux_reg_dst_d); - mov(aux_reg_kernel, aux_reg_ker_d); - } - - Label kh_loop, skip_kh_loop; - cmp(kj, 0); - jle(skip_kh_loop, T_NEAR); - L(kh_loop); { - for (int ki = 0; ki < kw; ki++) { - int jj_start = get_iw_start(ki, l_overflow); // 0; - int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w; - for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) { - - for (int jj = jj_start ; jj < jj_end; jj += stride_w) { - int aux_output_offset - = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2; - vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w), - ptr[aux_reg_ddst - + sizeof(float) * aux_output_offset]); - } - - for (int ii = 0; ii < nb_ic_block; ii++) { - int aux_kernel_offset - = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block - + ki * jcp.ic_block * jcp.oc_block - + ofm2 * jcp.ic_block; - vmovups(ymm15, - ptr[aux_reg_kernel - + sizeof(float) * aux_kernel_offset]); - for (int jj = jj_start; jj < jj_end; jj += stride_w) - vfmadd231ps(Ymm(ur_w * ii + jj), - Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15); - } - } - } - add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block - * ic_block); - sub(aux_reg_ddst, sizeof(float) * ow * oc_block); - - dec(kj); - cmp(kj, 0); - jg(kh_loop, T_NEAR); - } - L(skip_kh_loop); - - if (jcp.ndims == 5) { - sub(aux_reg_dst_d, - sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); - add(aux_reg_ker_d, - sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_loop, T_NEAR); - L(skip_kd_loop); - - pop(oi_iter); - } - - if (one_of(jcp.ndims, 3, 4)) { - int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow - * jcp.oc_block; - int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw - * jcp.ic * jcp.oc_block; - - add(aux_reg_ddst_oc_loop, ddst_oc_shift); - add(aux_reg_kernel_oc_loop, kernel_oc_shift); - - inc(reg_channel); - cmp(reg_channel, reg_channel_work); - jl(oc_loop, T_NEAR); - - L(skip_oc_loop); - mov(reg_channel, ptr[param1 + GET_OFF(channel)]); - } - - Label no_update_label; - cmp(reg_channel, 0); - je(no_update_label, T_NEAR); - for (int ii = 0; ii < nb_ic_block; ii++) { - for (int jj = 0; jj < ur_w; jj++) { - size_t offt = - sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; - vmovups(Ymm(15), - make_safe_addr(reg_dsrc, offt, reg_long_offt)); - vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), - Ymm(15)); - - } - } - L(no_update_label); - - for (int ii = 0; ii < nb_ic_block; ii++) - for (int jj = 0; jj < ur_w; jj++) { - size_t offt = - sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; - vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt), - Ymm(ur_w * ii + jj)); - } -} - -void jit_avx2_conv_bwd_data_kernel_f32::generate() { - preamble(); - - mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); - mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); - mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); - mov(reg_channel, ptr[param1 + GET_OFF(channel)]); - mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]); - - int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block; - int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block; - - int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); - int r_overflow = nstl::max(0, (jcp.kw - 1 - - nstl::max(0, jcp.r_pad)) / jcp.stride_w); - int r_overflow1 = nstl::max(0, (jcp.kw - 1 - - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); - - int n_oi = jcp.iw / jcp.ur_w; - if (r_overflow1 > 0) - n_oi--; - - if (jcp.ur_w == jcp.iw) { - compute_loop(jcp.ur_w, l_overflow, r_overflow); - } else if (n_oi == 0) { - compute_loop(jcp.ur_w, l_overflow, r_overflow1); - add(reg_dsrc, dsrc_shift); - add(reg_ddst, ddst_shift); - if (jcp.ur_w_tail != 0) - compute_loop(jcp.ur_w_tail, 0, r_overflow); - } else { - xor_(oi_iter, oi_iter); - if (l_overflow > 0) { - compute_loop(jcp.ur_w, l_overflow, 0); - add(reg_dsrc, dsrc_shift); - add(reg_ddst, ddst_shift); - inc(oi_iter); - } - - if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) { - Label ow_loop; - L(ow_loop); { - compute_loop(jcp.ur_w, 0, 0); - add(reg_dsrc, dsrc_shift); - add(reg_ddst, ddst_shift); - inc(oi_iter); - cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR); - } - } - - if (r_overflow1 > 0 ) { - compute_loop(jcp.ur_w, 0, r_overflow1); - add(reg_dsrc, dsrc_shift); - add(reg_ddst, ddst_shift); - } - - if (jcp.ur_w_tail != 0) - compute_loop(jcp.ur_w_tail, 0, r_overflow); - } - - this->postamble(); -} - -status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d) -{ - if (!mayiuse(avx2)) return status::unimplemented; - - const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; - - int ndims = diff_src_d.ndims(); - jcp.ndims = ndims; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = diff_src_d.dims()[0]; - - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; - - jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; - jcp.iw = diff_src_d.dims()[ndims-1]; - jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; - jcp.ow = diff_dst_d.dims()[ndims-1]; - - jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; - jcp.l_pad = cd.padding[0][ndims-3]; - - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; - jcp.stride_w = cd.strides[ndims-3]; - - jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; - jcp.dilate_w = cd.dilates[ndims-3]; - - const int simd_w = 8; - - /* derivatives */ - jcp.idp = jcp.id + 2 * jcp.f_pad; - jcp.ihp = jcp.ih + 2 * jcp.t_pad; - jcp.iwp = jcp.iw + 2 * jcp.l_pad; - jcp.ohp = jcp.oh; /* do we really need */ - jcp.owp = jcp.ow; /* padded output ??? */ - - bool ok_to_pad_channels = true - && jcp.ngroups == 1; - - /* gemm-based convolution performs better in these cases */ - if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1) - return status::unimplemented; - - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - jcp.oc_block = simd_w; - if (jcp.oc % jcp.oc_block) return status::unimplemented; - jcp.nb_oc = jcp.oc / jcp.oc_block; - - jcp.ur_h = 1; /* no code-unrolling by h so far */ - jcp.nb_ic_blocking = 1; - jcp.nb_oc_blocking = 1; - jcp.ur_w = 1; - - if(one_of(ndims, 3, 4) && jcp.ow < 40) - jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2; - - if (ndims == 3) { - jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c); - jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); - } else if (ndims == 4) { - jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c); - jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); - } else if (ndims == 5) { - jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c); - jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); - } - - bool args_ok = true - && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) - && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i, - gOIdhw8o8i, OIdhw8o8i) - && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) - && jcp.stride_w == jcp.stride_h - && jcp.stride_d == 1 - && jcp.dilate_d == 0 - && jcp.dilate_h == 0 - && jcp.dilate_w == 0 - && jcp.ic % simd_w == 0 - && jcp.oc % simd_w == 0 - && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1 - && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 - && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; - if (!args_ok) return status::unimplemented; - jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad; - jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad; - int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); - - const int max_regs = 15; /* Maximun number of registers available for - result accumulation and delta dst data. - One additional register is reserved for weights - data. */ - - /* Find the best blocking with maximum number of fma instructions - per ur_w * nb_ic_blocking compute loops. Number of required registers - is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. - ur_w must be divisible by stride_w */ - if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers - distribution exceeds max_regs */ - return status::unimplemented; - - int best_nfmas = 0; - for (int b = 1; b <= 4; b++) - { - if (jcp.nb_ic % b != 0) - continue; - - for (int u = jcp.stride_w; - u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w; - u += jcp.stride_w) - { - int ur_w = nstl::min(u, jcp.iw); - /* maximum 1 step with l_overflow so far */ - if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) - continue; - int nfmas = utils::div_up(ur_w, jcp.stride_w) * b; - if (nfmas > best_nfmas - || (nfmas == best_nfmas && jcp.ur_w < ur_w)) { - jcp.ur_w = ur_w; - jcp.nb_ic_blocking = b; - best_nfmas = nfmas; - } - } - } - if (best_nfmas == 0) /* can't find appropriate blocking */ - return status::unimplemented; - - jcp.ur_w_tail = jcp.iw % jcp.ur_w; - - int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail - - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); - /* maximum 1 ur_w block with r_overflow so far */ - if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) - return status::unimplemented; - - if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) - return status::unimplemented; - - return status::success; -} - -void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - UNUSED(scratchpad); - UNUSED(jcp); -} - -void jit_avx2_conv_bwd_weights_kernel_f32::generate() { - this->preamble(); - - mov(reg_input, ptr[this->param1 + GET_OFF(src)]); - mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); - compute_oh_loop_common(); - this->postamble(); -} - -status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &diff_weights_d, - const memory_desc_wrapper &diff_dst_d) { - if (!mayiuse(avx2)) return status::unimplemented; - - const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; - int ndims = src_d.ndims(); - jcp.ndims = ndims; - - jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; - jcp.iw = src_d.dims()[ndims-1]; - jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; - jcp.ow = diff_dst_d.dims()[ndims-1]; - - jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; - jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; - - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; - jcp.l_pad = cd.padding[0][ndims-3]; - - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; - jcp.stride_w = cd.strides[ndims-3]; - - jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; - jcp.dilate_w = cd.dilates[ndims-3]; - - if (ndims == 3) { - jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); - jcp.wei_tag = diff_weights_d.matches_one_of_tag( - Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); - } else if (ndims == 4) { - jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); - jcp.wei_tag = diff_weights_d.matches_one_of_tag( - Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); - } else if (ndims == 5) { - jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); - jcp.wei_tag = diff_weights_d.matches_one_of_tag( - Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); - } - jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; - - const bool flat = jcp.ic == 3; - const bool mimo = !flat; - - const int simd_w = 8; - - jcp.b_pad = nstl::max( - 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); - jcp.r_pad = nstl::max( - 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); - - int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id - - jcp.f_pad); - if (ndims == 5) - if (jcp.f_pad != 0 || back_pad != 0) - return status::unimplemented; - - const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1); - const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1); - const bool boundaries_ok = true - && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad - && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad; - if (!boundaries_ok) - return status::unimplemented; - - bool ok_to_pad_channels = true - && jcp.ngroups == 1; - - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - if (mimo) - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - bool args_ok = true - && IMPLICATION(flat, true - && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) - && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, - gOdhwi8o)) - && IMPLICATION(mimo, true - && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) - && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, - OIdhw8i8o, gOIdhw8i8o)) - && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) - && IMPLICATION(mimo, jcp.ic % simd_w == 0) - && jcp.oc % simd_w == 0 - && jcp.kw < 14 - && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */ - && jcp.kh <= jcp.ih /* [bwd_w:r2] */ - && jcp.kd <= jcp.f_pad + jcp.id - && jcp.kd <= jcp.id - && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */ - && jcp.dilate_d == 0 - && jcp.dilate_h == 0 - && jcp.dilate_w == 0; - if (!args_ok) return status::unimplemented; - - jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - jcp.oc_block = simd_w; - jcp.nb_oc = jcp.oc / jcp.oc_block; - jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; - - return status::success; -} - -void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) - scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() -{ - Label kd_comeback_loop; - mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0 - L(kd_comeback_loop); { - const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? 1 : jcp.ic_block; - sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult); - sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block - * jcp.oc_block); - dec(kj); - cmp(kj, 0); - jg(kd_comeback_loop, T_NEAR); - } -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() -{ - mov(kj, reg_kh); - Label kh_comeback_loop; - L(kh_comeback_loop); { - const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? 1 : jcp.ic_block; - sub(reg_input, sizeof(float) * jcp.iw * inp_mult); - sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_comeback_loop, T_NEAR); - } -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step( - int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset, - int kernel_offset, int output_offset) -{ - const int kw = jcp.kw; - const int ic_block = jcp.ic_block; - const int oc_block = jcp.oc_block; - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - size_t off - = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block - + kernel_offset; - vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]); - } - - for (int i_ur = 0; i_ur < ur_w; i_ur++) { - vmovups(Ymm(kw * ic_block_step + 0), - yword[reg_output - + sizeof(float) * i_ur * oc_block + output_offset]); - - for (int i_kw = 0; i_kw < kw; i_kw++) { - int i_iw = i_ur * jcp.stride_w + i_kw; - if (i_iw - pad_l < 0 - || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r) - continue; - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - size_t i_off = (size_t)input_offset + sizeof(float)*( - one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? (i_iw - pad_l) + i_ic - * ((size_t)jcp.id * jcp.ih * jcp.iw) - : (i_iw - pad_l) * ic_block + i_ic); - vbroadcastss(Ymm(kw * ic_block_step + 1), - make_safe_addr(reg_input, i_off, reg_long_offt)); - vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic), - Ymm(kw * ic_block_step + 0), - Ymm(kw * ic_block_step + 1)); - } - } - } - - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - size_t off - = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block - + kernel_offset; - vmovups(yword[reg_kernel + off], - Ymm(i_kw * ic_block_step + i_ic)); - } -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() -{ - int ic_block_step; - if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { - ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block; - } else { - ic_block_step = jcp.kw > 7 ? 1 - : jcp.kw > 3 ? 2 - : jcp.kw > 1 ? 4 : 8; - } - - const int max_ur_w = jcp.ow > 56 ? 14 : 28; - - if (jcp.ow <= max_ur_w) - compute_oh_step_unroll_ow(ic_block_step, max_ur_w); - else - compute_oh_step_common(ic_block_step, max_ur_w); - - if (jcp.ndims == 5) { - od_step_comeback_pointers(); - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - } else { - oh_step_comeback_pointers(); - } -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow( - int ic_block_step, int max_ur_w) -{ - UNUSED(max_ur_w); - - const int ic_block = jcp.ic_block; - const int oc_block = jcp.oc_block; - int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; - Label kd_loop; - - const int r_pad - = nstl::max(0, - (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); - - if (jcp.ndims == 5) { - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - mov(ki, jcp.kd); - L(kd_loop); - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - } - - mov(kj, reg_kh); - Label kh_loop; - L(kh_loop); { - xor_(b_ic, b_ic); - Label ic_block_loop; - L(ic_block_loop); { - compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, - 0, 0); - size_t inp_icblk_stride = sizeof(float) * ic_block_step - * (one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? jcp.id*jcp.ih*jcp.iw : 1); - safe_add(reg_input, inp_icblk_stride, reg_long_offt); - add(reg_kernel, sizeof(float) * ic_block_step * oc_block); - add(b_ic, ic_block_step); - cmp(b_ic, ic_block); - jl(ic_block_loop, T_NEAR); - } - if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) { - size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; - safe_sub(reg_input, offt, reg_long_offt); - add(reg_input, sizeof(float) * jcp.iw); - } else { - add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); - } - add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_loop, T_NEAR); - } - - if (jcp.ndims == 5) { - add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); - add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block - * oc_block); - dec(ki); - cmp(ki, 0); - jg(kd_loop, T_NEAR); - } - -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common( - int ic_block_step, int max_ur_w) -{ - const int ic_block = jcp.ic_block; - const int oc_block = jcp.oc_block; - const int stride_w = jcp.stride_w; - int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; - Label kd_loop; - - const int r_pad = jcp.r_pad; - - int ur_w = nstl::min(jcp.ow, max_ur_w); - int ur_w_trips = jcp.ow / ur_w; - int ur_w_tail = jcp.ow % ur_w; - if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) { - if (ur_w_trips > 1) { - ur_w_tail += ur_w; - ur_w_trips--; - } else { - ur_w_tail += (ur_w - ur_w / 2); - ur_w = ur_w / 2; - } - } - const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block; - - int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult; - int output_comeback = ur_w_trips * ur_w * oc_block; - - if (jcp.ndims == 5) { - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - mov(ki, jcp.kd); - L(kd_loop); - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - } - - mov(kj, reg_kh); - Label kh_loop; - L(kh_loop); { - xor_(b_ic, b_ic); - Label ic_block_loop; - L(ic_block_loop); { - if (jcp.l_pad != 0) { - ur_w_trips--; - compute_ic_block_step(ur_w, - jcp.l_pad, 0, ic_block_step, 0, 0, 0); - add(reg_input, sizeof(float) - * (ur_w * stride_w - jcp.l_pad) * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_block); - } - - if (ur_w_trips > 0) { - xor_(reg_ur_w_trips, reg_ur_w_trips); - Label ow_block_loop; - L(ow_block_loop); { - compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); - add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_block); - - inc(reg_ur_w_trips); - cmp(reg_ur_w_trips, ur_w_trips); - jl(ow_block_loop, T_NEAR); - } - } - - if (ur_w_tail > 0) - compute_ic_block_step(ur_w_tail, - 0, r_pad, ic_block_step, 0, 0, 0); - - sub(reg_input, sizeof(float) * input_comeback); - sub(reg_output, sizeof(float) * output_comeback); - - size_t inp_icblk_stride = sizeof(float) * ic_block_step - * (one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? jcp.id*jcp.ih*jcp.iw : 1); - safe_add(reg_input, inp_icblk_stride, reg_long_offt); - add(reg_kernel, sizeof(float) * ic_block_step * oc_block); - - add(b_ic, ic_block_step); - cmp(b_ic, jcp.ic_block); - jl(ic_block_loop, T_NEAR); - } - if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { - size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; - safe_sub(reg_input, offt, reg_long_offt); - add(reg_input, sizeof(float) * jcp.iw); - } else { - add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); - } - add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_loop, T_NEAR); - } - - if (jcp.ndims == 5) { - add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); - add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block - * oc_block); - dec(ki); - cmp(ki, 0); - jg(kd_loop, T_NEAR); - } - -} - -inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() -{ - const int icoc_block = jcp.ic_block * jcp.oc_block; - const int t_pad = jcp.t_pad; - const int stride_h = jcp.stride_h; - const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) - ? 1 : jcp.ic_block; - int b_pad = jcp.b_pad; - - Label oh_tpad_loop, oh_loop, oh_loop_end; - - mov(reg_kh, jcp.kh); - xor_(reg_ih_count, reg_ih_count); - xor_(reg_oj, reg_oj); - if (t_pad > 0) { - assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */ - mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); - add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block); - - L(oh_tpad_loop); { - compute_oh_step_disp(); - add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); - sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block); - - inc(reg_oj); - add(reg_ih_count, stride_h); - add(reg_kh, stride_h); - - /* the overlap between input and kernel may not reach kernel size. - * so far we do not support that (until we put constant here) */ - const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */ - cmp(reg_kh, final_inp_ker_overlap); - jl(oh_tpad_loop, T_NEAR); - } - - if (t_pad % stride_h != 0) { - int inp_corr = stride_h - t_pad % stride_h; - add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block); - add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult); - } - } - cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); - jge(oh_loop_end, T_NEAR); - cmp(reg_oj, jcp.oh); - jge(oh_loop, T_NEAR); - - mov(reg_kh, jcp.kh); - L(oh_loop); { - compute_oh_step_disp(); - add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); - add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); - - inc(reg_oj); - add(reg_ih_count, stride_h); - - cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); - jge(oh_loop_end, T_NEAR); - - cmp(reg_oj, jcp.oh); - jl(oh_loop, T_NEAR); - } - L(oh_loop_end); - if (b_pad > 0) { - Label oh_bpad_loop, oh_bpad_loop_end; - cmp(reg_oj, jcp.oh); - jge(oh_bpad_loop_end, T_NEAR); - - mov(reg_kh, jcp.ih + t_pad); - sub(reg_kh, reg_ih_count); - L(oh_bpad_loop); { - compute_oh_step_disp(); - add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); - add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); - - sub(reg_kh, stride_h); - cmp(reg_kh, 0); - jle(oh_bpad_loop_end, T_NEAR); - - inc(reg_oj); - cmp(reg_oj, jcp.oh); - jl(oh_bpad_loop, T_NEAR); - } - L(oh_bpad_loop_end); - } -} - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp deleted file mode 100644 index 412c50c9e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp +++ /dev/null @@ -1,225 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP -#define JIT_AVX2_CONV_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "cpu_memory.hpp" -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx2_conv_fwd_kernel_f32: public jit_generator { - jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, - const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, - jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); - } - - ~jit_avx2_conv_fwd_kernel_f32() { - delete eltwise_injector_; - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32) - - static bool post_ops_ok(jit_conv_conf_t &jcp, - const primitive_attr_t &attr); - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - reg64_t reg_input = rax; - reg64_t aux_reg_input = r8; - reg64_t reg_kernel = rdx; - reg64_t aux_reg_kernel = r9; - reg64_t reg_output = rsi; - reg64_t reg_bias = rbx; - - reg64_t aux_reg_inp_d = r11; - reg64_t aux_reg_ker_d = abi_not_param1; - - reg64_t reg_ki = rsi; - reg64_t kj = r10; - reg64_t oi_iter = r11; - reg64_t ki_iter = r12; - reg64_t reg_kh = abi_not_param1; - reg64_t reg_oc_blocks = r14; - reg64_t imm_addr64 = r15; - reg64_t reg_long_offt = r15; - Xbyak::Reg32 reg_ci_flag = r13d; - - Xbyak::Ymm ytmp = Xbyak::Ymm(14); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, - int oc_blocks); - inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, - char pad_label, int oc_blocks, char oc_blocks_label); - inline void width_blk_step(int ur_w, int pad_l, int pad_r, - char pad_label, int oc_blocks, char oc_blocks_label); - inline void solve_common(int oc_blocks, char oc_blocks_label); - - void generate(); -}; - -struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) - - jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) - { - this->generate(); - jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); - } - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - - reg64_t reg_ddst = rax; - reg64_t aux_reg_ddst = r8; - reg64_t reg_kernel = rdx; - reg64_t aux_reg_kernel = r10; - reg64_t reg_dsrc = rsi; - reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only - reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5 - case only */ - - reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only - reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only - - reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only - reg64_t kj = r11; - reg64_t oi_iter = r12; - reg64_t reg_kh = r14; - reg64_t reg_channel = r13; // used in ndims < 5 case only - reg64_t reg_channel_work = r9; // used in ndims < 5 case only - reg64_t reg_long_offt = r15; - - inline void compute_loop(int ur_w, int l_overflow, int r_overflow); - - void generate(); - - inline int get_iw_start(int ki, int l_overflow) - { - int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w - + l_overflow * jcp.stride_w - - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); - while (res < 0) - res += jcp.stride_w; - - return res; - } - - inline int get_iw_end(int ur_w, int ki, int r_overflow) - { - if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) - ur_w += nstl::min(0, jcp.r_pad); // remove negative padding - int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w - + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); - while (res < 0) - res += jcp.stride_w; - - return ur_w - res; - } -}; - -struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32) - - jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) - { - this->generate(); - jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); - } - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &diff_weights_d, - const memory_desc_wrapper &diff_dst_d); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - reg64_t reg_input = rax; - reg64_t reg_kernel = rdx; - reg64_t reg_output = rsi; - reg64_t b_ic = abi_not_param1; - reg64_t kj = r8; - reg64_t reg_kh = r9; - reg64_t reg_ur_w_trips = r10; - reg64_t reg_tmp = r11; - reg64_t reg_oj = r15; - reg64_t reg_ih_count = rbx; - reg64_t aux_reg_input = r12; - reg64_t aux_reg_kernel = r13; - reg64_t ki = r14; - reg64_t reg_long_offt = r11; - - inline void od_step_comeback_pointers(); - inline void oh_step_comeback_pointers(); - inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, - int ic_block_step, int input_offset, int kernel_offset, - int output_offset); - inline void compute_oh_step_disp(); - inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); - inline void compute_oh_step_common(int ic_block_step, int max_ur_w); - inline void compute_oh_loop_common(); - - void generate(); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp deleted file mode 100644 index 13f61e84f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp +++ /dev/null @@ -1,410 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx2_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -#define src_blk_off(f, n, c, d, h, w) \ - (pd()->ndims() == 3) \ - ? (f).blk_off(n, c, w) \ - : (pd()->ndims() == 4) \ - ? (f).blk_off(n, c, h, w) \ - : (f).blk_off(n, c, d, h, w) - -#define wht_blk_off_(f, g, ...) \ - pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) -#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ - (pd()->ndims() == 3) \ - ? wht_blk_off_(f, g, oc, ic, kw) \ - : (pd()->ndims() == 4) \ - ? wht_blk_off_(f, g, oc, ic, kh, kw) \ - : wht_blk_off_(f, g, oc, ic, kd, kh, kw) - -void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const auto &jcp = kernel_->jcp; - - int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); - const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od - * jcp.oh; - - auto ker = [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - int icbb = 0; - while (icbb < jcp.nb_ic) { - int icb_step = jcp.nb_ic_blocking; - int icb_step_rem = jcp.nb_ic - icbb; - if (icb_step_rem < jcp.nb_ic_blocking_max) - icb_step = icb_step_rem; - - size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0}; - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, - od, jcp.od, oh, jcp.oh); - for (size_t iwork = start; iwork < end; ++iwork) { - int ocb = ocbb * jcp.nb_oc_blocking; - int ocb_num = jcp.nb_oc_blocking; - - for (int icb = icbb; icb < icbb + icb_step; ++icb) { - auto par_conv = jit_conv_call_s(); - - const int ij = oh * jcp.stride_h; - const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); - const int i_b_overflow = nstl::max(jcp.ih, ij - + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; - - const int dj = od * jcp.stride_d; - const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); - const int d_b_overflow = nstl::max(jcp.id, dj - + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id; - - const size_t _oc = g * jcp.nb_oc + ocb; - const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb; - - const int ih = nstl::max(ij - jcp.t_pad - + div_up(i_t_overflow, - (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); - - const int id = nstl::max(dj - jcp.f_pad - + div_up(d_t_overflow, - (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0); - - par_conv.src = &src[src_blk_off(src_d, n, - jcp.ic == 3 ? 0 : _ic, id, ih, 0)]; - - par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; - - const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); - const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); - par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, - jcp.ic == 3 ? 0 : icb, wd, wh, 0)]; - - if (icb == 0) { - if (bias) - par_conv.bias = - &bias[bias_d.blk_off(_oc * jcp.oc_block)]; - par_conv.flags |= FLAG_IC_FIRST; - } - - if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { - par_conv.flags |= FLAG_IC_LAST; - } - - par_conv.oc_blocks = - nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; - - par_conv.kw_padding = 0; - const int kh_padding = jcp.kh - - div_up(i_t_overflow, (jcp.dilate_h + 1)) - - div_up(i_b_overflow, (jcp.dilate_h + 1)); - par_conv.kh_padding = nstl::max(0, kh_padding); - - const int kd_padding = jcp.kd - - div_up(d_t_overflow, (jcp.dilate_d + 1)) - - div_up(d_b_overflow, (jcp.dilate_d + 1)); - par_conv.kd_padding = nstl::max(0, kd_padding); - - kernel_->jit_ker(&par_conv); - } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, - od, jcp.od, oh, jcp.oh); - } - icbb += icb_step; - } - }; - - if (pd()->wants_padded_bias()) { - auto padded_bias = scratchpad(ctx).get(key_conv_padded_bias); - utils::array_copy(padded_bias, bias, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bias = padded_bias; - } - - parallel(0, ker); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); -} - -void jit_avx2_convolution_bwd_data_t::execute_backward_data( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - - int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; - int ih_block_size = jcp.ih; - int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); - size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; - if (work_amount < (size_t)2 * mkldnn_get_max_threads()) { - ih_block_size = 1; - num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); - work_amount *= num_ih_blocks; - } - - auto ker = [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - size_t n{0}, g{0}, icbb{0}, ihb{0}; - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, - ihb, num_ih_blocks); - for (size_t iwork = start; iwork < end; ++iwork) { - for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) - for (int id = 0; id < jcp.id; ++id) { - auto par_conv = jit_conv_call_s(); - - const int idp = jcp.id + 2 * jcp.f_pad; - const int d_t_overflow = nstl::max(0, - jcp.kd - 1 - id - jcp.f_pad); - const int back_pad = idp - jcp.id - jcp.f_pad; - const int d_b_overflow = nstl::max(0, - jcp.kd - 1 - (jcp.id - 1 - id) - back_pad); - const int od = id + jcp.f_pad - d_b_overflow; - - int ih_start = ihb * ih_block_size; - int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); - for (int ih = ih_start; ih < ih_end; ++ih) { - - const int i_t_overflow = nstl::max(0, (jcp.kh - 1 - - ih - jcp.t_pad) / jcp.stride_h); - const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih - + ih - jcp.b_pad) / jcp.stride_h); - int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 - + jcp.b_pad - ih) % jcp.stride_h); - int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; - - par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; - par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) - / jcp.stride_h + 1 - i_t_overflow - i_b_overflow; - par_conv.kw_padding = 0; - - const int k_lo = overflow_kh_lo - + i_b_overflow * jcp.stride_h; - const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; - - par_conv.src = &diff_src[src_blk_off(diff_src_d, n, - /*jcp.ic == 3 ? 0 :*/ - g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)]; - par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, - n, g * jcp.nb_oc + oc, od, oh, 0)]; - par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, - jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb, - d_b_overflow, k_lo, 0)]; - - par_conv.src_prf = nullptr; - par_conv.dst_prf = nullptr; - par_conv.filt_prf = nullptr; - par_conv.channel = oc; - par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc, - jcp.nb_oc_blocking); - - kernel_->jit_ker(&par_conv); - } - } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, - num_ih_blocks); - } - }; - - parallel(0, ker); -} - -void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - auto scratchpad = this->scratchpad(ctx); - - data_t *diff_bias = pd()->wants_padded_bias() - ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - - auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, - prefix_reducer_bia); - auto rb = this->reducer_bias_; - rb->init(reducer_bia_scratchpad); - - auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, - prefix_reducer_wei); - auto rw = this->reducer_weights_; - rw->init(reducer_wei_scratchpad); - - auto ker = [&](int ithr, int nthr) { - assert(nthr == rw->balancer().nthr_); - - const int w_job_start = rw->balancer().ithr_job_off(ithr); - const int w_njobs = rw->balancer().ithr_njobs(ithr); - - if (w_njobs == 0) return; - - /* reduction dimension */ - int img_od_start{0}, img_od_end{0}, img{0}, od_s{0}; - balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, - rw->balancer().id_in_group(ithr), img_od_start, img_od_end); - - int img_start = img_od_start, img_end = img_od_end; - nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); - const int img_first = img; - - /* jobs */ - int g_start{0}, ocb_start{0}, icb_start{0}; - nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, - jcp.nb_oc, icb_start, jcp.nb_ic); - - while (img_start < img_end) { - int g = g_start, ocb = ocb_start, icb = icb_start; - - const int work_rem = img_end - img_start; - const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; - const int id_s = od_s * jcp.stride_d; - const int idp = jcp.id + jcp.f_pad + jcp.back_pad; - - if (id_s < idp - jcp.back_pad - jcp.kd + 1) - for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { - const size_t _oc = g * jcp.nb_oc + ocb; - const size_t _ic = g * jcp.nb_ic + icb; - - /* TODO: put dw <-- 0 in kernel */ - if (img == img_first) - array_set(rw->get_local_ptr(ithr, diff_weights, - reducer_wei_scratchpad) + - w_job_loc * rw->balancer().job_size_, 0, - rw->balancer().job_size_); - - for (int od = od_s; od < od_e; ++od) { - const int id = od * jcp.stride_d; - if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; - - auto par_conv = jit_conv_call_s(); - par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; - par_conv.dst = - &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)]; - par_conv.filt = rw->get_local_ptr(ithr, diff_weights, - reducer_wei_scratchpad) + - w_job_loc * rw->balancer().job_size_; - - kernel_->jit_ker(&par_conv); - } - nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb, - jcp.nb_ic); - } - nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); - } - rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); - }; - - auto ker_bias = [&](int ithr, int nthr) { - assert(nthr == rb->balancer().nthr_); - - const int b_job_start = rb->balancer().ithr_job_off(ithr); - const int b_njobs = rb->balancer().ithr_njobs(ithr); - - if (b_njobs == 0) return; - - /* reduction dimension */ - int img_start{0}, img_end{0}; - balance211(jcp.mb, rb->balancer().nthr_per_group_, - rb->balancer().id_in_group(ithr), img_start, img_end); - - /* jobs */ - int g_start{0}, ocb_start{0}; - nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, - jcp.nb_oc); - - for (int img = img_start; img < img_end; ++img) { - int g = g_start, ocb = ocb_start; - for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { - const size_t _oc = g * jcp.nb_oc + ocb; - - const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; - data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, - reducer_bia_scratchpad) + - b_job_loc * rb->balancer().job_size_; - - if (img == img_start) - for (int o = 0; o < 8; ++o) - d_bias[o] = 0.; - - for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { - PRAGMA_OMP_SIMD() - for (int o = 0; o < 8; ++o) - d_bias[o] += d_dst[o]; - d_dst += 8; - } - - nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); - } - } - rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); - }; - - parallel(0, [&](const int ithr, const int nthr) { - ker(ithr, nthr); - if (pd()->with_bias()) - ker_bias(ithr, nthr); - }); - - /* TODO: put this in ker_bias */ - if (pd()->wants_padded_bias()) { - assert(jcp.ngroups == 1); - for (int oc = 0; oc < jcp.oc_without_padding; ++oc) - diff_bias_in[oc] = diff_bias[oc]; - } -} - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp deleted file mode 100644 index bb65bce79..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp +++ /dev/null @@ -1,302 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP -#define CPU_JIT_AVX2_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_reducer.hpp" - -#include "jit_avx2_conv_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx2_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx2, ""), - jit_avx2_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_, - *desc(), src_md(), weights_md(), dst_md(), *attr()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_); - - return status::success; - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - const bool flat = IC() < 8; - auto src_tag = flat - ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) - : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto dst_tag = - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, - gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) - : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, - OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); - - return set_default_formats_common(src_tag, wei_tag, dst_tag); - } - }; - - jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } - ~jit_avx2_convolution_fwd_t() { delete kernel_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx2_conv_fwd_kernel_f32 *kernel_; -}; - -struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx2, ""), - jit_avx2_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::undef, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf( - jcp_, *desc(), *diff_src_md(), *weights_md(), - *diff_dst_md()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad, - jcp_); - - return status::success; - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) - : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); } - ~jit_avx2_convolution_bwd_data_t() { delete kernel_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx2_conv_bwd_data_kernel_f32 *kernel_; -}; - -struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx2, ""), - jit_avx2_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf( - jcp_, *desc(), *src_md(), *diff_weights_md(), - *diff_dst_md()); - if (status != status::success) return status; - - init_balancers(); - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad, - jcp_); - - auto reducer_bia_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_bia); - reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); - - auto reducer_wei_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_wei); - reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); - - return status::success; - } - - jit_conv_conf_t jcp_; - cpu_reducer_t::conf_t reducer_bia_conf_; - cpu_reducer_t::conf_t reducer_wei_conf_; - - protected: - bool set_default_formats() { - using namespace format_tag; - const bool flat = IC() == 3; - - auto src_tag = flat - ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) - : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto dst_tag = - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, - gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) - : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, - OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); - - return set_default_formats_common(src_tag, wei_tag, dst_tag); - } - - private: - void init_balancers() { - const int max_threads = mkldnn_get_max_threads(); - const size_t max_buffer_size = 1<<21; /* just a heuristic */ - - if(with_bias()) { - reducer_bia_conf_.init(reduce_balancer_t(max_threads, - jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, - max_buffer_size)); - } - - reducer_wei_conf_.init(reduce_balancer_t(max_threads, - jcp_.kd * jcp_.kh * jcp_.kw - * jcp_.ic_block * jcp_.oc_block, - jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc, - jcp_.mb * jcp_.od, max_buffer_size)); - } - }; - - jit_avx2_convolution_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr) - , reducer_weights_(nullptr) - , reducer_bias_(nullptr) - { - kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_); - reducer_bias_ = - new cpu_reducer_t(pd()->reducer_bia_conf_); - reducer_weights_ = - new cpu_reducer_t(pd()->reducer_wei_conf_); - } - - ~jit_avx2_convolution_bwd_weights_t() { - delete kernel_; - delete reducer_weights_; - delete reducer_bias_; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx2_conv_bwd_weights_kernel_f32 *kernel_; - cpu_reducer_t *reducer_weights_, *reducer_bias_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp deleted file mode 100644 index 635b83b2b..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp +++ /dev/null @@ -1,1255 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_memory.hpp" -#include "cpu_barrier.hpp" - -#include "jit_uni_1x1_conv_utils.hpp" -#include "jit_avx512_common_1x1_conv_kernel.hpp" - -#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) -{ - mov(aux1_reg_bcast_data, reg_bcast_data); - mov(aux_reg_bcast_data, reg_bcast_data); - - mov(aux_reg_output_data, reg_output_data); - mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt)); - - if (jcp.ver == ver_4fma) - { - Label bcast_loop; - Label bcast_loop_wraparound; - Label bcast_loop_out; - Label bcast_loop_ur_full; - - cmp(bcast_loop_iter, jcp.ur); - jle(bcast_loop_wraparound, T_NEAR); - - L(bcast_loop); { - assert(jcp.bcast_block % jcp.ur == 0); - int num_substeps = jcp.bcast_block / jcp.ur; - assert(num_substeps > 0 && num_substeps < 10); - for (int i = 0; i < num_substeps; i++) { - reduce_loop(load_loop_blk, jcp.ur, i, false); - if (i < num_substeps - 1) { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_substep); - } - else { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step - - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_step - - (num_substeps - 1) * jcp.bcast_loop_output_substep); - } - } - sub(bcast_loop_iter, jcp.bcast_block); - cmp(bcast_loop_iter, jcp.bcast_block); - jg(bcast_loop, T_NEAR); - } - - L(bcast_loop_wraparound); - if (jcp.ur_tail) { - je(bcast_loop_ur_full, T_NEAR); - reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); - jmp(bcast_loop_out, T_NEAR); - } - L(bcast_loop_ur_full); - reduce_loop(load_loop_blk, jcp.ur, 0, true); - L(bcast_loop_out); - } - else - { - Label bcast_loop; - Label bcast_loop_tail; - - cmp(bcast_loop_iter, jcp.ur); - jl(bcast_loop_tail, T_NEAR); - - L(bcast_loop); { - assert(jcp.bcast_block % jcp.ur == 0); - int num_substeps = jcp.bcast_block / jcp.ur; - assert(num_substeps > 0 && num_substeps < 10); - for (int i = 0; i < num_substeps; i++) { - reduce_loop(load_loop_blk, jcp.ur, i, false); - if (i < num_substeps - 1) { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_substep); - } - else { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step - - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_step - - (num_substeps - 1) * jcp.bcast_loop_output_substep); - } - } - sub(bcast_loop_iter, jcp.bcast_block); - cmp(bcast_loop_iter, jcp.bcast_block); - jge(bcast_loop, T_NEAR); - } - - L(bcast_loop_tail); - if (jcp.ur_tail) { - Label bcast_loop_tail_out; - cmp(bcast_loop_iter, 0); - jz(bcast_loop_tail_out, T_NEAR); - reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); - L(bcast_loop_tail_out); - } - } -} - -void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk, - int ur, int substep, bool wraparound) -{ - auto vreg_load = [=](int i_load, int i_fma) { - return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step) - + jcp.fma_step * i_load + i_fma); - }; - - auto vreg_accum = [=](int i_load, int i_ur) { - return Zmm(i_ur * load_loop_blk + i_load); - }; - - auto bias_ptr = [=](int i_load) { - return EVEX_compress_addr(reg_bias_data, - jcp.typesize_out * jcp.oc_block * i_load); - }; - - auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { - assert(i_ur < jcp.ur); - assert(i_reduce <= jcp.reduce_loop_unroll); - int offt; - if (one_of(jcp.prop_kind, forward_training, forward_inference, - backward_data)) { - assert(jcp.reduce_loop_unroll == jcp.reduce_block); - offt = (i_reduce == jcp.reduce_loop_unroll) - ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll - : i_ur * jcp.reduce_loop_unroll + i_reduce; - } else { - if (jcp.transpose_src) { - const int reduce_group = i_reduce / 4; - const int reduce_shift = i_reduce % 4; - offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift; - } - else - offt = i_reduce * jcp.ic_block + i_ur; - } - return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, - bcast); - }; - - auto load_ptr = [=](int i_reduce, int i_load) { - int offt; - int u0 = i_reduce % jcp.reduce_loop_unroll; - int u1 = i_reduce / jcp.reduce_loop_unroll; - offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; - return EVEX_compress_addr(aux_reg_load_data, - u1 * jcp.reduce_loop_load_step - + jcp.typesize_in * offt); - }; - - auto output_ptr = [=](int i_load, int i_ur) { - if (one_of(jcp.prop_kind, forward_training, forward_inference, - backward_data)) - return EVEX_compress_addr(aux_reg_output_data, - (i_load * jcp.bcast_dim + i_ur) * jcp.load_block - * jcp.typesize_out); - else - return ptr[aux_reg_output_data + - (i_load - ? reg_output_stride * i_load - : 0) // TODO: Xbyak should allow 0 scale - + jcp.typesize_out * jcp.load_block * i_ur]; - }; - - auto init = [=]() { - Label init_done; - Label init_zero; - - if (jcp.with_sum) { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - for (int i_ur = 0; i_ur < ur; ++i_ur) { - mic_prefetcht1(output_ptr(i_load, i_ur)); - } - } - } - - if (jcp.with_bias - && one_of(jcp.prop_kind, forward_training, forward_inference)) { - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jz(init_zero, T_NEAR); - - for (int i_load = 0; i_load < load_loop_blk; i_load++) - for (int i_ur = 0; i_ur < ur; ++i_ur) - vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load)); - jmp(init_done, T_NEAR); - } - - L(init_zero); - for (int i_load = 0; i_load < load_loop_blk; ++i_load) - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto r = vreg_accum(i_load, i_ur); - vpxord(r, r, r); - } - L(init_done); - }; - - auto store = [=]() { - Label store_noadd; - if (!jcp.with_sum) { - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jnz(store_noadd, T_NEAR); - } - - for (int i_ur = 0; i_ur < ur; ++i_ur) - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - auto r = vreg_accum(i_load, i_ur); - vaddps(r, r, output_ptr(i_load, i_ur)); - } - - L(store_noadd); - if (jcp.with_eltwise) { - Label store_noeltwise; - test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); - jz(store_noeltwise, T_NEAR); - - eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); - - L(store_noeltwise); - } - - auto store_output = [=](bool output_is_aligned) { - for (int i_ur = 0; i_ur < ur; ++i_ur) - for (int i_load = 0; i_load < load_loop_blk; ++i_load) - if (output_is_aligned && jcp.use_vmovntps) - vmovntps(output_ptr(i_load, i_ur), - vreg_accum(i_load, i_ur)); - else - vmovups(output_ptr(i_load, i_ur), - vreg_accum(i_load, i_ur)); - }; - - Label unaligned_store, end_store; - test(aux_reg_output_data, cpu_isa_traits::vlen - 1); - jnz(unaligned_store, T_NEAR); - store_output(true); - jmp(end_store, T_NEAR); - L(unaligned_store); { - store_output(false); - } - L(end_store); - }; - - auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load, - bool last_block, bool wraparound, int reduce_step) - { - bool pf_ker_l1 = true; - bool pf_ker_l2 = wraparound; - int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk; - int i_op = (i_reduce / reduce_step) * ur * load_loop_blk + - i_ur * load_loop_blk + i_load; - - int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0; - int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0; - int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur; - - int pf_inp_ops = n_ops / 2; // # of operations during which to pf input - int pf_inp_trigger; - if (jcp.prop_kind == backward_weights) - pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block); - else - pf_inp_trigger = nstl::max(1, pf_inp_ops / ur); - - int n_other_pf = - load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1); - int n_other_pf_ops = n_ops - pf_inp_ops; - int other_pf_trigger - = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0; - - if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) { - // input prefetches have the highest priority b/c the - // first iteration of the kernel block touches all the - // cache lines - int i_pf = i_op / pf_inp_trigger; - auto pf_reg = wraparound && last_block - ? reg_bcast_data - : (last_block ? aux1_reg_bcast_data - : aux_reg_bcast_data); - int offt = i_pf; - if (jcp.prop_kind == backward_weights) { - offt += wraparound && last_block - ? 0 - : (last_block ? jcp.is : jcp.reduce_block); - offt *= jcp.bcast_block; - } else { - offt += wraparound && last_block - ? 0 - : (last_block ? jcp.ur : jcp.bcast_dim); - offt *= jcp.reduce_block; - } - mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); - } else if (i_op >= pf_inp_ops && n_other_pf) { - // remaining prefetches are spread among the rest of the - // operations; prefetches for output take priority - // TODO: spread L2 prefetches among L1 prefetches - i_op -= pf_inp_ops; - if (i_op % other_pf_trigger == 0) { - int i_pf = i_op / (load_loop_blk * other_pf_trigger); - if (i_pf < n_pf_ker_l2) { - int offt = (i_pf + (i_load + 1) * jcp.reduce_dim) - * jcp.load_block; - mic_prefetcht1(ptr[aux_reg_load_data - + offt * jcp.typesize_in]); - } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) { - i_pf -= n_pf_ker_l2; - auto pf_reg = last_block ? reg_load_data - : aux_reg_load_data; - int offt = (i_pf + i_load * jcp.reduce_dim - + (last_block - ? (wraparound ? jcp.reduce_dim : 0) - : jcp.reduce_block)) - * jcp.load_block; - mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); - } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) { - i_pf -= n_pf_ker_l1 + n_pf_ker_l2; - int offt = i_pf * jcp.load_block; - mic_prefetcht0(ptr[aux_reg_output_data - + offt * jcp.typesize_out]); - } - } - } - }; - - auto fma_block = [=](bool last_block) { - assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); - - int reduce_step = jcp.fma_step; - - for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll; - i_reduce += reduce_step) { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - // if transposed input data used and if spatial size is - // not divided by transpose step (4) then for last reduce step - // we should load only needed load_registers data - // and clear remaining - if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block - && i_reduce == jcp.reduce_loop_unroll - reduce_step) { - Label load_all; - Label load_finish; - test(reg_reduce_pos_flag, FLAG_SP_LAST); - jz(load_all, T_NEAR); - - const int n_loads = jcp.is % jcp.fma_step; - for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { - if (i_fma < n_loads) - vmovups(vreg_load(i_load, i_fma), - load_ptr(i_reduce + i_fma, i_load)); - else - vpxord(vreg_load(i_load, i_fma), - vreg_load(i_load, i_fma), - vreg_load(i_load, i_fma)); - } - jmp(load_finish); - - L(load_all); - for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { - vmovups(vreg_load(i_load, i_fma), - load_ptr(i_reduce + i_fma, i_load)); - } - L(load_finish); - } else { - for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { - vmovups(vreg_load(i_load, i_fma), - load_ptr(i_reduce + i_fma, i_load)); - } - } - } - - for (int i_ur = 0; i_ur < ur; ++i_ur) { - if (jcp.ver == ver_avx512_core && jcp.expl_bcast - && load_loop_blk > 1) - vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - if (jcp.ver == ver_4fma) - v4fmaddps(vreg_accum(i_load, i_ur), - vreg_load(i_load, 0), - bcast_ptr(i_reduce, i_ur, false)); - else if (jcp.ver == ver_avx512_core && jcp.expl_bcast - && load_loop_blk > 1) - vfmadd231ps(vreg_accum(i_load, i_ur), - vreg_load(i_load, 0), vreg_bcast); - else - vfmadd231ps(vreg_accum(i_load, i_ur), - vreg_load(i_load, 0), - bcast_ptr(i_reduce, i_ur, true)); - prefetch_callback(ur, i_reduce, i_ur, i_load, - last_block, wraparound, reduce_step); - } - } - } - }; - Label reduce_loop; - Label reduce_loop_tail; - - mov(aux_reg_load_data, reg_load_data); - - mov(aux_reg_bcast_data, aux1_reg_bcast_data); - init(); - - mov(reduce_loop_iter, reg_reduce_loop_work); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jle(reduce_loop_tail, T_NEAR); - - L(reduce_loop); { - fma_block(false); - add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); - add(aux_reg_load_data, jcp.reduce_loop_load_step); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jg(reduce_loop, T_NEAR); - } - - L(reduce_loop_tail); - fma_block(true); - - store(); -} - -void jit_avx512_common_1x1_conv_kernel::generate() -{ - preamble(); - - mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); - mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); - mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); - - sub(rsp, stack_space_needed); - - if (jcp.with_bias) - mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); - - mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); - mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); - mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work); - mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); - mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); - if (one_of(jcp.prop_kind, forward_training, forward_inference)) - mov(reg_relu_ns, reinterpret_cast(&jcp.eltwise.alpha)); - if (jcp.prop_kind == backward_weights) - mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); - - auto load_loop_body = [=](int load_loop_blk) { - bcast_loop(load_loop_blk); - add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); - switch (jcp.prop_kind) { - case forward_training: - case forward_inference: - add(reg_bias_data, - load_loop_blk * jcp.load_block * jcp.typesize_out); - add(reg_output_data, - load_loop_blk * jcp.bcast_dim * jcp.load_block * - jcp.typesize_out); - break; - case backward_data: - add(reg_output_data, - load_loop_blk * jcp.bcast_dim * jcp.load_block * - jcp.typesize_out); - break; - case backward_weights: - for (int i_load = 0; i_load < load_loop_blk; i_load++) - add(reg_output_data, reg_output_stride); - break; - default: - assert(!"invalid prop_kind"); - } - sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); - }; - - const int simd_w = 16; - - Label load_loop_blk[7]; - - static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 }; - static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; - static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 }; - - const int size_ur_cases_fma - = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? - sizeof(ur_cases_fma_expl_bcast) : - sizeof(ur_cases_fma_embd_bcast); - const int size_ur_cases_4fma = sizeof(ur_cases_4fma); - - const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? - ur_cases_fma_expl_bcast : - ur_cases_fma_embd_bcast; - const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma; - const int num_ur_cases = - (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma) - / sizeof(*ur_cases); - - for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { - int label_idx = num_ur_cases - ur_idx - 1; - if (jcp.ur <= ur_cases[ur_idx]) { - cmp(reg_load_loop_work, simd_w * (label_idx + 1)); - jle(load_loop_blk[label_idx], T_NEAR); - } - } - - for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { - if (jcp.ur <= ur_cases[ur_idx]) { - int label_idx = num_ur_cases - ur_idx - 1; - L(load_loop_blk[label_idx]); - { - if (label_idx == 0) { - cmp(reg_load_loop_work, 0); - je(load_loop_blk[num_ur_cases], T_NEAR); - } - load_loop_body(label_idx + 1); - if (label_idx - 1 > 0) { - cmp(reg_load_loop_work, 2 * label_idx * simd_w); - je(load_loop_blk[label_idx - 1], T_NEAR); - } - cmp(reg_load_loop_work, (label_idx + 1) * simd_w); - jge(load_loop_blk[label_idx]); - } - for (int idx = label_idx - 1; idx > 0; --idx) { - cmp(reg_load_loop_work, simd_w * (idx + 1)); - je(load_loop_blk[idx], T_NEAR); - } - if (ur_idx < num_ur_cases - 2) { - cmp(reg_load_loop_work, simd_w); - jle(load_loop_blk[0], T_NEAR); - } - } - } - L(load_loop_blk[num_ur_cases]); - - add(rsp, stack_space_needed); - - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_avx512_common_1x1_conv_kernel::post_ops_ok( - jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr, int nthreads, bool reduce_src) { - if (!mayiuse(avx512_common)) return status::unimplemented; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - const int simd_w = cpu_isa_traits::vlen / sizeof(float); - const int ndims = src_d.ndims(); - - jcp.prop_kind = cd.prop_kind; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - bool ok_to_pad_channels = true - && jcp.ngroups == 1 - && src_d.data_type() == data_type::f32; - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; - jcp.ow = dst_d.dims()[ndims - 1]; - - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; - jcp.l_pad = cd.padding[0][ndims - 3]; - - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; - jcp.stride_w = cd.strides[ndims - 3]; - - jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, - format_kind::undef, cd.diff_bias_desc.format_kind) - != format_kind::undef; - - jcp.os = jcp.oh * jcp.ow; - jcp.is = jcp.ih * jcp.iw; - jcp.tr_is = rnd_up(jcp.is, 4); - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) { - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - if (dst_d.data_type() == data_type::s32) return status::unimplemented; - } - - auto dat_tag = pick(ndims - 3, nCw16c, nChw16c); - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.ngroups == 1 - && jcp.src_tag == dat_tag - && jcp.dst_tag == dat_tag; - if (!args_ok) return status::unimplemented; - - args_ok = true - && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 - && jcp.t_pad == 0 && jcp.l_pad == 0 - && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides - && jcp.kh == 1 && jcp.kw == 1; - if (!args_ok) return status::unimplemented; - - jcp.ic_block = jcp.oc_block = simd_w; - jcp.transpose_src = false; - - if (everyone_is(data_type::f32, src_d.data_type(), - weights_d.data_type(), dst_d.data_type())) - { - const int is_bwd_d = jcp.prop_kind == backward_data; - format_tag_t wei_tag = with_groups - ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, - gOIhw16i16o, gIOhw16o16i) - : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, - OIhw16i16o, IOhw16o16i); - - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - if (jcp.wei_tag != wei_tag) - return status::unimplemented; - - if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) && - ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4 - == 0) { - jcp.ver = ver_4fma; - jcp.fma_step = 4; - } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops) - && !reduce_src - /* Heuristic condition for relation of src size to oc. Otherwise - the src transposition overhead exceed the benefit from 4fma - */ - && ((jcp.is * jcp.ic) / jcp.oc <= 2048) - && mkldnn_thr_syncable() - ) - { - jcp.transpose_src = true; - jcp.ver = ver_4fma; - jcp.fma_step = 4; - } else { - jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma; - jcp.fma_step = 1; - } - jcp.typesize_in = sizeof(prec_traits::type); - jcp.typesize_out = sizeof(prec_traits::type); - } else { - return status::unimplemented; - } - - /* once all the formats are set, check the padding consistency */ - args_ok = true - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= dst_d.padded_dims()[1] - && jcp.ic <= weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; - if (!args_ok) return status::unimplemented; - - const int SMALL_SPATIAL = 10; - const int BIG_SPATIAL = 28; - const int BIG_REDUCE_DIM = 1024; - const int BIG_LOAD_DIM = 256; - - int load_blocking{ 0 }; - int load_blocking_max{ 0 }; - int bcast_blocking{ 0 }; - int bcast_blocking_max{ 0 }; - int reduce_blocking{ 0 }; - int reduce_blocking_max{ 0 }; - - jcp.load_grp_count = 1; - - const int L1_capacity = get_cache_size(1, true) / sizeof(float); - const int L2_size = get_cache_size(2, true) / sizeof(float); - const int L2_capacity = (L2_size * 3) / 4; - - if (one_of(jcp.prop_kind, forward_training, forward_inference, - backward_data)) { - if (one_of(jcp.prop_kind, forward_training, forward_inference)) { - jcp.reduce_dim = jcp.ic; - jcp.reduce_block = jcp.ic_block; - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.is; - } else { - jcp.reduce_dim = jcp.oc; - jcp.reduce_block = jcp.oc_block; - - jcp.load_dim = jcp.ic; - jcp.load_block = jcp.ic_block; - - jcp.bcast_dim = jcp.os; - } - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; - - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; - jcp.load_loop_load_step - = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; - - // adjusting registry blocking - int max_regs, min_regs, size_treshold, ur_step; - const int spatial - = (one_of(jcp.prop_kind, forward_training, forward_inference)) ? - jcp.oh : - jcp.ih; - if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) { - max_regs = 9; - min_regs = 6; - size_treshold = 14; - ur_step = 1; - jcp.expl_bcast = true; - - if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM - && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { - max_regs = 6; - min_regs = 5; - } - } else { - max_regs = jcp.ver == ver_4fma ? 28 : 30; - min_regs = 9; - size_treshold = jcp.ver == ver_4fma ? 28 : 14; - ur_step = jcp.ver == ver_4fma ? 4 : 1; - jcp.expl_bcast = false; - jcp.use_vmovntps = true; - } - jcp.ur = 1; - for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) { - if ((spatial >= size_treshold && spatial % ur_w == 0) - || (spatial < size_treshold && jcp.os % ur_w == 0)) { - jcp.ur = ur_w; - break; - } - } - if (jcp.ur == 1) { - jcp.ur = nstl::min(max_regs, jcp.os); - int os_tail = jcp.os % max_regs; - for (int i = max_regs; i >= min_regs; i -= ur_step) { - int i_tail = jcp.os % i; - if (i_tail > os_tail || i_tail == 0) { - jcp.ur = i; - os_tail = i_tail; - if (i_tail == 0) - break; - } - } - } - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; - - jcp.bcast_block = jcp.ur; - - jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out; - jcp.bcast_loop_output_substep = -1; // unused - jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in; - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_iter_step = jcp.load_block; - - if (jcp.prop_kind == backward_data) - jcp.loop_order = loop_lbr; - else - jcp.loop_order = reduce_src ? loop_blr : loop_lbr; - - int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - int nb_load = div_up(jcp.load_dim, jcp.load_block); - - if (jcp.ver == ver_avx512_core && jcp.expl_bcast) { - if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL - && spatial < BIG_SPATIAL) - reduce_blocking = nstl::min(jcp.reduce_dim, 80); - else if (spatial > SMALL_SPATIAL) - reduce_blocking = nstl::min(jcp.reduce_dim, 512); - else - reduce_blocking = nstl::min(jcp.reduce_dim, 256); - - if ((jcp.mb > 28 && spatial >= 28) - || (jcp.mb > 112 && spatial >= 17)) - jcp.use_vmovntps = true; - else - jcp.use_vmovntps = false; - } else { - - reduce_blocking = nb_reduce; - if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) - reduce_blocking = 16; - else if (spatial > SMALL_SPATIAL - && jcp.reduce_dim >= BIG_REDUCE_DIM) - reduce_blocking = 8; - reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); - reduce_blocking *= jcp.reduce_block; - } - - // Check input data cache aliasing. - // For other ISA constants may be updated. - // 64 * 1024 is chosen due to 1MB L2 16-way cache. - // 7 is empirical value. It is about half of 16. - // So we leave about half of the set for other data - weights, dst - int way_size = (64 * 1024) / jcp.typesize_in; - int max_hits = 7; - if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) { - int nrb = reduce_blocking / simd_w; - int sp = jcp.bcast_dim; - int wl = way_size / simd_w; - for (int start_off = 0; start_off < jcp.ur; start_off++) { - for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { - if (off % sp >= jcp.ur || ++hits < max_hits) - continue; - int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); - reduce_blocking - = nstl::min(reduce_blocking, max_r_blocking); - break; - } - } - } - - if (reduce_blocking < jcp.reduce_dim) { - jcp.use_vmovntps = false; - if (jcp.prop_kind == backward_data) - jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; - else - jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; - } - load_blocking = jcp.load_dim; - - int load_size = jcp.load_dim * jcp.reduce_dim; - int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; - - if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads - && nb_load * nb_bcast > nthreads) { - // Some heuristic here - float calc_koef = 0.01, best_cost = FLT_MAX; - int n_lgc = nthreads; - float ratio = (float)load_size / (float)bcast_size; - int best_lgc = ratio > 1 ? n_lgc : 1; - auto calc_job_cost = [&](int lb, int tg, float mem_k) { - int bb_size = jcp.mb * div_up(nb_bcast, tg); - float calc_size = (float)(bb_size * jcp.ur) - * (lb * jcp.load_block) * jcp.reduce_dim; - float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) - * jcp.reduce_dim; - return calc_koef * calc_size + mem_k * mem_size; - }; - for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { - lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; - int min_lb = nb_load / lgc; - int max_lb = div_up(nb_load, lgc); - int min_tg = nthreads / lgc; - int max_tg = div_up(nthreads, lgc); - // Some heuristic here - float mem_koef = (max_tg == 1) ? 1.f : 1.3f; - float job_cost = 0.; - if (nthreads % lgc < nb_load % lgc) { - job_cost = calc_job_cost(max_lb, min_tg, mem_koef); - } else { - auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); - auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); - job_cost = nstl::max(job_cost1, job_cost2); - } - - if (job_cost < best_cost) { - best_lgc = lgc; - best_cost = job_cost; - } - } - jcp.load_grp_count = best_lgc; - load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; - } else { - jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); - jcp.load_grp_count = best_divider( - nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); - } - - if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64 - && load_size >= L2_size) { - jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); - } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads - && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { - jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); - load_blocking = jcp.load_block; - } - - if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim - && jcp.oh * jcp.ow > 64 - && IMPLICATION(reduce_src, jcp.load_dim < 1024)) { - /* Looking for best loading dimension blocking - * to get the best thread and data read/write efficiency - * by finding the optimal 'load_chunk' value - * Example: - * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512 - * the 'best' load_chunk value should be 1 - * TODO: remove heuristic constants in above condition - * TODO: check this blocking for other ISA - */ - float best_eff = -1.f; - int best_lgc = 1; - - for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { - int lgc = div_up(nb_load, load_chunk); - if (lgc > nthreads) - continue; - int thr_per_grp = div_up(nthreads, lgc); - int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp) - * jcp.bcast_block; - int load_per_thr = load_chunk * simd_w; - float data_norm = (bcast_per_thr + load_per_thr) / 2.f; - float data_eff = (bcast_per_thr * load_per_thr) - / (data_norm * data_norm); - float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc) - / div_up(nthreads, lgc); - float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) - / rnd_up(jcp.mb * nb_bcast, thr_per_grp); - float thr_eff = thr_eff_over_grp * thr_eff_in_grp; - float load_eff = (float)nb_load / rnd_up(nb_load, lgc); - float overall_eff = data_eff + thr_eff + load_eff; - if (overall_eff > best_eff) { - best_eff = overall_eff; - best_lgc = lgc; - } - } - jcp.load_grp_count = best_lgc; - load_blocking - = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; - } - bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, - div_up(nthreads, jcp.load_grp_count)) - * jcp.bcast_block; - bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); - bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); - - int space_for_bcast - = (L2_capacity - /* kernel_size - */ - 2 * jcp.load_block * reduce_blocking - - jcp.ur * reduce_blocking - 3 * 1024); - if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) - space_for_bcast /= 2; - - int bcast_in_cache - = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); - bcast_blocking = nstl::min( - bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); - - load_blocking_max = load_blocking; - bcast_blocking_max = bcast_blocking * 3 / 2; - reduce_blocking_max = reduce_blocking; - - } else if (jcp.prop_kind == backward_weights) { - - jcp.use_vmovntps = false; - if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma) - jcp.use_vmovntps = true; - - if (jcp.transpose_src) - jcp.reduce_dim = jcp.tr_is; - else - jcp.reduce_dim = jcp.is; - - if (jcp.ver == ver_4fma) { - // reduce_block should be divided by fma_step - jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4); - } else { - jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); - if (jcp.reduce_dim % jcp.reduce_block != 0) - jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); - if (jcp.reduce_block > 256) { - jcp.reduce_block = 1; - } - - } - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.ic; - jcp.bcast_block = jcp.ic_block; - - if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) { - // if reduce_block is big then generated JIT code may be big - // for small values of ur because reduce_loop_unroll = reduce_block - jcp.ur = jcp.bcast_block / 2; - jcp.expl_bcast = true; - } else { - jcp.ur = jcp.bcast_block; - jcp.expl_bcast = false; - } - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in; - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in; - - jcp.bcast_loop_output_step = - jcp.oc_block * jcp.ic_block * jcp.typesize_out; - jcp.bcast_loop_output_substep = - jcp.oc_block * jcp.ur * jcp.typesize_out; - jcp.bcast_loop_bcast_step = - jcp.ic_block * jcp.reduce_dim * jcp.typesize_in; - jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; - - jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in; - jcp.load_loop_iter_step = jcp.oc_block; - - /* --- */ - balance(jcp, nthreads); - - load_blocking = div_up(jcp.load_dim, jcp.load_block); - load_blocking = best_divider(load_blocking, 16, load_blocking, false); - load_blocking *= jcp.load_block; - - load_blocking_max = load_blocking; - assert(jcp.load_dim % load_blocking == 0); - - int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); - int min_bcast_blocking = 5; - - bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); - bcast_blocking = best_divider( - bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); - bcast_blocking *= jcp.bcast_block; - bcast_blocking_max = bcast_blocking; - assert(jcp.bcast_dim % bcast_blocking == 0); - - // for reduction balance - if (jcp.ver == ver_avx512_core) { - int max_reduce_blocking - = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); - int min_reduce_blocking = nstl::min( - L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); - reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, - max_reduce_blocking, true); - reduce_blocking - = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), - jcp.reduce_block); - } else { - int max_reduce_blocking = L2_capacity - / ((bcast_blocking + load_blocking) * jcp.reduce_block); - max_reduce_blocking = nstl::min(max_reduce_blocking, - (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block); - - int num_jobs = div_up(jcp.load_dim, load_blocking) - * div_up(jcp.bcast_dim, bcast_blocking); - int threads_per_job = nstl::max(1, nthreads / num_jobs); - reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block); - reduce_blocking = div_up(reduce_blocking, threads_per_job); - - reduce_blocking = best_divider(reduce_blocking, - max_reduce_blocking - 2, max_reduce_blocking, true); - reduce_blocking *= jcp.reduce_block; - } - - reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); - } else - return status::unimplemented; - - assert(load_blocking); - assert(load_blocking_max); - assert(bcast_blocking); - assert(bcast_blocking_max); - assert(reduce_blocking); - assert(reduce_blocking_max); - assert(load_blocking % jcp.load_block == 0); - assert(reduce_blocking % jcp.reduce_block == 0); - assert(load_blocking_max % jcp.load_block == 0); - assert(reduce_blocking_max % jcp.reduce_block == 0); - if (jcp.ver == ver_4fma) { - assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); - assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); - } - - assert(jcp.bcast_block % jcp.ur == 0); - assert(jcp.reduce_dim % jcp.reduce_block == 0); - - jcp.ur_tail = jcp.bcast_dim % jcp.ur; - - jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; - jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; - jcp.nb_load_blocking = load_blocking / jcp.load_block; - jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; - jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; - jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; - - jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); - jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - return status::success; -} - -void jit_avx512_common_1x1_conv_kernel::init_scratchpad( - memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp) { - using namespace mkldnn::impl::memory_tracking::names; - - if (jcp.prop_kind != backward_data && jcp.with_bias - && jcp.oc != jcp.oc_without_padding) - scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); - - if (jcp.prop_kind == backward_weights) { - const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic; - scratchpad.book(key_conv_wei_reduction, - jcp.typesize_out * wei_size * (jcp.nthr_mb - 1)); - } - - if (jcp.transpose_src) { - const size_t tr_src_size = - (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is; - scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size); - scratchpad.book(key_conv_tr_src_bctx, - sizeof(simple_barrier::ctx_t) * jcp.nthr); - } -} - -void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp, - int nthreads) -{ - // initialize jcp reduction threading properties - jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; - if (nthreads < jcp.ngroups) { - /* simplification... fortunately it doesn't hurt much */ - return; - } - const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - const int nb_load = div_up(jcp.load_dim, jcp.load_block); - const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - jcp.nthr_g = jcp.ngroups; - const int nthr = nthreads / jcp.nthr_g; - - auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { - /* calculate per thread memory cost (read/write). high level - * optimizer tries to minimize memory consumption. few notes: (n1) - * unclear why, but that essentially helps first convolution... - * (n2) assuming the reduction over minibatch is always there: - * - instead of 8 it should be 5 here (write ~= 2 read): - * kernel: temporal workspace 1 write - * reduction: 1 read from workspace and 1 write to the diff_wei - * - but experiments showed 8 works better than 5 or 6... */ - int bcast_koeff = 1; - int load_koeff = 1; - int output_koeff = 12; - if (jcp.transpose_src) { - bcast_koeff = 5; - load_koeff = 1; - output_koeff = 8; - } - return 0 - + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) - * div_up(jcp.ngroups, jcp.nthr_g) - * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block - / jcp.stride_h / jcp.stride_w /* (n1) */ - + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) - * div_up(jcp.ngroups, jcp.nthr_g) - * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block - + (size_t)output_koeff /* (n2) */ - * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) - * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block - * jcp.oc_block; - }; - - int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; - auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - - /* step 1: find the best thread distribution with lowest memory cost */ - const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); - for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { - const int nthr_par = nthr / nthr_mb; - const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); - for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { - nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); - auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - if (mem_cost <= best_mem_cost) { - best_mem_cost = mem_cost; - jcp.nthr_mb = nthr_mb; - jcp.nthr_oc_b = nthr_oc_b; - jcp.nthr_ic_b = nthr_ic_b; - } - } - - if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } - } - if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) - jcp.nthr_mb = nstl::min(jcp.mb, nthreads); - - jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; - assert(jcp.nthr <= nthreads); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp deleted file mode 100644 index d2ae01794..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp +++ /dev/null @@ -1,108 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP -#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx512_common_1x1_conv_kernel : public jit_generator { - jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, - const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32( - this, jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); - } - - ~jit_avx512_common_1x1_conv_kernel() { - delete eltwise_injector_; - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel) - - static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr, - int nthreads, bool reduce_src); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp); - - jit_1x1_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_1x1_conv_call_s *); - - private: - using reg64_t = const Xbyak::Reg64; - using zmm_t = const Xbyak::Zmm; - - reg64_t reg_bcast_data = r8; - reg64_t reg_load_data = r10; - reg64_t reg_output_data = r9; - reg64_t aux_reg_bcast_data = r14; - reg64_t aux1_reg_bcast_data = rbx; - reg64_t aux_reg_load_data = r15; - reg64_t imm_addr64 = aux_reg_load_data; - reg64_t aux_reg_output_data = abi_not_param1; - reg64_t reg_load_loop_work = rsi; - reg64_t reg_reduce_loop_work = r11; - reg64_t bcast_loop_iter = rdx; - reg64_t reduce_loop_iter = abi_param1; - reg64_t reg_reduce_pos_flag = rax; - reg64_t reg_output_stride = r13; - reg64_t reg_bias_data = r12; - reg64_t reg_relu_ns = r13; - reg64_t reg_bcast_loop_work = aux1_reg_bcast_data; - - Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - int bcast_loop_work_offt = 0; - int stack_space_needed = 16; - - void bcast_loop(int load_loop_blk); - void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); - - void generate(); - static void balance(jit_1x1_conv_conf_t &jcp, int nthreads); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp deleted file mode 100644 index 54d58c8a3..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp +++ /dev/null @@ -1,816 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -#include "jit_avx512_common_1x1_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -#define data_blk_off(f, n, c, h, w) \ - ((ndims == 3) \ - ? (f).blk_off(n, c, w) \ - : (f).blk_off(n, c, h, w)) - - -namespace { -template -void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, - T nx, T &nx_start, T &nx_end, T nx_divider) -{ - const int grp_count = nstl::min(nx_divider, nthr); - const int grp_size_big = nthr / grp_count + 1; - const int grp_size_small = nthr / grp_count; - const int n_grp_big = nthr % grp_count; - const int threads_in_big_groups = n_grp_big * grp_size_big; - - const int ithr_bound_distance = ithr - threads_in_big_groups; - T grp, grp_ithr, grp_nthr; - if (ithr_bound_distance < 0) { // ithr in first groups - grp = ithr / grp_size_big; - grp_ithr = ithr % grp_size_big; - grp_nthr = grp_size_big; - } else { // ithr in last groups - grp = n_grp_big + ithr_bound_distance / grp_size_small; - grp_ithr = ithr_bound_distance % grp_size_small; - grp_nthr = grp_size_small; - } - - balance211(nx, grp_count, grp, nx_start, nx_end); - balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); -} -} -/* convolution forward */ - -template -void jit_avx512_common_1x1_convolution_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - auto scratchpad = this->scratchpad(ctx); - - const auto &jcp = kernel_->jcp; - if (pd()->wants_padded_bias()) { - auto padded_bias = scratchpad.template get( - key_conv_padded_bias); - utils::array_copy(padded_bias, bias, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bias = padded_bias; - } - - parallel(0, [&](const int ithr, const int nthr) { - execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); - }); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); -} - -template -void jit_avx512_common_1x1_convolution_fwd_t:: -execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, - const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - auto rtus_space = scratchpad.get(key_conv_rtus_space); - - const int ndims = src_d.ndims(); - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][ndims - 3]; - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - auto p = jit_1x1_conv_call_s(); - - auto rp = rtus_driver_t::call_params_t(); - - const int nb_oc = jcp.nb_load; - const int nb_ic = jcp.nb_reduce; - const int nb_ic_blocking = jcp.nb_reduce_blocking; - const int os_block = jcp.bcast_block; - - int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; - balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, - jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); - - auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, - int &oh, int &ow, int &ih, int &iw) - { - int osb{0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, bcast_end - iwork); - - const int os = osb * os_block; - oh = os / jcp.ow; - ow = os % jcp.ow; - - ih = nstl::max(oh * stride_h - pad_t, 0); - iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - p.bcast_dim = this_block_size(os, jcp.os, - bcast_step * os_block); - rp.os = p.bcast_dim; - }; - - auto init_load = [&](int ocb, int &load_step) - { - load_step = step(jcp.nb_load_blocking, ocb_end - ocb, - jcp.nb_load_blocking_max); - p.load_dim = this_block_size(ocb * jcp.oc_block, - ocb_end * jcp.oc_block, load_step * jcp.oc_block); - }; - - auto init_reduce = [&](int icb) - { - const int nb_ic_blocking_step = - nstl::min(icb + nb_ic_blocking, nb_ic) - icb; - p.first_last_flag = 0 - | (icb == 0 ? FLAG_REDUCE_FIRST : 0) - | (icb + nb_ic_blocking_step >= nb_ic - ? FLAG_REDUCE_LAST : 0); - - p.reduce_dim = this_block_size(icb * jcp.ic_block, - jcp.ic, nb_ic_blocking_step * jcp.ic_block); - rp.icb = p.reduce_dim / jcp.reduce_block; - }; - - auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, - int ih, int iw) - { - - const int _ocb = g * nb_oc + ocb; - const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); - - p.output_data = &dst[dst_off]; - p.bias_data = &bias[_ocb * jcp.oc_block]; - p.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - const int _icb = g * nb_ic + icb; - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ - + _icb * jcp.is * jcp.ic_block; - if (ocb == ocb_start) { - rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); - rtus_driver_->ker_(&rp); - } - p.bcast_data = rp.ws; - } else - p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); - - kernel_->jit_ker(&p); - }; - - if (jcp.loop_order == loop_rlb) { - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - iwork += bcast_step; - } - ocb += load_step; - } - } - } else if (jcp.loop_order == loop_lbr) { - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - } - iwork += bcast_step; - } - ocb += load_step; - } - } else if (jcp.loop_order == loop_rbl) { - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - ocb += load_step; - } - iwork += bcast_step; - } - } - } else if (jcp.loop_order == loop_blr) { - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - init_reduce(icb); - inner_ker(ocb, icb, n, g, oh, ow, ih, iw); - } - ocb += load_step; - } - iwork += bcast_step; - } - } else { - assert(!"unsupported loop order"); - } -} - - -template struct jit_avx512_common_1x1_convolution_fwd_t; -/* convolution backward wtr data */ - -template -void jit_avx512_common_1x1_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - - const auto &jcp = kernel_->jcp; - auto rtus_space = scratchpad(ctx).template get( - key_conv_rtus_space); - - const int ndims = diff_src_d.ndims(); - - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][ndims - 3]; - - const int nb_ic = jcp.nb_load; - const int nb_oc = jcp.nb_reduce; - const int os_block = jcp.bcast_block; - const int nb_oc_blocking = jcp.nb_reduce_blocking; - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - parallel(0, [&](const int ithr, const int nthr) { - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - - int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0}; - balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, - jcp.nb_load, icb_start, icb_end, jcp.load_grp_count); - - bool reduce_outer = (jcp.loop_order == loop_rbl - || jcp.loop_order == loop_rlb); - int nboc_outer = reduce_outer ? nb_oc : 1; - int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; - - int nboc_inner = reduce_outer ? 1 : nb_oc; - int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; - - for (int ocb_outer = 0; ocb_outer < nboc_outer; - ocb_outer += ocb_outer_step) { - size_t cur_ocb_outer = - nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer; - - int load_step = 0; - for (int icb = icb_start; icb < icb_end; icb += load_step) { - load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, - jcp.nb_load_blocking_max); - - p.load_dim = this_block_size(icb * jcp.ic_block, - icb_end * jcp.ic_block, load_step * jcp.ic_block); - rp.icb = p.load_dim / jcp.ic_block; - - int bcast_step; - for (int iwork = bcast_start; iwork < bcast_end; - iwork += bcast_step) - { - int n{0}, g{0}, osb{0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - - bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, bcast_end - iwork); - - const int os = osb * os_block; - p.bcast_dim = this_block_size(os, jcp.os, - bcast_step * os_block); - rp.os = p.bcast_dim; - - const int oh = os / jcp.ow; - const int ow = os % jcp.ow; - const int ih = nstl::max(oh * stride_h - pad_t, 0); - const int iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - const int _icb = g * nb_ic + icb; - rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_; - p.output_data = rp.ws; - } else - p.output_data = rp.src; - - for (int ocb_inner = 0; ocb_inner < nboc_inner; - ocb_inner += ocb_inner_step) { - int cur_ocb_inner = - nstl::min(ocb_inner + ocb_inner_step, nboc_inner) - - ocb_inner; - - int ocb = reduce_outer ? ocb_outer : ocb_inner; - int nb_oc_blocking_step = reduce_outer - ? cur_ocb_outer : cur_ocb_inner; - const int _ocb = g * nb_oc + ocb; - size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow); - p.bcast_data = &diff_dst[diff_dst_off]; - - p.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; - - p.reduce_dim = this_block_size(ocb * jcp.oc_block, - jcp.oc, nb_oc_blocking_step * jcp.oc_block); - - kernel_->jit_ker(&p); - } - if (pd()->rtus_.reduce_src_) - rtus_driver_->ker_(&rp); - } - } - } - }); -} - -template struct jit_avx512_common_1x1_convolution_bwd_data_t; - -/* convolution backward wtr weights */ - -#define wht_blk_off(d, g, ...) \ - (pd()->with_groups() \ - ? (d).blk_off((g), __VA_ARGS__) \ - : (d).blk_off(__VA_ARGS__)) - -jit_avx512_common_1x1_convolution_bwd_weights_t :: - jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) - , trans_kernel_(nullptr), rtus_driver_(nullptr) -{ - kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); - acc_ker_ = new cpu_accumulator_1d_t(); - reducer_bias_ = new cpu_reducer_t(pd()->reducer_bia_conf_); - init_rtus_driver(this); - - const auto &jcp = kernel_->jcp; - - if (jcp.transpose_src) { - auto tp = jit_transpose4x16_src_t(); - tp.src_pf0_distance = 4; - tp.tr_src_pf0_distance = 0; - tp.src_pf1 = true; - tp.tr_src_pf1 = false; - trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp); - } -} - -void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const -{ - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - - const auto scratchpad = this->scratchpad(ctx); - - auto rtus_space = scratchpad.get(key_conv_rtus_space); - data_t *diff_bias = pd()->wants_padded_bias() - ? scratchpad.get(key_conv_padded_bias) : diff_bias_in; - auto wei_reduction = scratchpad.get(key_conv_wei_reduction); - - /* prepare src transposition barriers */ - auto tr_src = scratchpad.get(key_conv_tr_src); - auto tr_src_bctx = scratchpad.get( - key_conv_tr_src_bctx); - if (jcp.transpose_src) { - for (int i = 0; i < jcp.nthr; ++i) - simple_barrier::ctx_init(&tr_src_bctx[i]); - } - - const int ndims = src_d.ndims(); - const int wei_size = jcp.ngroups * jcp.oc * jcp.ic; - - simple_barrier::ctx_t reduction_barrier; - simple_barrier::ctx_init(&reduction_barrier); - - const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, - prefix_reducer_bia); - auto rb = this->reducer_bias_; - rb->init(reducer_bia_scratchpad); - - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - - const int nb_ic = jcp.nb_bcast; - const int nb_ic_blocking = jcp.nb_bcast_blocking; - - const int nb_oc = jcp.nb_load; - const int nb_oc_blocking = jcp.nb_load_blocking; - - const int sp_nb = jcp.nb_reduce; - const int mb_sp_work = jcp.mb * sp_nb; - - const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[ndims - 3]; - const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][ndims - 3]; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - // TODO: use memory descriptor with the same fmt as src - // (or use a macro :)) - auto tr_src_off = [&](int img, int icb, int is) { - const size_t tr_chn_size = jcp.tr_is * jcp.ic_block; - const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups; - return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block; - }; - - auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size, - int g_start, int g_work, int ic_b_start, int ic_b_work, - int ithr, int nthr, int first_ic_b) - { - const int work_amount = g_work * ic_b_work; - - int start{ 0 }, end{ 0 }; - balance211(work_amount, nthr, ithr, start, end); - - int g{ 0 }, ic_b{ 0 }; - nd_iterator_init(start, g, g_work, ic_b, ic_b_work); - g += g_start; - const int ic_b_tr = g * nb_ic + first_ic_b + ic_b; - ic_b += ic_b_start; - - const int _ic = g * nb_ic + ic_b; - - const int is = sp_b_start * jcp.reduce_block; - const int ih = is / jcp.iw; - const int iw = is % jcp.iw; - - const int src1_off = data_blk_off(src_d, img, _ic, ih, iw); - data_t *src1 = (data_t *)&src[src1_off]; - data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)]; - - assert(jcp.ic_block == 16); - const int src_stride = jcp.is * jcp.ic_block; - const int tr_src_stride = jcp.tr_is * jcp.ic_block; - - const int my_work = end - start; - for (int iwork = 0; iwork < my_work; iwork++) { - auto par_trans = jit_src_transpose_s(); - assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4); - par_trans.size = sp_size; - par_trans.src = src1; - par_trans.tr_src = tr_src1; - par_trans.src_prf = src1 + 64 * 16; - par_trans.tr_src_prf = tr_src1 + 80 * 16; - trans_kernel_->jit_ker(&par_trans); - - src1 += src_stride; - tr_src1 += tr_src_stride; - } - }; - - auto ker = [&](const int ithr, const int nthr) { - assert(nthr == jcp.nthr); - assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1)); - - const int ithr_ic_b = ithr % jcp.nthr_ic_b; - const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; - const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; - const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / - jcp.nthr_g; - - const int ithr_but_oc - = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b; - - /* reduction dimension */ - int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 }; - if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) { - // it's preferable to parallelize by mb if possible - int img_start{ 0 }, img_end{ 0 }; - balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end); - mb_sp_b_start = img_start * sp_nb; - mb_sp_b_end = img_end * sp_nb; - } - else { - balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, - mb_sp_b_end); - } - - /* independent dimensions */ - int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 }; - int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 }; - - balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); - balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, - oc_b_end); - balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, - ic_b_end); - - const int g_work = g_end - g_start; - const int oc_b_work = oc_b_end - oc_b_start; - const int ic_b_work = ic_b_end - ic_b_start; - - data_t *diff_wei = ithr_mb == 0 - ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size; - - int sp_b_step = 0; - for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; - mb_sp_b += sp_b_step) { - int img{ 0 }, sp_b{ 0 }; - nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); - sp_b_step = step(jcp.nb_reduce_blocking, - nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), - jcp.nb_reduce_blocking_max); - - for (int g = g_start; g < g_end; ++g) { - int load_step = 0; - int bcast_step = 0; - for (int ic_b = ic_b_start; ic_b < ic_b_end; - ic_b += bcast_step) { - bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, - jcp.nb_bcast_blocking_max); - if (jcp.transpose_src) { - if (jcp.nthr_oc_b > 1) - simple_barrier::barrier( - &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); - const int sp_size - = nstl::min(sp_b_step * jcp.reduce_block, - jcp.is - sp_b * jcp.reduce_block); - uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b, - bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start); - if (jcp.nthr_oc_b > 1) - simple_barrier::barrier( - &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); - } - - for (int oc_b = oc_b_start; oc_b < oc_b_end; - oc_b += load_step) { - load_step = step(nb_oc_blocking, oc_b_end - oc_b, - jcp.nb_load_blocking_max); - const int _ic_b = g * nb_ic + ic_b; - const int _ic_b_tr = g * nb_ic + ic_b_start; - const int _oc_b = g * nb_oc + oc_b; - - data_t *store_to; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b); - store_to = diff_wei + off; - - const data_t *diff_src = jcp.transpose_src ? - &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] : - &src[src_d.blk_off(img, _ic_b)]; - - int sp_b_end = sp_b + sp_b_step; - const data_t *pdiff_dst - = &diff_dst[diff_dst_d.blk_off(img, _oc_b)]; - const data_t *local_src = diff_src; - - auto p = jit_1x1_conv_call_s(); - auto rp = rtus_driver_t::call_params_t(); - - p.output_stride - = jcp.ic * jcp.oc_block * jcp.typesize_out; - - p.load_dim = load_step * jcp.oc_block; - - p.bcast_dim = bcast_step * jcp.ic_block; - rp.icb = bcast_step; - p.output_data = store_to; - - p.reduce_dim = sp_b_step * jcp.reduce_block; - rp.os = p.reduce_dim; - - p.first_last_flag = 0 - | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0) - | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); - - int sp = sp_b * jcp.reduce_block; - p.load_data = pdiff_dst + sp * jcp.oc_block; - - if (pd()->rtus_.reduce_src_) { - const int oh = sp / jcp.ow; - const int ow = sp % jcp.ow; - - const int ih = nstl::max(oh * stride_h - pad_t, 0); - const int iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - rp.ws = rtus_space - + ithr * pd()->rtus_.space_per_thread_ - + sp * jcp.ic_block; - - if (ndims == 3) - rp.src = local_src + iw - * src_d.blocking_desc().strides[2]; - else - rp.src = local_src + ih - * src_d.blocking_desc().strides[2] - + iw * src_d.blocking_desc().strides[3]; - rtus_driver_->ker_(&rp); - - p.bcast_data = rp.ws; - } else - p.bcast_data = local_src + sp * jcp.ic_block; - - kernel_->jit_ker(&p); - } - } - } - } - - /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ - if (jcp.nthr_mb > 1) { - simple_barrier::barrier(&reduction_barrier, jcp.nthr); - const int work = g_work * oc_b_work * ic_b_work; - int start{ 0 }, end{ 0 }; - balance211(work, jcp.nthr_mb, ithr_mb, start, end); - if (start == end) - return; - - for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { - int w = start; - int sub_g_start{ 0 }, sub_oc_b_start{ 0 }, - sub_ic_b_start{ 0 }; - nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, - oc_b_work, sub_ic_b_start, ic_b_work); - while (w < end) { - const int g = g_start + sub_g_start; - const int oc_b = oc_b_start + sub_oc_b_start; - const int ic_b = ic_b_start + sub_ic_b_start; - - const int acc_size - = nstl::min(end - w, ic_b_work - sub_ic_b_start) - * jcp.ic_block * jcp.oc_block; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b); - data_t *d = diff_weights + off; - data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; - - acc_ker_->accumulate(d, s, acc_size); - - nd_iterator_jump(w, end, sub_g_start, g_work, - sub_oc_b_start, oc_b_work, sub_ic_b_start, - ic_b_work); - } - } - } - }; - - auto ker_bias = [&](int ithr, int nthr) { - assert(nthr == rb->balancer().nthr_); - - const int b_job_start = rb->balancer().ithr_job_off(ithr); - const int b_njobs = rb->balancer().ithr_njobs(ithr); - - if (b_njobs == 0) - return; - - /* reduction dimension */ - int img_start{ 0 }, img_end{ 0 }; - - balance211(jcp.mb, rb->balancer().nthr_per_group_, - rb->balancer().id_in_group(ithr), img_start, img_end); - - /* jobs */ - int g_start{ 0 }, ocb_start{ 0 }; - nd_iterator_init( - b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); - - for (int img = img_start; img < img_end; ++img) { - int g = g_start, ocb = ocb_start; - for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { - const size_t _oc = g * jcp.nb_load + ocb; - - const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; - data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, - reducer_bia_scratchpad) - + b_job_loc * rb->balancer().job_size_; - - if (img == img_start) - for (int o = 0; o < 16; ++o) - d_bias[o] = 0.; - - for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { - PRAGMA_OMP_SIMD() - for (int o = 0; o < 16; ++o) - d_bias[o] += d_dst[o]; - d_dst += 16; - } - - nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); - } - } - rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); - }; - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - ker(ithr, jcp.nthr); - if (pd()->with_bias()) - ker_bias(ithr, jcp.nthr); - }); - - /* TODO: put this in ker_bias */ - if (pd()->wants_padded_bias()) { - assert(jcp.ngroups == 1); - utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding); - } -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp deleted file mode 100644 index 2e9fda76d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp +++ /dev/null @@ -1,344 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP -#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" -#include "cpu_reducer.hpp" - -#include "jit_avx512_common_1x1_conv_kernel.hpp" -#include "jit_uni_1x1_conv_utils.hpp" -#include "jit_transpose_src_utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), - jit_avx512_common_1x1_convolution_fwd_t); - - status_t init() { - using namespace utils; - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, wei_type, dst_type, dst_type, - data_type::undef) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, dst_md()); - - status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( - jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(), - mkldnn_get_max_threads(), rtus_.reduce_src_); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, - jcp_); - - rtus_prepare_space_info(this, scratchpad); - - return status::success; - } - - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr), rtus_driver_(nullptr) - { - kernel_ = - new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); - init_rtus_driver(this); - } - - ~jit_avx512_common_1x1_convolution_fwd_t() { - delete kernel_; - delete rtus_driver_; - } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - - private: - void execute_forward(const exec_ctx_t &ctx) const; - void execute_forward_thr(const int ithr, const int nthr, - const src_data_t *src, const wei_data_t *weights, - const dst_data_t *bias, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_common_1x1_conv_kernel *kernel_; - rtus_driver_t *rtus_driver_; -}; - -using jit_avx512_common_1x1_convolution_fwd_f32_t - = jit_avx512_common_1x1_convolution_fwd_t; - -template -struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t { - struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), - jit_avx512_common_1x1_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(diff_src_type, wei_type, data_type::undef, - diff_dst_type, data_type::undef) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *diff_src_d = diff_src_md(); - rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); - - status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( - jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), - *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, - jcp_); - - rtus_prepare_space_info(this, scratchpad); - - return status::success; - } - - // TODO (Roma): structs conf header cleanup - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr), rtus_driver_(nullptr) - { - kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, - *pd()->attr()); - init_rtus_driver(this); - } - - ~jit_avx512_common_1x1_convolution_bwd_data_t() { - delete kernel_; - delete rtus_driver_; - } - - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_src_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - - private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_common_1x1_conv_kernel *kernel_; - rtus_driver_t *rtus_driver_; -}; - -using jit_avx512_common_1x1_convolution_bwd_data_f32_t - = jit_avx512_common_1x1_convolution_bwd_data_t; - -struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t -{ - struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), - jit_avx512_common_1x1_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, diff_dst_md()); - - status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( - jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), - *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); - if (status != status::success) return status; - - init_balancers(); - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, - jcp_); - - auto reducer_bia_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_bia); - reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); - - rtus_prepare_space_info(this, scratchpad); - - return status::success; - } - - // TODO (Roma): structs conf header cleanup - jit_1x1_conv_conf_t jcp_; - cpu_reducer_t::conf_t reducer_bia_conf_; - reduce_to_unit_stride_t rtus_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - - private: - void init_balancers() { - const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; - if (with_bias()) { - reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, - jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, - jcp_.mb, max_buffer_size)); - } - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd); - - ~jit_avx512_common_1x1_convolution_bwd_weights_t() { - delete kernel_; - delete acc_ker_; - delete reducer_bias_; - delete rtus_driver_; - delete trans_kernel_; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - - private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_common_1x1_conv_kernel *kernel_; - cpu_accumulator_1d_t *acc_ker_; - cpu_reducer_t *reducer_bias_; - jit_transpose4x16_src *trans_kernel_; - rtus_driver_t *rtus_driver_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp deleted file mode 100644 index 235fb02fe..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp +++ /dev/null @@ -1,4539 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_barrier.hpp" - -#include "jit_avx512_common_conv_kernel.hpp" - -#define GET_OFF(field) offsetof(jit_conv_call_s, field) -#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -namespace { - -constexpr auto small_spatial = 14; -unsigned int L1_cache_size = get_cache_size(1, true); - -inline void pick_loop_order(jit_conv_conf_t &jcp) { - using namespace prop_kind; - assert(one_of(jcp.prop_kind, - forward_training, forward_inference, backward_data)); - auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; - auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; - - // ow-threading is currently implemented for forward only - // TODO: single code for fwd and bwd after ow-thr for bwd - // meaningless switch was removed - if (jcp.prop_kind == backward_data) { - jcp.loop_order = (w <= small_spatial && h <= small_spatial) - ? loop_cgn : loop_gnc; - } else { - jcp.loop_order = (w <= small_spatial && h <= small_spatial) - ? loop_cwgn : loop_gncw; - } -} - -inline bool is_1stconv(const jit_conv_conf_t &jcp) { - if (mayiuse(avx512_core)) - return (jcp.ic < 16 && jcp.ngroups == 1); - else - return one_of(jcp.ic, 1, 3); -} - -inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { - return (jcp.nb_ow > 1); -} - -inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) { - return (jcp.ver == ver_4fma && is_ow_threading_on(jcp)); -} - -} - -template -void _jit_avx512_common_conv_fwd_kernel::prepare_output(int ur_w) -{ - for (int k = 0; k < jcp.nb_oc_blocking; k++) - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - vpxord(vmm, vmm, vmm); - if (!is_owb_prefetching(jcp)) { - size_t aux_output_offset = get_output_offset(j, k); - mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, - aux_output_offset, reg_out_long_offt)); - } - } -} - -template -void _jit_avx512_common_conv_fwd_kernel::store_output(int ur_w) -{ - Label no_update_label, store_label, eltwise_label; - - mov(reg_channel, ptr[param1 + GET_OFF(channel)]); - if (jcp.with_bias) { - mov(reg_bias, ptr[param1 + GET_OFF(bias)]); - } - - if (!jcp.with_sum) { - cmp(reg_channel, 0); - je(no_update_label, T_NEAR); - } - - for (int k = 0; k < jcp.nb_oc_blocking; k++) - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - size_t aux_output_offset = get_output_offset(j, k); - vaddps(vmm, - make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); - } - - if (!jcp.with_sum) { - jmp(eltwise_label, T_NEAR); - } else { - cmp(reg_channel, 0); - jne(eltwise_label, T_NEAR); - } - - L(no_update_label); - if (jcp.with_bias) { - for (int k = 0; k < jcp.nb_oc_blocking; k++) { - int bias_offset = jcp.typesize_out * k * jcp.oc_block; - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset)); - } - mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); - } - } - - L(eltwise_label); - if (jcp.with_eltwise) { - cmp(reg_channel, jcp.nb_ic - 1); - jl(store_label, T_NEAR); - - if (ur_w == jcp.ur_w) { - eltwise_injector_->compute_vector_range(0, - jcp.nb_oc_blocking * jcp.ur_w); - } else { - for (int k = 0; k < jcp.nb_oc_blocking; k++) - eltwise_injector_->compute_vector_range(k * jcp.ur_w, - k * jcp.ur_w + ur_w); - } - } - - L(store_label); - for (int k = 0; k < jcp.nb_oc_blocking; k++) - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - size_t aux_output_offset = (size_t)typesize * - ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; - vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, - reg_out_long_offt), vmm); - if (!is_owb_prefetching(jcp)) - mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, - aux_output_offset, reg_out_long_offt)); - } -} - -template -void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, - int pad_l, int pad_r) -{ -} - -template<> -void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma_1st(int ur_w, - int pad_l, int pad_r) -{ - assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0); - - int iw = jcp.iw; - int ih = jcp.ih; - int kw = jcp.kw; - int stride_w = jcp.stride_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - - Label kh_label, kd_label; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_inp, reg_inp); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_inp_prf, reg_inp_prf); - } - - size_t max_input_offset = (size_t)jcp.typesize_in - * ((size_t)(kw + ur_w * stride_w - pad_l) - + (size_t)ic_block * iw * ih * jcp.id); - assert(reg_inp_prf == reg_long_offt); - if (max_input_offset > INT_MAX) push(reg_inp_prf); - - if (jcp.ndims == 5) { - push(reg_out_prf); - push(reg_out); - - mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); - mov(aux_reg_inp_d, reg_inp); - mov(aux_reg_inp_d_prf, reg_inp_prf); - - L(kd_label); - } - mov(reg_kj, reg_kh); - if (jcp.ndims == 5) { - mov(aux_reg_inp, aux_reg_inp_d); - mov(aux_reg_ker, aux_reg_ker_d); - mov(aux_reg_inp_prf, aux_reg_inp_d_prf); - } - - L(kh_label); - for (int ki = 0; ki < kw; ki += 4) { - for (int ic = 0; ic < ic_block; ic++) { - for (int i = 0; i < 4; i++) { - int aux_ker_offset - = jcp.typesize_in - * ((ki + i) * oc_block - + ic * kw * jcp.kh * jcp.kd * oc_block); - if (ki + i < kw) - vmovups(vmm_ker(i), - EVEX_compress_addr(aux_reg_ker, aux_ker_offset)); - else - vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i)); - } - - int j_start = get_ow_start(ki, pad_l); - int j_end = get_ow_end(ur_w, ki, pad_r); - - for (int j = j_start, prf_count=0; j < j_end; j++) { - size_t aux_input_offset = (size_t)jcp.typesize_in - * ((size_t)(ki + j * stride_w - - pad_l) + (size_t)ic * iw * ih * jcp.id); - v4fmaddps(vmm_out(j, 0), vmm_ker(0), - EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, - reg_long_offt)); - if (ki + prf_count < kw && prf_count < 4 - && ((ki < 2 && j % 4) || j % 2)) { - int aux_ker_offset = jcp.typesize_in - * ((ki + prf_count) * oc_block - + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block); - mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, - aux_ker_offset)); - prf_count++; - } - if (ki == 0 - && j % (64 / (stride_w * jcp.typesize_in)) == 0) { - mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, - aux_input_offset, reg_long_offt)); - } - if (ki == 1 - && j % (64 / (stride_w * jcp.typesize_in)) == 0) { - mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, - aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); - } - } - } - } - add(aux_reg_ker, jcp.typesize_in * kw * oc_block); - add(aux_reg_inp, jcp.typesize_in * iw); - add(aux_reg_inp_prf, jcp.typesize_in * iw); - - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - - if (jcp.ndims == 5) { - add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw); - add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block); - add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - - pop(reg_out); - pop(reg_out_prf); - } - - if (max_input_offset > INT_MAX) pop(reg_inp_prf); -} - -template -void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, - int pad_l, int pad_r) -{ -} - -template<> -void _jit_avx512_common_conv_fwd_kernel::compute_loop_4fma(int ur_w, - int pad_l, int pad_r) -{ - int stride_w = jcp.stride_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - Label kh_label, last_iter_label, loop_end_label, kd_label; - int ker_load_number = 4; - int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block; - int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; - - bool check_last_kh = (jcp.kh > 3); - bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28); - - int oi_ipref_t0 = get_ow_start(0, pad_l); - int ow_end_ipref = get_ow_end(ur_w, 0, pad_r); - - assert(jcp.oc % jcp.nb_oc_blocking == 0); - - auto kernel_offset = [=](int ocb, int ic, int ki) { - int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; - int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; - int ic_offset = ic * jcp.oc_block; - return typesize * (blk_offset + ic_offset); - }; - auto kernel_loads = [=](int ki, int ic, int kk) { - for (int ii = 0; ii < ker_load_number; ii++) { - int aux_kernel_offset = kernel_offset(kk, ic + ii, ki); - vmovups(vmm_ker(ii), - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - } - }; - auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { - if (cnt1 >= ker_load_number && cnt0 >= ker_load_number - && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) { - int aux_inp_offset - = typesize - * ((oi_ipref_t0 * stride_w - pad_l) * ic_block - + (jcp.dilate_h + 1) * jcp.iw * ic_block); - prefetcht0(EVEX_compress_addr(aux_reg_inp, - aux_inp_offset)); - oi_ipref_t0++; - } - }; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_inp, reg_inp); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_ker_prf, reg_ker_prf); - mov(aux_reg_inp_prf, reg_inp_prf); - } - - if (jcp.ndims == 5) { - push(reg_out_prf); - push(reg_out); - - mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); - mov(aux_reg_inp_d, reg_inp); - mov(aux_reg_inp_d_prf, reg_inp_prf); - mov(aux_reg_ker_d_prf, reg_ker_prf); - L(kd_label); - mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - } else { - mov(reg_kj, reg_kh); - } - if (jcp.ndims == 5) { - mov(aux_reg_inp, aux_reg_inp_d); - mov(aux_reg_ker, aux_reg_ker_d); - mov(aux_reg_ker_prf, aux_reg_ker_d_prf); - mov(aux_reg_inp_prf, aux_reg_inp_d_prf); - } - - align(16); - L(kh_label); - int kw = jcp.kw; - if (check_last_kh) { - for (int ki = 0; ki < kw; ki++) - for (int ic = 0; ic < ic_block; ic += 4) - for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { - bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1 - && ki == kw - 1 && (ic + 4) == ic_block); - - if (last_kernel_loads) { - cmp(reg_kj, 1); - je(last_iter_label, T_NEAR); - } - - kernel_loads(ki, ic, kk); - for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, - prf_count_t0 = 0; - oi < get_ow_end(ur_w, ki, pad_r); oi++) { - int aux_input_offset = typesize - * ((ki * (jcp.dilate_w + 1) + oi * stride_w - - pad_l) * ic_block - + ic); - v4fmaddps(vmm_out(oi, kk), vmm_ker(0), - EVEX_compress_addr(aux_reg_inp, aux_input_offset)); - - if (oi % 2) { - if (prf_count_t0 < 4) { - int aux_kernel_prf; - if (last_kernel_loads) - aux_kernel_prf= kernel_offset(0, - prf_count_t0 + ic + 4 - - ic_block, 0) + typesize * kw - * oc_block * ic_block; - else - aux_kernel_prf = kernel_offset(kk, ic + 4 - + prf_count_t0, ki); - mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, - aux_kernel_prf)); - prf_count_t0++; - } else if (prf_count_t1 < 4) { - mic_prefetcht1(EVEX_compress_addr( - aux_reg_ker_prf, kernel_offset(kk, ic - + prf_count_t1, ki))); - prf_count_t1++; - } - } else - prefetch_inp_next_kh(ki, 2, prf_count_t0, - prf_count_t1); - } - - if (last_kernel_loads) { - jmp(loop_end_label, T_NEAR); - - L(last_iter_label); - - kernel_loads(ki, ic, kk); - for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, - prf_count_t0 = 0; - oi < get_ow_end(ur_w, ki, pad_r); oi++) { - int aux_input_offset = typesize - * ((ki * (jcp.dilate_w + 1) + oi * stride_w - - pad_l) * ic_block - + ic); - v4fmaddps(vmm_out(oi, kk), vmm_ker(0), - EVEX_compress_addr(aux_reg_inp, - aux_input_offset)); - if (oi % 2) { - if (prf_count_t0 < 4) { - mic_prefetcht0(EVEX_compress_addr( - aux_reg_ker_prf, kernel_offset(0, - prf_count_t0, 0))); - prf_count_t0++; - } else if (prf_count_t1 < 4) { - mic_prefetcht1(EVEX_compress_addr( - aux_reg_ker_prf, kernel_offset(kk, - ic + prf_count_t1, ki))); - prf_count_t1++; - } - } - } - L(loop_end_label); - } - } - } else { - for (int ki = 0; ki < kw; ki++) - for (int ic = 0; ic < ic_block; ic += 4) - for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { - kernel_loads(ki, ic, kk); - for (int oi = get_ow_start(ki, pad_l), - prf_count_t1 = 0, prf_count_t0 = 0; - oi < get_ow_end(ur_w, ki, pad_r); oi++) { - int aux_input_offset = typesize - * ((ki * (jcp.dilate_w + 1) + oi * stride_w - - pad_l) * ic_block + ic); - v4fmaddps(vmm_out(oi, kk), vmm_ker(0), - EVEX_compress_addr(aux_reg_inp, - aux_input_offset)); - - if (!is_owb_prefetching(jcp)) { - if ((oi % 2) && (prf_count_t1 < 4)) { - mic_prefetcht1(EVEX_compress_addr( - aux_reg_ker_prf, kernel_offset(kk, - ic + prf_count_t1, ki))); - prf_count_t1++; - } - } else { - if (!(ki == 0 && ic == 0) - && !(ki == kw-1 && ic == 0) && - (oi % 2) && (prf_count_t1 < 4) - ) { - mic_prefetcht0(EVEX_compress_addr( - aux_reg_ker, kernel_offset(kk, - ic + 4 + prf_count_t0, ki))); - prf_count_t0++; - } - } - if (!is_owb_prefetching(jcp)) { - if (pref_current_inp) { - if (ki == 0 && ic == 0 && kk == 0) - mic_prefetcht0(EVEX_compress_addr( - aux_reg_inp, - aux_input_offset + shift_input_ptr)); - } else { - if (ki == 1 && ic == 0 && kk == 0) - mic_prefetcht1(EVEX_compress_addr( - aux_reg_inp_prf, aux_input_offset)); - } - } else { - int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - int inp_shift - = jcp.typesize_in * ur_w * stride_w * inp_mult; - bool kk_pref_slot = kk ? oi % 2 : !(oi % 2); - if (ki == 0 && ic == 0 && kk_pref_slot) - mic_prefetcht1(EVEX_compress_addr( - aux_reg_inp, - aux_input_offset + inp_shift)); - - if (ki == kw - 1 && ic == 0 && kk_pref_slot) - mic_prefetcht0(EVEX_compress_addr( - aux_reg_inp, - aux_input_offset + inp_shift)); - } - } - } - } - - add(aux_reg_ker, shift_kernel_ptr); - add(aux_reg_inp, shift_input_ptr); - add(aux_reg_ker_prf, shift_kernel_ptr); - add(aux_reg_inp_prf, shift_input_ptr); - - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - - if (jcp.ndims == 5) { - add(aux_reg_inp_d, - typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); - add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block - * jcp.ic_block); - add(aux_reg_inp_d_prf, - typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); - add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block - * jcp.ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - - pop(reg_out); - pop(reg_out_prf); - } -} - -template -void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma(int ur_w, - int pad_l, int pad_r) -{ - bool prf_ker = true; - bool prf_inp = true; - int ih = jcp.ih; - int stride_w = jcp.stride_w; - int id = jcp.id; - int iw = jcp.iw; - int kw = jcp.kw; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int nb_oc_block = jcp.nb_oc_blocking; - Label kh_label, kd_label; - - int ker_pipeline_depth = 4; - assert(ker_reg_base_idx + ker_pipeline_depth <= 32); - assert(oc_block >= ker_pipeline_depth); - - int num_ker_loads = ic_block * nb_oc_block * kw; - int num_ker_prfs = prf_ker ? num_ker_loads : 0; - int num_inp_prfs = prf_inp ? - ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) : - 0; - if (jcp.is_1stconv && prf_inp) { - num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block; - } - int num_prfs = num_ker_prfs + num_inp_prfs; - int num_fmas = num_ker_loads * ur_w; - int prf_inst_spacing - = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1; - int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; - int inp_mul = !jcp.is_1stconv ? ic_block : 1; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_inp, reg_inp); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_inp_prf, reg_inp_prf); - mov(aux_reg_ker_prf, reg_ker_prf); - } - - size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; - assert(reg_inp_prf == reg_long_offt); - if (max_input_offset > INT_MAX) push(reg_inp_prf); - - - if (jcp.ndims == 5) { - push(reg_out_prf); - push(reg_out); - - mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); - mov(aux_reg_inp_d, reg_inp); - mov(aux_reg_inp_d_prf, reg_inp_prf); - mov(aux_reg_ker_d_prf, reg_ker_prf); - - L(kd_label); - mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - } else { - mov(reg_kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_inp, aux_reg_inp_d); - mov(aux_reg_ker, aux_reg_ker_d); - mov(aux_reg_ker_prf, aux_reg_ker_d_prf); - mov(aux_reg_inp_prf, aux_reg_inp_d_prf); - } - - align(16); - L(kh_label); - { - int step = 0; - int ker_prfs = 0; - for (int ki = 0; ki < kw; ki++) { - for (int ic = 0; ic < ic_block; ic++) { - int aux_kernel_offset = 0; - if (step == 0) { - for (int i = 0; i < ker_pipeline_depth; i++) { - aux_kernel_offset = get_kernel_offset(ki, ic, 0, i); - vmovups(vmm_ker(i), EVEX_compress_addr( - aux_reg_ker, aux_kernel_offset)); - } - } else if (step < num_ker_loads - ker_pipeline_depth + 1) { - int load_offset = ker_pipeline_depth - 1; - int ker_load_reg_idx - = (step + load_offset) % ker_pipeline_depth; - aux_kernel_offset - = get_kernel_offset(ki, ic, 0, load_offset); - vmovups(vmm_ker(ker_load_reg_idx), - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - } - - bool ker_prf_inserted = false; - Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth); - int j_start = get_ow_start(ki, pad_l); - int j_end = get_ow_end(ur_w, ki, pad_r); - for (int j = j_start; j < j_end; j++) { - size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); - auto addr = EVEX_compress_addr_safe(aux_reg_inp, - aux_input_offset, reg_long_offt, true); - vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr); - int fma_idx = step * ur_w + j; - int prf_slot_idx = fma_idx / prf_inst_spacing; - if (fma_idx % prf_inst_spacing == prf_inst_trigger) { - if (prf_ker && !ker_prf_inserted - && ker_prfs < num_ker_prfs) { - int ker_prf_offset - = jcp.typesize_in * ker_prfs * jcp.oc_block; - mic_prefetcht2(EVEX_compress_addr( - aux_reg_ker_prf, ker_prf_offset)); - ker_prf_inserted = true; - ker_prfs++; - } else if (prf_inp) { - int inp_prf_idx = prf_slot_idx - ker_prfs; - if (inp_prf_idx < num_inp_prfs) { - size_t inp_prf_stride = nstl::max(kw, stride_w); - size_t inp_prf_offset; - if (!jcp.is_1stconv) { - inp_prf_offset - = ic_block * jcp.typesize_in - * ((inp_prf_idx / kw) - * inp_prf_stride - + (inp_prf_idx % kw)); - } else { - size_t ic_prf_stride = - (size_t)jcp.typesize_in * iw * ih * id; - size_t iw_prf_stride - = jcp.typesize_in * jcp.simd_w; - inp_prf_offset = ((inp_prf_idx / ic_block) - * iw_prf_stride - + (inp_prf_idx % ic_block) - * ic_prf_stride); - } - mic_prefetcht0(EVEX_compress_addr_safe( - aux_reg_inp_prf, inp_prf_offset, - reg_long_offt)); - } - } - } - } - step++; - } - } - add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block); - if (prf_ker) - add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block); - add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); - if (prf_inp) - add(aux_reg_inp_prf, - jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - } - - - if (jcp.ndims == 5) { - add(aux_reg_inp_d, - typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); - add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block - * jcp.ic_block); - add(aux_reg_inp_d_prf, - typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); - add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block - * jcp.ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - - pop(reg_out); - pop(reg_out_prf); - } - if (max_input_offset > INT_MAX) pop(reg_inp_prf); -} - -template -void _jit_avx512_common_conv_fwd_kernel::compute_loop_fma_core(int ur_w, - int pad_l, int pad_r) -{ - int kw = jcp.kw; - int stride_w = jcp.stride_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int nb_oc_block = jcp.nb_oc_blocking; - Label kh_label, kd_label; - int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block - * jcp.ic_block; - int inp_mul = !jcp.is_1stconv ? ic_block : 1; - int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw - * inp_mul; - - - auto input_offset = [=](int oi, int ic, int ki) { - return (size_t)jcp.typesize_in - * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) - * inp_mul + (size_t)ic - * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); - }; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_inp, reg_inp); - mov(aux_reg_ker, reg_ker); - } - - if (jcp.ndims == 5) { - push(reg_out); - - mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); - mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); - mov(aux_reg_inp_d, reg_inp); - - L(kd_label); - mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - } else { - mov(reg_kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_inp, aux_reg_inp_d); - mov(aux_reg_ker, aux_reg_ker_d); - } - - L(kh_label); - { - for (int ki = 0; ki < kw; ki++) { - int jj_start = get_ow_start(ki, pad_l); - int jj_end = get_ow_end(ur_w, ki, pad_r); - for (int ic = 0; ic < ic_block; ic++) { - if (jcp.kernel_kind == expl_bcast) { - for (int jj = jj_start; jj < jj_end; jj++) { - size_t aux_input_offset = input_offset(jj, ic, ki); - vbroadcastss(vmm_inp(jj, nb_oc_block), - EVEX_compress_addr_safe(aux_reg_inp, - aux_input_offset, reg_long_offt)); - } - } - for (int ii = 0; ii < nb_oc_block; ii++) { - int aux_kernel_offset = jcp.typesize_in - * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block - * oc_block + ki * ic_block * oc_block + ic * oc_block); - if (jj_end - jj_start > 0) - vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker, - aux_kernel_offset)); - for (int jj = jj_start; jj < jj_end; jj++) - if (jcp.kernel_kind == expl_bcast) - vfmadd231ps(vmm_out(jj, ii), - vmm_inp(jj, nb_oc_block), vmm_wei); - else { - size_t aux_input_offset = input_offset(jj, ic, ki); - vfmadd231ps(vmm_out(jj, ii), vmm_wei, - EVEX_compress_addr_safe(aux_reg_inp, - aux_input_offset, reg_long_offt, true)); - } - } - } - } - add(aux_reg_ker, shift_kernel_ptr); - add(aux_reg_inp, shift_input_ptr); - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - } - - if (jcp.ndims == 5) { - add(aux_reg_inp_d, - typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); - add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block - * jcp.ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - - pop(reg_out); - } -} - -template -void _jit_avx512_common_conv_fwd_kernel::compute_loop(int ur_w, - int pad_l, int pad_r) -{ - if (jcp.ndims == 5) push(reg_oi); - - prepare_output(ur_w); - - Label skip_compute_loop; - if (jcp.ndims == 5) { - if ((jcp.dilate_d >= jcp.id) - || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { - mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); - cmp(reg_kj, 0); - je(skip_compute_loop, T_NEAR); - } - } - if ((jcp.dilate_h >= jcp.ih) - || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { - mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - cmp(reg_kj, 0); - je(skip_compute_loop, T_NEAR); - } - - if (jcp.ver == ver_4fma) - if(jcp.is_1stconv) - compute_loop_4fma_1st(ur_w, pad_l, pad_r); - else - compute_loop_4fma(ur_w, pad_l, pad_r); - else if (jcp.ver == ver_fma) - if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast) - || mayiuse(avx512_mic)) - compute_loop_fma(ur_w, pad_l, pad_r); - else - if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1) - compute_loop_fma(ur_w, pad_l, pad_r); - else - compute_loop_fma_core(ur_w, pad_l, pad_r); - else - assert(!"unknown convolution version"); - - L(skip_compute_loop); - store_output(ur_w); - if (jcp.ndims == 5) pop(reg_oi); -} - -template -void _jit_avx512_common_conv_fwd_kernel::generate() -{ - int iw = jcp.iw; - int ow = jcp.ow; - int ow_block = jcp.ow_block; - int nb_ow = jcp.nb_ow; - int kw = jcp.kw; - int l_pad = jcp.l_pad; - int ur_w = jcp.ur_w; - int ur_w_tail = jcp.ur_w_tail; - int dilate_w = jcp.dilate_w + 1; - int stride_w = jcp.stride_w; - - int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; - int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; - int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; - int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; - - preamble(); - mov(reg_inp, ptr[param1 + GET_OFF(src)]); - mov(reg_out, ptr[param1 + GET_OFF(dst)]); - mov(reg_ker, ptr[param1 + GET_OFF(filt)]); - mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); - mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); - - int r_pad = nstl::max( - 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); - int n_oi = ow / ur_w; - int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w - - (iw + l_pad - 1); - - if (!is_ow_threading_on(jcp)) { - // ow is being processed as a whole - with left and right paddings - if (r_pad1 > 0) n_oi--; - - if (ow == ur_w) { - mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]); - mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]); - compute_loop(ur_w, l_pad, r_pad); - } else { - mov(reg_inp_prf, reg_inp); - mov(reg_out_prf, reg_out); - if (n_oi == 0) { - add(reg_inp_prf, inp_shift_pad); - add(reg_out_prf, out_shift); - compute_loop(ur_w, l_pad, r_pad1); - add(reg_inp, inp_shift_pad); - add(reg_out, out_shift); - if (ur_w_tail != 0) { - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w_tail, 0, r_pad); - } - } else { - xor_(reg_oi, reg_oi); - if (l_pad > 0) { - add(reg_inp_prf, inp_shift_pad); - add(reg_out_prf, out_shift); - compute_loop(ur_w, l_pad, 0); - add(reg_inp, inp_shift_pad); - add(reg_out, out_shift); - inc(reg_oi); - } - if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { - Label ow_loop_label; - L(ow_loop_label); - { - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w, 0, 0); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - inc(reg_oi); - cmp(reg_oi, n_oi); - jl(ow_loop_label, T_NEAR); - } - } - if (r_pad1 > 0) { - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w, 0, r_pad1); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - } - if (ur_w_tail != 0) { - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w_tail, 0, r_pad); - } - } - } - } else { - // ow block is only processed. - // Number of block is passed as parameter owb, - // and padding processing depends on this number. - - Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; - Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; - - assert(ow_block % ur_w == 0); - int n_oi_not_last_ow_block = ow_block / ur_w; - // to simplify code (and general regs usage), - // size of ow block must be >= 2 * ur_w - assert(n_oi_not_last_ow_block > 1); - int n_oi_next_last_ow_block = n_oi_not_last_ow_block; - int n_oi_first_ow_block = n_oi_not_last_ow_block; - - int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w; - - // prepare right padding - bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; - bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2; - bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; - - if (last_ow_block_padded) n_oi_last_ow_block--; - else if (first_ow_block_padded) n_oi_first_ow_block--; - else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; - - mov(reg_owb, ptr[param1 + GET_OFF(owb)]); - cmp(reg_owb, 0); // is that the first ow-block ? - jg(middle_ow_blocks_label, T_NEAR); - - // the first ow block, compute left padding - - mov(reg_oi, n_oi_first_ow_block); - mov(reg_inp_prf, reg_inp); - mov(reg_out_prf, reg_out); - - if (l_pad > 0) { - mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); - add(reg_inp_prf, inp_shift_pad); - add(reg_out_prf, out_shift); - compute_loop(ur_w, l_pad, 0); - add(reg_inp, inp_shift_pad); - add(reg_out, out_shift); - dec(reg_oi); - } - jmp(oi_loop_label, T_NEAR); - - // middle or last ow block entry - - L(middle_ow_blocks_label); - - if (l_pad > 0) { - // just to consider left padding, not compute - add(reg_inp, inp_shift_pad_second_block); - add(reg_inp_prf, inp_shift_pad_second_block); - } - - // set number of iteration for oi-loop - cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? - mov(reg_oi, n_oi_last_ow_block); - je(oi_loop_label, T_NEAR); - cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? - mov(reg_oi, n_oi_next_last_ow_block); - je(oi_loop_label, T_NEAR); - mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks - - // oi loop w/o padding - L(oi_loop_label); - mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); - L(oi_loop_start_label); - cmp(reg_oi, 0); - jle(oi_loop_end_label, T_NEAR); - - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w, 0, 0); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - dec(reg_oi); - jmp(oi_loop_start_label, T_NEAR); - L(oi_loop_end_label); - - mov(reg_owb, ptr[param1 + GET_OFF(owb)]); - - cmp(reg_owb, 0); // first ow-block ? - if (first_ow_block_padded) { - je(last_oi_label, T_NEAR); - } else { - je(end_label, T_NEAR); - } - cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? - jl(end_label, T_NEAR); - if (next_last_ow_block_padded) { - je(last_oi_label, T_NEAR); - } else { - je(end_label, T_NEAR); - } - // that is last block - if (!last_ow_block_padded) { - jmp(tail_label, T_NEAR); - } - - // last oi block with right padding - L(last_oi_label); - mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w, 0, r_pad1); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - - mov(reg_owb, ptr[param1 + GET_OFF(owb)]); - cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? - jl(end_label, T_NEAR); - - L(tail_label); - mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); - if (ur_w_tail != 0) { - add(reg_inp_prf, inp_shift); - add(reg_out_prf, out_shift); - compute_loop(ur_w_tail, 0, r_pad); - } - L(end_label); - } - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -status_t jit_avx512_common_conv_fwd_kernel::init_conf( - jit_conv_conf_t &jcp, const convolution_desc_t &cd, - memory_desc_t &src_md, memory_desc_t &weights_md, - memory_desc_t &dst_md, memory_desc_t &bias_md, - const primitive_attr_t &attr, int nthreads) -{ - using namespace prop_kind; - - if (!mayiuse(avx512_common)) - return status::unimplemented; - - const memory_desc_wrapper src_d(&src_md); - const memory_desc_wrapper weights_d(&weights_md); - const memory_desc_wrapper dst_d(&dst_md); - const memory_desc_wrapper bias_d(&bias_md); - - const int regs = 28; - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - int ndims = src_d.ndims(); - - jcp = zero(); - jcp.ndims = ndims; - jcp.prop_kind = cd.prop_kind; - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; - jcp.iw = src_d.dims()[ndims-1]; - jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; - jcp.ow = dst_d.dims()[ndims-1]; - jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; - jcp.kw = weights_d.dims()[with_groups + ndims-1]; - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; - jcp.l_pad = cd.padding[0][ndims-3]; - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; - jcp.stride_w = cd.strides[ndims-3]; - - jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; - jcp.dilate_w = cd.dilates[ndims-3]; - - jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) - - (jcp.ih + jcp.t_pad - 1); - jcp.back_pad = (jcp.od - 1) * jcp.stride_d - + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); - - jcp.is_1stconv = is_1stconv(jcp); - - bool ok_to_pad_channels = true - && jcp.ngroups == 1 - && src_d.data_type() == data_type::f32; - - const int full_simd_w = cpu_isa_traits::vlen / sizeof(float); - jcp.simd_w = full_simd_w; - bool ok_to_try_xmm = true - && mayiuse(avx512_core) - && src_d.data_type() == data_type::f32 - && !jcp.is_1stconv - && !ok_to_pad_channels - && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0) - && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0); - if (ok_to_try_xmm) - jcp.simd_w = 4; - - jcp.oc_block = jcp.simd_w; - jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; - jcp.aligned_threads = 0; - - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, jcp.oc_block); - jcp.ic = rnd_up(jcp.ic, jcp.ic_block); - } - bool args_ok = true - && jcp.oc % jcp.oc_block == 0 - && jcp.ic % jcp.ic_block == 0; - if (!args_ok) - return status::unimplemented; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) { - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - if (dst_d.data_type() == data_type::s32) return status::unimplemented; - } - - auto src_tag = jcp.is_1stconv - ? pick(ndims - 3, ncw, nchw, ncdhw) - : ((jcp.simd_w == 4) - ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) - : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)); - auto dst_tag = (jcp.simd_w == 4) - ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) - : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = with_groups - ? ((jcp.simd_w == 4) - ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o) - : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)) - : ((jcp.simd_w == 4) - ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o) - : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o)); - - if (src_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md, src_tag)); - jcp.src_tag = src_tag; - } else { - jcp.src_tag = src_d.matches_one_of_tag(src_tag); - } - if (jcp.src_tag != src_tag) - return status::unimplemented; - - if (dst_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); - jcp.dst_tag = dst_tag; - } else { - jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag); - } - if (jcp.dst_tag != dst_tag) - return status::unimplemented; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - if (jcp.with_bias) { - if (bias_d.format_kind() == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md, x)); - } - - if (mayiuse(avx512_common) && - src_d.data_type() == data_type::f32 - && weights_d.data_type() == data_type::f32 - && dst_d.data_type() == data_type::f32) { - jcp.ver = ver_fma; - jcp.typesize_in = sizeof(float); - jcp.typesize_out = sizeof(float); - if (mayiuse(avx512_mic_4ops)) - jcp.ver = ver_4fma; - - if (jcp.is_1stconv) { - // TODO: fix & remove constraints below - bool not_for_4fma - = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad), - nstl::max(jcp.kw, jcp.kh) < 7); - bool is_dilated - = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); - if (one_of(true, not_for_4fma, is_dilated)) - jcp.ver = ver_fma; - if (jcp.ver == ver_4fma) { - wei_tag = with_groups - ? ((jcp.simd_w == 4) - ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o) - : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)) - : ((jcp.simd_w == 4) - ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o) - : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o)); - } else { - wei_tag = with_groups - ? ((jcp.simd_w == 4) - ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o) - : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)) - : ((jcp.simd_w == 4) - ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o) - : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o)); - } - } - } else { - return status::unimplemented; - } - - if (weights_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); - jcp.wei_tag = wei_tag; - } else { - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - } - if (jcp.wei_tag != wei_tag) - return status::unimplemented; - - if (jcp.is_1stconv) { - jcp.ur_w = nstl::min(jcp.ow, regs); - } else { - // avx512_core guard - just to avoid possible regression for other archs - if (jcp.ver == ver_fma && mayiuse(avx512_core)) { - jcp.ur_w = nstl::min(jcp.ow, regs); - } else { - for (int ur_w = regs; ur_w > 0; --ur_w) { - if (jcp.ow % ur_w == 0) { - jcp.ur_w = ur_w; - break; - } - } - } - if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) { - jcp.ur_w = nstl::min(jcp.ow, regs); - } - } - // TODO (Tanya): currently applied to Segnet convolutions only. - // Need to try for other topologies - if (jcp.ow > 150 && jcp.ur_w < regs/2) - jcp.ur_w = regs; - - int n_oi = (jcp.ow / jcp.ur_w); - int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); - if (jcp.l_pad > 0 && r_pad > 0) - n_oi--; - - bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0 - && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)); - if (large_code_size) { - const int max_code_size = 24 * 1024; - const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw; - int mult = 1; - if (jcp.l_pad > 0) mult += 1; - if (r_pad > 0) mult += 1; - for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { - if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) { - jcp.ur_w = ur_w; - break; - } - } - } - - /* Grouped channel offset to support 'non-blocked data' format for - * convolution sizes with '(input_channel / ngroups) < simd' */ - jcp.nonblk_group_off - = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ? - jcp.ic : - 1; - - jcp.nb_ic = jcp.ic / jcp.ic_block; - jcp.nb_oc = jcp.oc / jcp.oc_block; - jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; - - auto is_ow_threading_applicable = [=]() { - return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4) - && IMPLICATION(mayiuse(avx512_mic), - jcp.ver == ver_4fma - && IMPLICATION(jcp.mb != 1, - jcp.ih == 1 && jcp.kh == 1))); - }; - - if (jcp.ver == ver_4fma && !jcp.is_1stconv) { - if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8 - && jcp.oh <= 8 && jcp.ow == jcp.oh) - || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) { - if (jcp.nb_oc % 2 == 0) { - jcp.nb_oc_blocking = 2; - jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); - } - } else { - for (int i = jcp.nb_oc; i > 0; i--) - if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { - jcp.nb_oc_blocking = i; - break; - } - } - if (jcp.ver == ver_4fma && is_ow_threading_applicable()) { - if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow - && jcp.ow != 2 * jcp.ur_w) { - jcp.nb_oc_blocking = 2; - jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); - } - } - } - - jcp.ow_block = jcp.ow; - - auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { - int nb_ow = div_up(jcp.ow, ow_block); - int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); - int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; - float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); - float thr_eff = disbalance * (float)work_amount - / rnd_up(work_amount, nthreads); - return thr_eff; - }; - - auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { - int res_ow_block = jcp.ow; - eff = get_thr_eff(nb_oc_blocking, res_ow_block); - if (!is_ow_threading_applicable()) - return res_ow_block; - - int L2_part = (get_cache_size(2) * 7 / 8) / typesize; - if (jcp.ver == ver_4fma) - L2_part /= 2; - int size_src_chunk = jcp.ic_block * ur_w * jcp.kh; - int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w; - int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block - * jcp.kw * jcp.kh; - int nurw_cache = (L2_part - 2 * size_wei_chunk) - / (2 * size_dst_chunk + 2 * size_src_chunk); - // current design of generate() requires ow_block >= 2 * ur_w - int ow_block_cache = ur_w * nstl::max(2, nurw_cache); - - int ow_block_thr = ow_block_cache; - eff = get_thr_eff(nb_oc_blocking, ow_block_thr); - - int max_nb_ow = div_up(jcp.ow, 2 * ur_w); - int start_nb_ow = div_up(jcp.ow, ow_block_thr); - for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) { - int ow_block - = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); - float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f; - if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold) - break; - if (div_up(jcp.ow, ow_block) != nb_ow) - continue; - float thr_eff = get_thr_eff(nb_oc_blocking, ow_block); - float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f; - if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) { - ow_block_thr = ow_block; - eff = thr_eff; - } - eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f; - if (eff > eff_threshold) - break; - } - res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr)); - eff = get_thr_eff(nb_oc_blocking, res_ow_block); - return res_ow_block; - }; - - - if (jcp.ver == ver_fma && mayiuse(avx512_core)) { - int try_nb_oc_blocking = 2; - unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w) - * jcp.ic_block * jcp.kh * jcp.kd; - unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block - * try_nb_oc_blocking; - unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block - * jcp.oc_block * try_nb_oc_blocking * jcp.kd; - unsigned int ker_total_size = ker_inp_size + ker_out_size - + ker_wei_size; - - bool embd_bcast_condition = true - && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size) - && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192) - && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512); - - if (jcp.mb == 1) { - unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) - * div_up(jcp.iw, jcp.stride_w) * jcp.ic; - unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; - - // Estimate whether we need to limit the number of threads - // and calculate this number. Includes some heuristic. - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; - int job_size_min = work_amount / nthreads; - int job_size_max = div_up(work_amount, nthreads); - int ch_max = rnd_up(jcp.oh, job_size_max); - int ch_min = (job_size_min == 0) - ? jcp.oh - : rnd_up(jcp.oh, job_size_min); - bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 - && (jcp.oh != 8 || ch_max / jcp.oh > 1); - bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 - && (jcp.oh != 8 || ch_min / jcp.oh > 1); - bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) - || nthreads > oc_chunks; - if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1 - && wei_size / inp_size > 24 - && (not_aligned_max || not_aligned_min) - && eligible_case) { - // Try to find nthreads > mkldnn_get_max_threads() / 2 such - // that oc_chunks is a multiple of nthreads, or nthreads is a - // multiple of oc_chunks. Otherwise, keep default value. - // TODO: implement a task-based alternative without throttling. - jcp.aligned_threads = nthreads; - for (int i = nthreads; i > nthreads / 2; i--) { - if (oc_chunks % i == 0 || i % oc_chunks == 0) { - jcp.aligned_threads = i; - break; - } - } - } - } - - if (jcp.kw > 3 - || (jcp.stride_w == 1 && jcp.stride_h == 1 - && embd_bcast_condition) - || ((jcp.stride_w != 1 || jcp.stride_h != 1) - && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10) - && embd_bcast_condition))) - || (jcp.mb == 1 - && (jcp.ur_w >= jcp.ow || jcp.is_1stconv - || (jcp.ow <= 147 && jcp.oc <= 96)))) { - jcp.kernel_kind = embd_bcast; - jcp.ur_w = nstl::min(jcp.ow, regs); - jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; - if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3 - && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0 - && IMPLICATION(jcp.is_1stconv, jcp.mb == 1) - && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) { - jcp.nb_oc_blocking = try_nb_oc_blocking; - jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); - } - } else { - jcp.kernel_kind = expl_bcast; - jcp.nb_ic_blocking = 1; - if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) { - float best_thr_eff = 0.f; - int best_nb_oc_blocking = 1; - for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) { - if (jcp.nb_oc % i == 0) { - float thr_eff; - int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); - get_ow_block(i, ur_w, thr_eff); - if (thr_eff > 1.05f * best_thr_eff) { - best_nb_oc_blocking = i; - best_thr_eff = thr_eff; - } - } - } - jcp.nb_oc_blocking = best_nb_oc_blocking; - jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); - } - } - } - - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - - args_ok = true - && jcp.l_pad <= jcp.ur_w - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= dst_d.padded_dims()[1] - && jcp.ic <= weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; - if (!args_ok) - return status::unimplemented; - - int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.iw + jcp.l_pad - 1)); - if (r_pad_no_tail > jcp.ur_w) - return status::unimplemented; - - pick_loop_order(jcp); - - jcp.nb_ic_L2 = jcp.nb_ic; - - float thr_eff; - jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); - jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); - - const int L2_size = get_cache_size(2, true) / sizeof(float); - // Source and output data needs to fit in L2, - // leaving some space for weights and prefetching. - int h_L2 = int(((0.6f * L2_size) / jcp.simd_w - - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) - / (jcp.stride_h * jcp.iw + jcp.ow)); - jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); - - if (jcp.ver == ver_4fma) { - if (!is_ow_threading_on(jcp)) { - for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; - divf++) { - size_t l2_src - = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; - size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking - * jcp.oh * jcp.od; - size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block - * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; - if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { - if (jcp.kh == 3 && jcp.oh == 7) { - jcp.nb_ic_L2 = 1; - break; - } - temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf - : jcp.nb_ic_L2); - } else { - jcp.nb_ic_L2 = temp_nb; - break; - } - } - } else if (jcp.ic > 64) { - jcp.nb_ic_L2 = 2; /* according to performance data*/ - } - } - - return status::success; -} - -void jit_avx512_common_conv_fwd_kernel::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) - scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) -{ - for (int k = 0; k < jcp.nb_ic_blocking; k++) { - for (int j = 0; j < ur_w; j++) { - Zmm zmm = zmm_out(j, k); - vpxord(zmm, zmm, zmm); - size_t aux_src_offset - = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) - * jcp.ic_block; - mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, - reg_long_offt)); - } - } -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) -{ - Label no_update_label; - - mov(reg_channel, ptr[param + GET_OFF(channel)]); - cmp(reg_channel, 0); - je(no_update_label, T_NEAR); - for (int k = 0; k < jcp.nb_ic_blocking; k++) { - for (int j = 0; j < ur_w; j++) { - Zmm zmm = zmm_out(j, k); - size_t aux_src_offset = (size_t)typesize - * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; - vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, - reg_long_offt)); - } - } - - L(no_update_label); - for (int k = 0; k < jcp.nb_ic_blocking; k++) { - for (int j = 0; j < ur_w; j++) { - Zmm zmm = zmm_out(j, k); - size_t aux_src_offset = (size_t)typesize - * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; - vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, - reg_long_offt), zmm); - mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, - reg_long_offt)); - } - } -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( - int ur_w, int l_overflow, int r_overflow) -{ - int ow = jcp.ow; - int kw = jcp.kw; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - Label kh_label, last_iter_label, loop_end_label, kd_label; - int ker_load_number = 4; - int shift_ker_ptr = typesize * kw * oc_block * ic_block; - int shift_dst_ptr = typesize * ow * oc_block; - int ii_dpref_t0 = get_iw_start(0, l_overflow); - int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow); - - bool check_last_kh = (jcp.kh > 3); - auto kernel_offset = [=](int icb, int oc, int ki) { - int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; - int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; - int oc_offset = oc * jcp.oc_block; - return typesize * (blk_offset + oc_offset); - }; - auto kernel_loads = [=](int ki, int oc, int kk) { - for (int ii = 0; ii < ker_load_number; ii++) { - int aux_kernel_offset = kernel_offset(kk, oc + ii, ki); - vmovups(zmm_ker(ii), - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - } - }; - auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { - if (cnt1 >= ker_load_number && cnt0 >= ker_load_number - && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) { - int aux_dst_offset = typesize * ((ii_dpref_t0 - + jcp.l_pad) * oc_block + jcp.ow * oc_block); - prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); - ii_dpref_t0++; - } - }; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_dst, reg_dst); - mov(aux_reg_ker, reg_ker); - mov(aux_reg_dst_prf, reg_dst_prf); - mov(aux_reg_ker_prf, reg_ker_prf); - } - - if (jcp.ndims == 5) { - push(reg_src_prf); - push(reg_src); - - mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); - mov(aux_reg_dst_d, reg_dst); - mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); - mov(aux_reg_dst_d_prf, reg_dst_prf); - mov(aux_reg_ker_d_prf, reg_ker_prf); - - L(kd_label); - mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); - } else { - mov(reg_kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_dst, aux_reg_dst_d); - mov(aux_reg_ker, aux_reg_ker_d); - mov(aux_reg_dst_prf, aux_reg_dst_d_prf); - mov(aux_reg_ker_prf, aux_reg_ker_d_prf); - } - - align(16); - L(kh_label); - if (check_last_kh) { - for (int ki = 0; ki < kw; ki++) - for (int oc = 0; oc < oc_block; oc += 4) - for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { - bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1 - && ki == kw - 1 && (oc + 4) == oc_block); - - if (last_kernel_loads) { - cmp(reg_kj, 1); - je(last_iter_label, T_NEAR); - } - - kernel_loads(ki, oc, kk); - for (int ii = get_iw_start(ki, l_overflow), - prf_count_t0 = 0, prf_count_t1 = 0; - ii < get_iw_end(ur_w, ki, r_overflow); ii++) { - int aux_dst_offset = typesize - * ((ii + jcp.l_pad - ki) * oc_block + oc); - v4fmaddps(zmm_out(ii, kk), zmm_ker(0), - EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); - - if (ii % 2) { - if (prf_count_t0 < 4) { - int aux_kernel_prf; - if (last_kernel_loads) - aux_kernel_prf= kernel_offset(0, prf_count_t0 - + oc + 4 - oc_block, 0) + typesize * kw - * oc_block * ic_block; - else - aux_kernel_prf = kernel_offset(kk, oc + 4 - + prf_count_t0, ki); - mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, - aux_kernel_prf)); - prf_count_t0++; - } else if (prf_count_t1 < 4) { - mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, - kernel_offset(kk, oc + prf_count_t1, ki))); - prf_count_t1++; - } - } else - prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1); - } - if (last_kernel_loads) { - jmp(loop_end_label, T_NEAR); - - L(last_iter_label); - - kernel_loads(ki, oc, kk); - for (int ii = get_iw_start(ki, l_overflow), - prf_count_t0 = 0, prf_count_t1 = 0; - ii < get_iw_end(ur_w, ki, r_overflow); ii++) { - int aux_dst_offset = typesize - * ((ii + jcp.l_pad - ki) * oc_block + oc); - v4fmaddps(zmm_out(ii, kk), zmm_ker(0), - EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); - if (ii % 2) { - if (prf_count_t0 < 4) { - mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf, - kernel_offset(0, prf_count_t0, 0))); - prf_count_t0++; - } else if (prf_count_t1 < 4) { - mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, - kernel_offset(kk, oc + prf_count_t1, ki))); - prf_count_t1++; - } - } - } - L(loop_end_label); - } - } - } else { - for (int ki = 0; ki < kw; ki++) - for (int oc = 0; oc < oc_block; oc += 4) - for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { - kernel_loads(ki, oc, kk); - - for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0; - ii < get_iw_end(ur_w, ki, r_overflow); ii++) { - int aux_dst_offset = typesize - * ((ii + jcp.l_pad - ki) * oc_block + oc); - v4fmaddps(zmm_out(ii, kk), zmm_ker(0), - EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); - if ((ii % 2) && (prf_count_t1 < 4)) { - mic_prefetcht1(EVEX_compress_addr( - aux_reg_ker_prf, kernel_offset(kk, - oc + prf_count_t1, ki))); - prf_count_t1++; - } - if ( ki == 1 && oc == 0 && kk == 0) - mic_prefetcht1(EVEX_compress_addr( - aux_reg_dst_prf, aux_dst_offset)); - } - } - } - - add(aux_reg_ker, shift_ker_ptr); - sub(aux_reg_dst, shift_dst_ptr); - add(aux_reg_ker_prf, shift_ker_ptr); - sub(aux_reg_dst_prf, shift_dst_ptr); - - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - - if (jcp.ndims == 5) { - sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block); - add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); - sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block); - add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - - pop(reg_src); - pop(reg_src_prf); - } -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( - int ur_w, int l_overflow, int r_overflow) -{ - Label kh_label, kd_label; - int kw = jcp.kw; - int ow = jcp.ow; - - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int l_pad = jcp.l_pad; - int dilate_w = jcp.dilate_w + 1; - int stride_w = jcp.stride_w; - int stride_h = jcp.stride_h; - - int ker_pipeline_depth = 4; - assert(ker_reg_base_idx + ker_pipeline_depth <= 32); - assert(oc_block >= ker_pipeline_depth); - - int num_ker_loads = oc_block * kw; - int num_inp_prfs = ur_w * nstl::min(kw, stride_w) - + nstl::max(0, kw - stride_w); - int num_prfs = num_ker_loads + num_inp_prfs; - int num_fmas = num_ker_loads * ur_w / stride_w; - int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs); - int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_dst, reg_dst); - mov(aux_reg_ker, reg_ker); - - mov(aux_reg_dst_prf, reg_dst_prf); - mov(aux_reg_ker_prf, reg_ker_prf); - } - - if (jcp.ndims == 5) { - push(reg_src_prf); - push(reg_src); - - mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); - mov(aux_reg_dst_d, reg_dst); - mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); - mov(aux_reg_dst_d_prf, reg_dst_prf); - mov(aux_reg_ker_d_prf, reg_ker_prf); - - L(kd_label); - mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); - } else { - mov(reg_kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_dst, aux_reg_dst_d); - mov(aux_reg_ker, aux_reg_ker_d); - mov(aux_reg_dst_prf, aux_reg_dst_d_prf); - mov(aux_reg_ker_prf, aux_reg_ker_d_prf); - } - - L(kh_label); { - int step = 0; - int ker_prfs = 0; - for (int ki = 0; ki < kw; ki++) { - for (int oc = 0; oc < oc_block; oc++) { - if (step == 0) { - for (int i = 0; i < ker_pipeline_depth; i++) { - int aux_kernel_offset = typesize * ((oc + i) * oc_block - + ki * ic_block * oc_block); - vmovups(zmm_ker(i), EVEX_compress_addr( - aux_reg_ker, aux_kernel_offset)); - } - } else if (step < num_ker_loads - ker_pipeline_depth + 1) { - int load_offset = ker_pipeline_depth - 1; - int ker_load_reg_idx - = (step + load_offset) % ker_pipeline_depth; - int aux_kernel_offset = typesize * ((oc + load_offset) - * oc_block + ki * ic_block * oc_block); - vmovups(zmm_ker(ker_load_reg_idx), - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - } - - bool ker_prf_inserted = false; - auto zmm_kernel = zmm_ker(step % ker_pipeline_depth); - - int jj_start = get_iw_start(ki, l_overflow); - int jj_end = get_iw_end(ur_w, ki, r_overflow); - assert(stride_w != 1 - || jj_start == nstl::max(0, - l_overflow - (kw - 1 - ki) * dilate_w)); - assert(stride_w != 1 - || jj_end == ur_w - nstl::max(0, - r_overflow - ki * dilate_w)); - - for (int jj = jj_start; jj < jj_end; jj += stride_w) { - assert((jj + l_pad - ki * dilate_w) % stride_w == 0); - int aux_dst_offset = typesize * - (((jj + l_pad - ki * dilate_w) - / stride_w) * jcp.oc_block + oc); - vfmadd231ps(zmm_out(jj, 0), zmm_kernel, - EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true)); - - int fma_idx = (step * ur_w + jj) / stride_w; - int prf_slot_idx = fma_idx / prf_inst_spacing; - if (fma_idx % prf_inst_spacing == prf_inst_trigger) { - if (!ker_prf_inserted && ker_prfs < num_ker_loads) { - int ker_prf_offset = typesize - * ker_prfs * jcp.oc_block; - mic_prefetcht1(EVEX_compress_addr( - aux_reg_ker_prf, ker_prf_offset)); - ker_prf_inserted = true; - ker_prfs++; - } else { - int inp_prf_idx = prf_slot_idx - ker_prfs; - if (inp_prf_idx < num_inp_prfs) { - int inp_prf_offset - = ic_block * typesize - * ((inp_prf_idx / kw) * kw - + (inp_prf_idx % kw)); - mic_prefetcht0(EVEX_compress_addr( - aux_reg_dst_prf, inp_prf_offset)); - } - } - } - } - step++; - } - } - - add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block); - sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block); - add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block); - sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block); - - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - } - if (jcp.ndims == 5) { - sub(aux_reg_dst_d, - typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); - add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh - * oc_block * ic_block); - sub(aux_reg_dst_d_prf, - typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); - add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh - * oc_block * ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - } - - if (jcp.ndims == 5) - { - pop(reg_src); - pop(reg_src_prf); - } -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( - int ur_w, int l_overflow, int r_overflow) -{ - int kw = jcp.kw; - int ow = jcp.ow; - int dilate_w = jcp.dilate_w + 1; - int stride_w = jcp.stride_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int nb_ic_block = jcp.nb_ic_blocking; - Label kh_label, kd_label; - - int shift_ker_ptr = typesize * kw * oc_block * ic_block; - int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block; - - auto output_offset = [=](int oi, int oc, int ki) { - return typesize * - (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc); - }; - auto kernel_offset = [=](int icb, int oc, int ki) { - int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; - int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; - int oc_offset = oc * jcp.oc_block; - return typesize * (blk_offset + oc_offset); - }; - - if (one_of(jcp.ndims, 3, 4)) { - mov(aux_reg_dst, reg_dst); - mov(aux_reg_ker, reg_ker); - } - - if (jcp.ndims == 5) { - push(reg_src_prf); - push(reg_src); - - mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); - mov(aux_reg_dst_d, reg_dst); - mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); - - L(kd_label); - mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); - } else { - mov(reg_kj, reg_kh); - } - - if (jcp.ndims == 5) { - mov(aux_reg_dst, aux_reg_dst_d); - mov(aux_reg_ker, aux_reg_ker_d); - } - - L(kh_label); - { - for (int ki = 0; ki < kw; ki++) { - int jj_start = get_iw_start(ki, l_overflow); - int jj_end = get_iw_end(ur_w, ki, r_overflow); - for (int oc = 0; oc < oc_block; oc++) { - if (jcp.kernel_kind == expl_bcast) { - for (int jj = jj_start; jj < jj_end; jj++) { - int aux_output_offset = output_offset(jj, oc, ki); - vbroadcastss(zmm_inp(jj, nb_ic_block), - ptr[aux_reg_dst + aux_output_offset]); - } - } - for (int ii = 0; ii < nb_ic_block; ii++) { - int aux_kernel_offset = kernel_offset(ii, oc, ki); - if (jj_end - jj_start > 0) - vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, - aux_kernel_offset)); - for (int jj = jj_start; jj < jj_end; jj += stride_w) - if (jcp.kernel_kind == expl_bcast) - vfmadd231ps(zmm_out(jj, ii), - zmm_inp(jj, nb_ic_block), zmm_wei); - else - vfmadd231ps(zmm_out(jj, ii), zmm_wei, - EVEX_compress_addr(aux_reg_dst, - output_offset(jj, oc, ki), true)); - } - } - } - add(aux_reg_ker, shift_ker_ptr); - sub(aux_reg_dst, shift_dst_ptr); - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - } - - if (jcp.ndims == 5) { - sub(aux_reg_dst_d, - typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); - add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); - - dec(reg_ki); - cmp(reg_ki, 0); - jg(kd_label, T_NEAR); - - pop(reg_src); - pop(reg_src_prf); - } -} - -inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( - int ur_w, int l_overflow, int r_overflow) -{ - if (jcp.ndims == 5) push(reg_oi); - - prepare_output(ur_w); - - Label skip_compute_loop; - if (jcp.ndims == 5) { - mov(reg_kj, ptr[param + GET_OFF(kd_padding)]); - cmp(reg_kj, 0); - je(skip_compute_loop, T_NEAR); - } - mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); - cmp(reg_kj, 0); - je(skip_compute_loop, T_NEAR); - - if (jcp.ver == ver_4fma) - compute_loop_4fma(ur_w, l_overflow, r_overflow); - else if (jcp.ver == ver_fma) - if (mayiuse(avx512_mic)) - compute_loop_fma(ur_w, l_overflow, r_overflow); - else - if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1) - compute_loop_fma(ur_w, l_overflow, r_overflow); - else - compute_loop_fma_core(ur_w, l_overflow, r_overflow); - else - assert("!unknown convolution version"); - - L(skip_compute_loop); - store_output(ur_w); - if (jcp.ndims == 5) pop(reg_oi); -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::generate() -{ - int iw = jcp.iw; - int kw = jcp.kw; - int ur_w = jcp.ur_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int ur_w_tail = jcp.ur_w_tail; - int dilate_w = jcp.dilate_w + 1; - int stride_w = jcp.stride_w; - - int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block; - int src_shift = jcp.typesize_out * ur_w * oc_block; - - preamble(); - - mov(reg_src, ptr[param + GET_OFF(src)]); - mov(reg_dst, ptr[param + GET_OFF(dst)]); - mov(reg_ker, ptr[param + GET_OFF(filt)]); - - mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); - mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]); - mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]); - mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]); - - int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); - int r_overflow = nstl::max(0, ((kw - 1) * dilate_w - - nstl::max(0, jcp.r_pad)) / stride_w); - int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w - - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w); - - int n_oi = iw / ur_w; - if (r_overflow1 > 0) n_oi--; - - if (ur_w == iw) { - compute_loop(ur_w, l_overflow, r_overflow); - } else if (n_oi == 0) { - compute_loop(ur_w, l_overflow, r_overflow1); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - add(reg_src_prf, src_shift); - add(reg_dst_prf, dst_shift); - if (ur_w_tail != 0) - compute_loop(ur_w_tail, 0, r_overflow); - } else { - xor_(reg_oi, reg_oi); - if (l_overflow > 0) { - compute_loop(ur_w, l_overflow, 0); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - add(reg_src_prf, src_shift); - add(reg_dst_prf, dst_shift); - - inc(reg_oi); - } - if ((l_overflow <= 0 && n_oi > 0) - || (l_overflow > 0 && n_oi > 1)) { - Label ow_loop_label; - L(ow_loop_label); { - compute_loop(ur_w, 0, 0); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - add(reg_src_prf, src_shift); - add(reg_dst_prf, dst_shift); - - inc(reg_oi); - cmp(reg_oi, n_oi); - jl(ow_loop_label, T_NEAR); - } - } - if (r_overflow1 > 0) { - compute_loop(ur_w, 0, r_overflow1); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - add(reg_src_prf, src_shift); - add(reg_dst_prf, dst_shift); - } - if (ur_w_tail != 0) { - compute_loop(ur_w_tail, 0, r_overflow); - } - } - - postamble(); -} - -status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( - jit_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d) -{ - if (!mayiuse(avx512_common)) return status::unimplemented; - - jcp = zero(); - - jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); - const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; - int ndims = diff_src_d.ndims(); - - jcp.ndims = ndims; - jcp.prop_kind = cd.prop_kind; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = diff_src_d.dims()[0]; - - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; - - jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; - jcp.iw = diff_src_d.dims()[ndims-1]; - jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; - jcp.ow = diff_dst_d.dims()[ndims-1]; - - jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; - jcp.l_pad = cd.padding[0][ndims-3]; - - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; - jcp.stride_w = cd.strides[ndims-3]; - - jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; - jcp.dilate_w = cd.dilates[ndims-3]; - if ((jcp.dilate_w != 0 && jcp.stride_w != 1) - || (jcp.dilate_d != 0 && jcp.stride_d != 1) - || (jcp.dilate_h != 0 && jcp.stride_h != 1)) - return status::unimplemented; - - jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.iw + jcp.l_pad - 1); - jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) - - (jcp.ih + jcp.t_pad - 1); - jcp.back_pad = (jcp.od - 1) * jcp.stride_d - + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); - - jcp.aligned_threads = 0; - - jcp.is_1stconv = false; - - jcp.oc_block = jcp.simd_w; - jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; - - bool ok_to_pad_channels = true - && jcp.ngroups == 1 - && diff_src_d.data_type() == data_type::f32; - - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, jcp.oc_block); - jcp.ic = rnd_up(jcp.ic, jcp.ic_block); - } - - auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = with_groups - ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) - : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i); - jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.oc % jcp.oc_block == 0 - && jcp.ic % jcp.ic_block == 0 - && jcp.src_tag == dat_tag - && jcp.dst_tag == dat_tag; - if (!args_ok) - return status::unimplemented; - - jcp.nb_ic = jcp.ic / jcp.ic_block; - jcp.nb_oc = jcp.oc / jcp.oc_block; - - jcp.ur_w = jcp.stride_w; - - int regs = 28; - if (jcp.iw <= regs) - jcp.ur_w = jcp.iw; - else { - for (int ur_w = regs; ur_w > 0; --ur_w) - if (ur_w % jcp.stride_w == 0) { - jcp.ur_w = ur_w; - break; - } - } - int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - - jcp.l_pad) / jcp.stride_w); - int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w); - int n_oi = jcp.iw / jcp.ur_w; - if (r_overflow1 > 0) n_oi--; - - if (mayiuse(avx512_common) - && diff_dst_d.data_type() == data_type::f32 - && weights_d.data_type() == data_type::f32 - && diff_src_d.data_type() == data_type::f32) { - jcp.ver = ver_fma; - jcp.typesize_in = sizeof(float); - jcp.typesize_out = sizeof(float); - if (mayiuse(avx512_mic_4ops) - && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) { - jcp.ver = ver_4fma; - } - } else { - return status::unimplemented; - } - - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - if (jcp.wei_tag != wei_tag) - return status::unimplemented; - - if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) - && jcp.ver != ver_fma) - return status::unimplemented; - - jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; - if (jcp.ver == ver_4fma) { - if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) { - jcp.nb_ic_blocking = 2; - } else { - for (int i = jcp.nb_ic; i > 0; i--) - if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) { - jcp.nb_ic_blocking = i; - break; - } - } - } - - jcp.loop_order = loop_gnc; - - bool large_code_size = (jcp.ur_w != jcp.ow) - && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1)) - && (r_overflow1 > 0) && (l_overflow > 0); - if (large_code_size) { - const int max_code_size = 24 * 1024; - const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw; - int mult = 1; - if (l_overflow > 0) mult += 1; - if (r_overflow1 > 0) mult += 1; - for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { - if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2 - < max_code_size) { - if (ur_w % jcp.stride_w == 0) { - jcp.ur_w = ur_w; - break; - } - } - } - } - - if (jcp.ver == ver_fma && mayiuse(avx512_core)) { - int try_nb_ic_blocking = 2; - unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block - * try_nb_ic_blocking * jcp.kh; - unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block; - unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block - * jcp.oc_block * try_nb_ic_blocking; - unsigned int ker_total_size = ker_inp_size + ker_out_size - + ker_wei_size; - if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) - || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13)) - || ker_total_size > L1_cache_size ))) - || jcp.stride_h > 1 || jcp.stride_d > 1) { - jcp.kernel_kind = embd_bcast; - jcp.ur_w = nstl::min(jcp.iw, regs); - jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; - if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size - && jcp.ow > 8)) && jcp.stride_h == 1) - if (jcp.nb_ic % try_nb_ic_blocking == 0) { - jcp.nb_ic_blocking = try_nb_ic_blocking; - jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); - if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; - } - } else { - jcp.kernel_kind = expl_bcast; - jcp.nb_oc_blocking = 1; - jcp.nb_ic_blocking = 4; - if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic; - if (jcp.nb_ic % jcp.nb_ic_blocking != 0) - for (int i = jcp.nb_ic_blocking; i > 0; i--) - if (jcp.nb_ic % i == 0) { - jcp.nb_ic_blocking = i; - break; - } - jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); - if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; - } - } - jcp.ur_w_tail = jcp.iw % jcp.ur_w; - - if (l_overflow * jcp.stride_w > jcp.ur_w) - return status::unimplemented; - int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); - if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) - return status::unimplemented; - if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) - return status::unimplemented; - - pick_loop_order(jcp); - - jcp.nb_oc_L2 = jcp.nb_oc; - if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { - for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; - divf++) { - size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih - * jcp.id; - size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; - size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh - * jcp.kd * jcp.nb_ic_blocking * temp_nb; - if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { - if (jcp.kh == 3 && jcp.ih == 7) { - jcp.nb_oc_L2 = 1; - break; - } - temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf - : jcp.nb_oc_L2); - } else { - jcp.nb_oc_L2 = temp_nb; - break; - } - } - } - - args_ok = true - && jcp.ic <= diff_src_d.padded_dims()[1] - && jcp.oc <= diff_dst_d.padded_dims()[1] - && jcp.ic <= weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; - if (!args_ok) return status::unimplemented; - - return status::success; -} - -void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - UNUSED(scratchpad); - UNUSED(jcp); -} - -const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28; - -void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() -{ - Label kd_comeback_label; - - /* 'depth' loop count bound by 'kd_work_size' */ - mov(kj, reg_kd_count); - L(kd_comeback_label); { - int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; - sub(reg_input, - jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); - sub(reg_kernel, - jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); - dec(kj); - cmp(kj, 0); - jg(kd_comeback_label, T_NEAR); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() -{ - Label kh_comeback_label, kd_comeback_label; - mov(kj, reg_kh); - L(kh_comeback_label); { - int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; - sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); - sub(reg_kernel, - jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_comeback_label, T_NEAR); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( - int ur_w, int pad_l, int pad_r, - int ic_block_step, int input_offset, int kernel_offset, - int output_offset, bool input_wraparound) -{ - - int kw = jcp.kw; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) - vmovups(Zmm(i_kw * ic_block_step + i_ic), - EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block - + i_ic) * jcp.oc_block + kernel_offset)); - - for (int i_ur = 0; i_ur < ur_w; i_ur++) { - if (i_ur == 0) { - vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4), - EVEX_compress_addr(reg_output, typesize * (i_ur + 0) - * oc_block + output_offset)); - if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4), - EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block - + output_offset)); - if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4), - EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block - + output_offset)); - if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), - EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block - + output_offset)); - } else if (i_ur + 3 < ur_w) - vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), - EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block - + output_offset)); - - for (int i_kw = 0; i_kw < kw; i_kw++) { - int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); - if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + - (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - const size_t i_offset = (size_t)input_offset - + (size_t)typesize * (jcp.ver == ver_4fma - ? (i_iw - pad_l + i_ic * jcp.tr_iw) - : (jcp.is_1stconv - ? (i_iw - pad_l) + (size_t)i_ic - * ((size_t)jcp.ih*jcp.iw*jcp.id) - : (i_iw - pad_l) * ic_block + i_ic)); - vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), - Zmm(kw * ic_block_step + i_ur % 4), - EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, - true)); - } - } - } - - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) - vmovups(EVEX_compress_addr(reg_kernel, typesize - * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset), - Zmm(i_kw * ic_block_step + i_ic)); -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma( - int ur_w, int pad_l, int pad_r, - int ic_block_step, int input_offset, int kernel_offset, - int output_offset, bool input_wraparound) -{ - // TODO: add prefetches to fma version as well - - assert(jcp.ver == ver_4fma); - - int kw = jcp.kw; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - - auto zmm_ker = [=](int i_kw, int i_ic) { - return Zmm(i_kw * ic_block_step + i_ic); - }; - - auto ker_addr = [=](int i_kw, int i_ic) { - size_t local_offset - = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block; - return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); - }; - - auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) { - int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1); - int local_offset = jcp.typesize_in * (i_iw + i_ic * stride); - return EVEX_compress_addr(reg_input, - local_offset + input_offset + extra_offset); - }; - - auto zmm_out = [=](int i_iw) { - // TODO: move reg calc to global member funcs - const int out_zmm_base_idx = 28; - return Zmm(out_zmm_base_idx + i_iw % 4); - }; - - auto out_addr = [=](int i_ur) { - return EVEX_compress_addr(reg_output, - jcp.typesize_in * i_ur * oc_block + output_offset); - }; - - auto pf_callback = [=](int i_ur, int i_kw, int i_ic) { - assert(i_ur % 4 == 0); - if (i_ur == 0) - prefetcht1(ker_addr(i_kw, i_ic)); - if (i_ur + 4 >= ur_w) - prefetcht0(ker_addr(i_kw, i_ic)); - - const ptrdiff_t next_input_block_offset - = jcp.typesize_in * ic_block_step * jcp.tr_iw; - if (i_ur % 16 == 4 && i_kw == 0) { - if (i_ur + 16 < ur_w) - prefetcht0(inp_addr(i_ur + 16, i_ic)); - else - prefetcht0(inp_addr(0, i_ic, next_input_block_offset)); - } - if (i_ur % 16 == 4 && i_kw == 1) { - if (input_wraparound) - prefetcht1(inp_addr(i_ur, i_ic, -input_offset)); - else - prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset)); - } - }; - - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - auto zmm = zmm_ker(i_kw, i_ic); - vpxord(zmm, zmm, zmm); - } - - for (int i_ur = 0; i_ur < ur_w; i_ur += 4) { - - for (int i = 0; i < 4; i++) { - auto zmm = zmm_out(i_ur + i); - if (i_ur + i < ur_w) - vmovups(zmm, out_addr(i_ur + i)); - else - vpxord(zmm, zmm, zmm); - prefetcht0(out_addr(i_ur + i + 4)); - } - - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - int i_iw = i_ur + i_kw; - v4fmaddps(zmm_ker(i_kw, i_ic), - zmm_out(i_ur), inp_addr(i_iw, i_ic)); - pf_callback(i_ur, i_kw, i_ic); - } - } - - for (int i_kw = 0; i_kw < kw; i_kw++) - for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { - auto addr = ker_addr(i_kw, i_ic); - auto zmm = zmm_ker(i_kw, i_ic); - vaddps(zmm, zmm, addr); - vmovups(addr, zmm); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step( - int ur_w, int pad_l, int pad_r, - int ic_block_step, int input_offset, int kernel_offset, - int output_offset, bool input_wraparound) -{ - if (jcp.ver == ver_4fma) - compute_ic_block_step_4fma(ur_w, pad_l, pad_r, - ic_block_step, input_offset, kernel_offset, output_offset, - input_wraparound); - else if (jcp.ver == ver_fma) - compute_ic_block_step_fma(ur_w, pad_l, pad_r, - ic_block_step, input_offset, kernel_offset, output_offset, - input_wraparound); - else - assert(!"unknown convolution version"); -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32 - ::compute_oh_step_unroll_ow_icblock( - int ic_block_step, int max_ur_w) -{ - UNUSED(max_ur_w); - - Label kh_label, kd_label; - - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int inp_mul = !jcp.is_1stconv ? ic_block : 1; - int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; - int ow = jcp.ow; - - int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - int l_pad = jcp.l_pad; - - if (jcp.ndims == 5) { - L(kd_label); - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - } - - mov(kj, reg_kh); - L(kh_label); - { - for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { - const int input_offset = jcp.typesize_in - * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic); - compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step, - input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, - i_b_ic + ic_block_step >= jcp.ic_block); - } - add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); - add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_label, T_NEAR); - } - - if (jcp.ndims == 5) { - add(aux_reg_input, - jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); - add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block - * oc_block); - dec(ki); - cmp(ki, 0); - jg(kd_label, T_NEAR); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32 - ::compute_oh_step_unroll_ow( - int ic_block_step, int max_ur_w) -{ - Label kh_label, ic_block_label, kd_label; - - UNUSED(max_ur_w); - - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - - int ow = jcp.ow; - - int r_pad = nstl::max(0, - (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.iw + jcp.l_pad - 1)); - int l_pad = jcp.l_pad; - - if (jcp.ndims == 5) { - L(kd_label); - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - } - - mov(kj, reg_kh); - L(kh_label); - { - xor_(b_ic, b_ic); - L(ic_block_label); { - compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, - 0, 0, 0); - size_t inp_icblk_stride = jcp.is_1stconv - ? (size_t)jcp.ih * jcp.iw * jcp.id - : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); - size_t input_offset - = inp_icblk_stride * jcp.typesize_in * ic_block_step; - safe_add(reg_input, input_offset, reg_long_offt); - add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); - add(b_ic, ic_block_step); - cmp(b_ic, jcp.ic_block); - jl(ic_block_label, T_NEAR); - } - - if (jcp.is_1stconv) { - size_t input_offset - = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; - safe_sub(reg_input, input_offset, reg_long_offt); - add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); - } else if (jcp.ver != ver_4fma) { - add(reg_input, jcp.typesize_in - * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); - } - add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_label, T_NEAR); - } - if (jcp.ndims == 5) { - add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih - * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); - add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block - * oc_block); - dec(ki); - cmp(ki, 0); - jg(kd_label, T_NEAR); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32 - ::compute_oh_step_common( - int ic_block_step, int max_ur_w) -{ - Label kh_label, ic_block_label, ow_block_label, kd_label; - - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - - int ow = jcp.ow; - int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad; - - int ur_w = nstl::min(ow, max_ur_w); - int ur_w_trips = ow / ur_w; - int ur_w_tail = ow % ur_w; - if ((ur_w_tail == 0 && r_pad != 0) - || r_pad >= ur_w_tail) { - if (ur_w_trips > 1) { - ur_w_tail += ur_w; - ur_w_trips--; - } else { - ur_w_tail += (ur_w - ur_w / 2); - ur_w = ur_w / 2; - } - } - - int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block; - int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult; - int output_comeback = ur_w_trips * ur_w * oc_block; - - if (jcp.ndims == 5) { - L(kd_label); - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - } - - mov(kj, reg_kh); - L(kh_label); { - xor_(b_ic, b_ic); - L(ic_block_label); { - if (l_pad != 0) { - ur_w_trips--; - compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); - add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad) - * inp_mult); - add(reg_output, jcp.typesize_in * ur_w * oc_block); - } - - if (ur_w_trips > 0) { - xor_(reg_ur_w_trips, reg_ur_w_trips); - L(ow_block_label); { - compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); - add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w - * inp_mult); - add(reg_output, jcp.typesize_in * ur_w * oc_block); - - inc(reg_ur_w_trips); - cmp(reg_ur_w_trips, ur_w_trips); - jl(ow_block_label, T_NEAR); - } - } - - if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad, - ic_block_step, 0, 0, 0); - - sub(reg_input, jcp.typesize_in * input_comeback); - sub(reg_output, jcp.typesize_in * output_comeback); - int inp_icblk_stride = jcp.is_1stconv - ? jcp.ih * jcp.iw * jcp.id - : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); - size_t input_offset - = inp_icblk_stride * jcp.typesize_in * ic_block_step; - safe_add(reg_input, input_offset, reg_long_offt); - add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); - - add(b_ic, ic_block_step); - cmp(b_ic, jcp.ic_block); - jl(ic_block_label, T_NEAR); - } - if (jcp.is_1stconv) { - size_t input_offset - = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; - safe_sub(reg_input, input_offset, reg_long_offt); - add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); - } else if (jcp.ver != ver_4fma) { - add(reg_input, jcp.typesize_in - * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); - } - add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); - dec(kj); - cmp(kj, 0); - jg(kh_label, T_NEAR); - } - if (jcp.ndims == 5) { - add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih - * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); - add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block - * oc_block); - dec(ki); - cmp(ki, 0); - jg(kd_label, T_NEAR); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32 - ::compute_oh_step_disp() -{ - int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2); - if (jcp.is_1stconv) { - bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0); - ic_block_step - = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1; - } - - bool too_large_to_unroll - = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) - && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); - - int ow = jcp.ow; - if (jcp.ndims == 5) { - /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of - * 'movs' must be guaranteed. */ - mov(ki, reg_kd_count); - push(reg_kd_count); - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - } - - if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) - compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w); - else if (ow <= max_ur_w) - compute_oh_step_unroll_ow(ic_block_step, max_ur_w); - else - compute_oh_step_common(ic_block_step, max_ur_w); - - if (jcp.ndims == 5) { - mov(reg_input, aux_reg_input); - mov(reg_kernel, aux_reg_kernel); - pop(reg_kd_count); - od_step_comeback_pointers(); - } else { - oh_step_comeback_pointers(); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() -{ - Label skip_zeroing, zeroing_loop; - - mov(reg_tmp, ptr[param + GET_OFF(channel)]); - cmp(reg_tmp, 0); - jz(skip_zeroing, T_NEAR); - - Zmm zero = Zmm(0); - vpxord(zero, zero, zero); - xor_(reg_tmp, reg_tmp); - L(zeroing_loop); { - assert(jcp.oc_block * jcp.typesize_out - == cpu_isa_traits::vlen); - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) - vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block - * jcp.typesize_out], zero); - add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out); - cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd - * jcp.typesize_out); - jnz(zeroing_loop); - } - - L(skip_zeroing); -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() -{ - Label skip_bias, bias_loop, skip_load_bias; - - mov(reg_tmp, ptr[param + GET_OFF(flags)]); - test(reg_tmp,reg_tmp); - jne(skip_bias, T_NEAR); - - mov(reg_bias, ptr[param + GET_OFF(bias)]); - mov(reg_output, ptr[param + GET_OFF(dst)]); - vpxord(Zmm(1), Zmm(1), Zmm(1)); - - mov(reg_tmp, ptr[param + GET_OFF(channel)]); - cmp(reg_tmp, 0); - jne(skip_load_bias, T_NEAR); - vmovups(Zmm(1), ptr[reg_bias]); - - L(skip_load_bias); - - mov(reg_oi, ptr[param + GET_OFF(d_worksize)]); - sub(reg_oi, ptr[param + GET_OFF(d_index)]); - mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out); - imul(reg_oi, reg_tmp); - - xor_(reg_tmp, reg_tmp); - L(bias_loop); { - vmovups(Zmm(0), ptr[reg_output + reg_tmp]); - vaddps(Zmm(1), Zmm(1), Zmm(0)); - add(reg_tmp, jcp.oc_block * jcp.typesize_out); - cmp(reg_tmp, reg_oi); - jl(bias_loop); - } - vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1)); - - L(skip_bias); -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32 - ::compute_oh_loop_common() -{ - int b_pad = jcp.b_pad; - int t_pad = jcp.t_pad; - bool is_dilated = jcp.dilate_h != 0; - int dilate_h = jcp.dilate_h + 1; - int stride_h = jcp.stride_h; - const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; - Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, - oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, - oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; - - int ow = jcp.ow; - - mov(reg_kh, jcp.kh); - xor_(reg_ih_count, reg_ih_count); - xor_(reg_oj, reg_oj); - /* Compute 'top' edge */ - if (t_pad > 0) { - const int kh_range = 1 + (jcp.kh - 1) * dilate_h; - const int overflow - = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); - const int underflow = div_up(t_pad, dilate_h); - const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; - mov(reg_kh, initial_inp_ker_overlap); - add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block - * jcp.oc_block); - // generate loop to process kernel while it remains within t_pad + ih - if (kh_range < t_pad + jcp.ih) { - if (is_dilated) { - const int tail = t_pad % dilate_h; - const int shift = tail == 0 ? 0 : dilate_h - tail; - mov(reg_tmp, shift); - if (tail != 0) - add(reg_input, jcp.typesize_in * shift * iw * inp_mult); - } - L(oh_tpad_label); { - compute_oh_step_disp(); - add(reg_output, jcp.typesize_in * ow * jcp.oc_block); - if (is_dilated) { - inc(reg_tmp); - cmp(reg_tmp, dilate_h); - jl(oh_dilate_label_shift, T_NEAR); - // unshift input as new kernel element enters - sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); - xor_(reg_tmp, reg_tmp); - } - // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 - sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw - * jcp.ic_block * jcp.oc_block); - add(reg_kh, stride_h); - if (is_dilated) { - jmp(oh_dilate_label_noshift, T_NEAR); - L(oh_dilate_label_shift); - // shift input as old kernel element progresses - add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); - L(oh_dilate_label_noshift); - } - inc(reg_oj); - add(reg_ih_count, stride_h); - - // final number of kernel elements that overlap with input - const int final_inp_ker_overlap - = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); - cmp(reg_kh, final_inp_ker_overlap); - jl(oh_tpad_label, T_NEAR); - } - } - // need second loop to process kernel if it is larger than the input - // (does not apply to dilations as they must have unit stride) - if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : - t_pad % stride_h)) { - assert(!is_dilated); - mov(reg_kh, jcp.ih); - L(oh_tpad_tail_label); { - compute_oh_step_disp(); - add(reg_output, jcp.typesize_in * ow * jcp.oc_block); - sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw - * jcp.ic_block * jcp.oc_block); - - inc(reg_oj); - add(reg_ih_count, stride_h); - - cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h)); - jl(oh_tpad_tail_label, T_NEAR); - } - } - // correct any excess shifts to kernel and input - // (does not apply to dilations as they must have unit stride, - // kernel must fit inside input, and padding is smaller than input) - if (t_pad <= jcp.oh * stride_h) { - // kernel has moved beyond padding (adjust for stride effects) - if (t_pad % stride_h != 0) { - assert(!is_dilated); - int inp_corr = stride_h - t_pad % stride_h; - add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw - * jcp.ic_block * jcp.oc_block); - add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult); - } - } else { - // kernel still overlaps padding (complete reset) - assert(!is_dilated); - sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) - * jcp.kw * jcp.ic_block * jcp.oc_block); - } - } - - cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); - jge(oh_label_end, T_NEAR); - cmp(reg_oj, jcp.oh); - jge(oh_label, T_NEAR); - - /* Compute middle block(s) */ - mov(reg_kh, jcp.kh); - L(oh_label); { - compute_oh_step_disp(); - add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); - add(reg_output, jcp.typesize_in * ow * jcp.oc_block); - - inc(reg_oj); - add(reg_ih_count, stride_h); - - cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); - jge(oh_label_end, T_NEAR); - - cmp(reg_oj, jcp.oh); - jl(oh_label, T_NEAR); - } - L(oh_label_end); - - /* Compute bottom edge */ - if (b_pad > 0) { - cmp(reg_oj, jcp.oh); - jge(oh_bpad_label_end, T_NEAR); - - if (is_dilated) { - mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations - mov(reg_tmp, 0); - } else { - mov(reg_kh, jcp.ihp - b_pad); - sub(reg_kh, reg_ih_count); - } - L(oh_bpad_label); - { - compute_oh_step_disp(); - add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); - add(reg_output, jcp.typesize_in * ow * jcp.oc_block); - if (is_dilated) { - inc(reg_tmp); - cmp(reg_tmp, dilate_h); - jl(oh_dilate_label_end, T_NEAR); - xor_(reg_tmp, reg_tmp); - } - sub(reg_kh, stride_h); - cmp(reg_kh, 0); - jle(oh_bpad_label_end, T_NEAR); - if (is_dilated) - L(oh_dilate_label_end); - - inc(reg_oj); - cmp(reg_oj, jcp.oh); - jl(oh_bpad_label, T_NEAR); - } - L(oh_bpad_label_end); - } -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() { - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; - int ow = jcp.ow; - const int input_backpad_overlap - = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); - - const size_t filter_shift - = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block; - const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult; - const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block; - - Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, - backpad_end_label, backpad_label; - - if (jcp.with_bias) bias_kernel(); - - /* initially offset 'kd' by f_pad */ - add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); - - mov(reg_input_d, ptr[param + GET_OFF(src)]); - mov(reg_output_d, ptr[param + GET_OFF(dst)]); - mov(reg_d_index, ptr[param + GET_OFF(d_index)]); - mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); - - cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); - jge(loop_end_label, T_NEAR); - - L(d_loop_label); - - mov(reg_input, reg_input_d); - mov(reg_output, reg_output_d); - - push(reg_input_d); - push(reg_output_d); - push(reg_d_index); - - compute_oh_loop_common(); - - pop(reg_d_index); - pop(reg_output_d); - pop(reg_input_d); - - /* Compute 'front' edge */ - if (jcp.f_pad > 0) { - - /* Check if within fpad region */ - cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); - jge(fpad_end_label, T_NEAR); - - /* Fpad steps */ - sub(reg_kernel, filter_shift * jcp.stride_d); - add(reg_kd_count, jcp.stride_d); - - /* Final number of kernel elements that overlap with input */ - const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id); - cmp(reg_kd_count, inp_ker_overlap); - jl(common_block_label, T_NEAR); - - /* Correct any excess shifts to kernel and input */ - if (jcp.f_pad <= jcp.od * jcp.stride_d) { - /* Filter has moved beyond padding (adjust for stride effects) */ - if (jcp.f_pad % jcp.stride_d != 0) { - int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; - add(reg_kernel, filter_shift * inp_corr); - add(reg_input_d, input_shift * inp_corr); - } - } else { - /* Filter still overlaps padding (complete reset) */ - sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); - } - - /* Apply correction */ - mov(reg_kd_count, jcp.kd); - jmp(common_block_label); - - L(fpad_end_label); - } - - /* Compute bottom edge */ - if (jcp.back_pad > 0) { - - /* Check if within back_pad region */ - cmp(reg_d_index, input_backpad_overlap - 1); - jl(backpad_end_label, T_NEAR); - jg(backpad_label, T_NEAR); - - /* Execute overlap correction between the filter and the initial - * back_pad region. */ - mov(reg_kd_count, - jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d); - jmp(backpad_end_label, T_NEAR); - - L(backpad_label); - sub(reg_kd_count, jcp.stride_d); - cmp(reg_kd_count, 0); - jle(loop_end_label, T_NEAR); - - L(backpad_end_label); - } - - /* Compute middle block */ - add(reg_input_d, input_shift * jcp.stride_d); - - /* Execute common block and loop */ - L(common_block_label); - add(reg_output_d, output_shift); - inc(reg_d_index); - cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); - jl(d_loop_label, T_NEAR); - - L(loop_end_label); -} - -bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() { - // FIXME: use register mapping from the class declaration - bool ok = jcp.ver == ver_4fma - && everyone_is(0, jcp.dilate_h, jcp.dilate_w) - && everyone_is(1, jcp.stride_h, jcp.stride_w); - if (!ok) return false; - if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) - return false; - - // General code layout: - // - // Blocking over OH -- top level - // (Reduces L2 pressure; not very useful right now) - // Loop over all KHxKW kernel -- emit_kh_kw_loop() - // Loop over OH block -- emit_h_loop() - // Loop over OW blocks -- emit_fma_block() - // (Supports both fully unrolled and partially unrolled versions to - // reduce code size) - // Loop over OW block -- emit_fma_step() - - int max_working_set_size = 128 * 1024; - int pad_ow = jcp.ow; - - int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; - int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in; - int row_size = inp_row_size + out_row_size; - - int h_block_size = jcp.oh; - int working_set_size = row_size * h_block_size; - - if (working_set_size > max_working_set_size) { - int opt_working_set_size = 48 * 1024; - assert(opt_working_set_size < max_working_set_size); - - while (working_set_size > opt_working_set_size) { - for (int i = 2; i <= h_block_size; i++) - if (i == h_block_size) - h_block_size = h_block_size / 2; - else if (h_block_size % i == 0) { - h_block_size = h_block_size / i; - break; - } - working_set_size = row_size * h_block_size; - - if (h_block_size == 1 && working_set_size > opt_working_set_size) - return false; - } - } - - // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below) - if (h_block_size < nstl::max(1, jcp.t_pad) - || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size - : jcp.oh % h_block_size)) - return false; - - // check that we can use simple arithmetic for prefetch address - // calculations - // TODO: we need some traits for this check (Roma) - int cache_line_size = 64; - assert(jcp.ic_block * typesize == 64); - assert(jcp.oc_block * typesize == 64); - - int num_inp_l2_pfs = jcp.tr_iw * h_block_size; - int avg_h_loop_len = h_block_size; - int num_inp_l2_pfs_per_fma_block - = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); - int num_out_l2_pfs = pad_ow * h_block_size; - int num_out_l2_pfs_per_fma_block - = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); - - Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors - Reg64 reg_kh = rax; - Reg64 reg_kw = rbx; - Reg64 reg_tmp = abi_not_param1; - Reg32 reg_tmp_w = reg_tmp.cvt32(); - Reg64 reg_ohs = rdx; - Reg64 reg_ihs = rsi; - Reg64 reg_h = r8; - Reg64 reg_i = r9; - Reg64 reg_j = r10; - - Reg64 reg_inp = r13; - Reg64 reg_out = r14; - Reg64 reg_ker = r15; - - Reg64 reg_inp_pf_l1 = rbp; - - Reg64 reg_inp_pf_l2 = r11; - Reg64 reg_out_pf_l2 = r12; - - Xmm reg_inp_pf_save = xmm17; - Xmm reg_out_pf_save = xmm18; - - Reg64 reg_inp_save = abi_param1; - Reg64 reg_out_save = reg_tmp; - - auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); }; - auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; - auto inp_addr = [&](int oi, int ic1) { - return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in]; - }; - auto out_addr = [&](int oi, int oj = 0) { - assert(jcp.ver == ver_4fma); - return ptr[reg_out - + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in]; - }; - auto ker_addr = [&](int ic1) { - return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out]; - }; - - auto emit_block = [&](int h_block_size, - bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row) - { - // TODO: add an fma version (Roma) - auto pad_ow = jcp.ow; - - int ow4u = rnd_up(pad_ow, 4); - int def_step_size = 16; - - bool has_w_tail = (pad_ow % def_step_size != 0 - || pad_ow % 4 != 0); - bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; - - auto emit_step = [&](int ur_ow, - int num_inp_l1_pfs_per_fma_step, - int num_inp_l2_pfs_per_fma_step, - int num_out_l2_pfs_per_fma_step, bool is_w_tail) - { - bool block_wraparound = is_w_tail && is_last_row; - - assert(ur_ow % 4 == 0); - int tail_size = ow4u % ur_ow; - int this_ur_ow - = (is_w_tail && tail_size) ? tail_size : ur_ow; - int ow_last_chunk4 = pad_ow % 4; - int ow_zero_tail4 = ow_last_chunk4 - ? 4 - ow_last_chunk4 : 0; - - auto emit_out_pf = [&](int oi) { -#if 1 - if (oi + def_step_size < ur_ow || !block_wraparound) - mic_prefetcht0(ptr[reg_out - + ((def_step_size + oi) - * jcp.oc_block * jcp.typesize_in)]); - else { - assert(block_wraparound); - assert(oi + def_step_size >= ur_ow); - mic_prefetcht0(ptr[reg_out_save - + ((oi + def_step_size - ur_ow) - * jcp.oc_block * jcp.typesize_in)]); - } -#else - // XXX: This is an alternative prefetching strategy that - // always prefetches the next row. Keeping it here for - // future experiments (Roma) - if (!block_wraparound) - mic_prefetcht0(ptr[reg_out - + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]); - else - mic_prefetcht0(ptr[reg_out + reg_ohs - - ((h_block_size - 1) * jcp.ow - - oi) * jcp.oc_block * jcp.typesize_in]); -#endif - if (oi < num_out_l2_pfs_per_fma_step) - mic_prefetcht1(ptr[reg_out_pf_l2 - + oi * jcp.oc_block * jcp.typesize_in]); - }; - - auto emit_inp_pf = [&](int oi4, int ic1) { - int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block; - int num_pf_slots = jcp.ic_block * ur_ow / 4; - - int num_pfs = num_inp_l1_pfs_per_fma_step - + num_inp_l2_pfs_per_fma_step; - int pf_freq = nstl::max(1, num_pf_slots / num_pfs); - - if (pf_slot_idx % pf_freq) - return; - - int pf_idx = pf_slot_idx / pf_freq; - - if (pf_idx < num_inp_l2_pfs_per_fma_step) - mic_prefetcht1(ptr[reg_inp_pf_l2 - + pf_idx * jcp.ic_block * jcp.typesize_in]); - else { - pf_idx -= num_inp_l2_pfs_per_fma_step; - // prefetch the 'tail' of the cache line because most of - // the accesses are not aligned - mic_prefetcht0(ptr[reg_inp_pf_l1 - + pf_idx * jcp.ic_block * jcp.typesize_in - + cache_line_size - jcp.typesize_in]); - } - }; - - auto numloads = 4; - - int steps = this_ur_ow; - for (int oi4 = 0; oi4 < steps; oi4 += numloads) { - for (int oi1 = 0; oi1 < numloads; oi1++) { - int oi = oi4 + oi1; - if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) { - vmovups(zmm_out(oi), out_addr(oi)); - emit_out_pf(oi); - } else { - auto zmm = zmm_out(oi); - vpxord(zmm, zmm, zmm); - } - } - - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { - if (jcp.ver == ver_4fma) { - v4fmaddps(zmm_ker(ic1), - zmm_out(oi4), inp_addr(oi4, ic1)); - } else { - assert(!"unknown convolution version"); - } - emit_inp_pf(oi4, ic1); - } - } - }; - - // Input is transposed and padded but we only access about jcp.iw - // elements so use that to compute the # of cache lines in each 'row' - int num_inp_l1_pfs - = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block; - - if (full_w_unroll) { - emit_step(ow4u, num_inp_l1_pfs, - num_inp_l2_pfs_per_fma_block, - num_out_l2_pfs_per_fma_block, true); - add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size); - add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size); - } else { - Label w_loop; - int num_w_iters = pad_ow / def_step_size; - int num_w_iters_full = num_w_iters + has_w_tail; - int num_inp_l1_pfs_per_fma_step - = div_up(num_inp_l1_pfs, num_w_iters_full); - int num_inp_l2_pfs_per_fma_step - = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full); - int num_out_l2_pfs_per_fma_step - = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full); - mov(reg_i, num_w_iters); - L(w_loop); { - emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, - num_inp_l2_pfs_per_fma_step, - num_out_l2_pfs_per_fma_step, false); - add(reg_inp, def_step_size * jcp.typesize_in); - add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in); - add(reg_inp_pf_l1, - num_inp_l1_pfs_per_fma_step * cache_line_size); - add(reg_inp_pf_l2, - num_inp_l2_pfs_per_fma_step * cache_line_size); - add(reg_out_pf_l2, - num_out_l2_pfs_per_fma_step * cache_line_size); - sub(reg_i, 1); - jnz(w_loop); - } - if (has_w_tail) { - emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, - num_inp_l2_pfs_per_fma_step, - num_out_l2_pfs_per_fma_step, true); - add(reg_inp_pf_l2, - num_inp_l2_pfs_per_fma_step * cache_line_size); - add(reg_out_pf_l2, - num_out_l2_pfs_per_fma_step * cache_line_size); - } - // reset reg_inp and reg_out because emit_h_loop expects - // unmodified pointers - int w_offset = num_w_iters * def_step_size; - sub(reg_inp, w_offset * jcp.typesize_in); - sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in); - } - }; - - auto emit_h_loop = [&](int h_block_size, - bool is_last_block, bool is_last_kh_kw_iter) - { - Label h_loop, skip_h_loop; - mov(reg_j, 1); - cmp(reg_j, reg_h); - je(skip_h_loop, T_NEAR); - L(h_loop); { - - lea(reg_inp_pf_l1, - ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]); - emit_block(h_block_size, - is_last_block, is_last_kh_kw_iter, false); - - add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in); - add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in); - add(reg_j, 1); - cmp(reg_j, reg_h); - jb(h_loop); - } - - L(skip_h_loop); - - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) - mic_prefetcht0(ker_addr(ic1)); - - lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]); - emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true); - }; - - auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block, - int h_block_size) - { - xor_(reg_kh, reg_kh); - Label kh_loop, kh_loop_end; - - int last_oh_block_size - = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size); - int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size; - // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size - int ih_block_size = oh_block_size - 1 + jcp.kh - - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad; - - L(kh_loop); { - // determine starting indices for this block - if (is_first_block) { - xor_(reg_tmp, reg_tmp); - mov(reg_ohs, jcp.t_pad); - sub(reg_ohs, reg_kh); - cmovb(reg_ohs, reg_tmp); - - mov(reg_ihs, reg_ohs); - sub(reg_ihs, jcp.t_pad); - add(reg_ihs, reg_kh); - } else { - xor_(reg_ohs, reg_ohs); - mov(reg_ihs, reg_kh); - } - - // determine effective size of block based on padding - mov(reg_tmp, oh_block_size); - sub(reg_tmp, reg_ohs); - mov(reg_h, ih_block_size); - sub(reg_h, reg_ihs); - cmp(reg_tmp, reg_h); - cmovb(reg_h, reg_tmp); - - Label kh_loop_work; - cmp(reg_h, 0); - jg(kh_loop_work, T_NEAR); - - // empty h loop for this jcp.kh: - // - set the output to 0 if necessary - // - move ker pt - // - jump to the end - sub(reg_h, 1); - Label skip_ker_zeroing; - - // The reg_ker ptr has highest bit set if the output needs to be - // zeroed. Those who have byte-aligned their data will suffer the - // consiquences :( - // TODO: move the flag to a mask register? (Roma) - test(reg_ker, 1); - jz(skip_ker_zeroing, T_NEAR); - - Label zeroing_loop; - vpxord(zmm0, zmm0, zmm0); - and_(reg_ker, ~1); // temporarily clear the zeroing flag - mov(reg_tmp, jcp.kw); - L(zeroing_loop); { - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) - vmovups(ker_addr(ic1), zmm0); - add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out); - sub(reg_tmp, 1); - jnz(zeroing_loop, T_NEAR); - } - // restore the zeroing flag (it will be cleared after the end of - // emit_kh_kw_loop, but we may need it until then) - or_(reg_ker, 1); - jmp(kh_loop_end, T_NEAR); - - L(skip_ker_zeroing); - add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw - * jcp.typesize_out); - jmp(kh_loop_end, T_NEAR); - - L(kh_loop_work); - - mul_by_const(reg_ihs, reg_tmp, - jcp.tr_iw * jcp.ic_block * jcp.typesize_in); - mul_by_const(reg_ohs, reg_tmp, - pad_ow * jcp.oc_block * jcp.typesize_in); - - add(reg_inp, reg_ihs); - add(reg_out, reg_ohs); - - Label kw_loop; - xor_(reg_kw, reg_kw); - L(kw_loop); { - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { - auto zmm = zmm_ker(ic1); - vpxord(zmm, zmm, zmm); - mic_prefetcht1(ker_addr(ic1)); - } - - mov(reg_out_save, reg_out); - mov(reg_inp_save, reg_inp); - lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]); - -#if 0 - // XXX: Generate code with special prefetches when switching - // blocks or at the end of the last block. Disabled to reduce - // code size and because there's no performance benefit (Roma) - Label regular_h_loop, end_h_loop; - cmp(reg_kw, jcp.kw - 1); - jne(regular_h_loop, T_NEAR); - cmp(reg_kh, jcp.kh - 1); - jne(regular_h_loop, T_NEAR); - - emit_h_loop(oh_block_size, is_last_block, true); - jmp(end_h_loop, T_NEAR); - - L(regular_h_loop); - emit_h_loop(oh_block_size, is_last_block, false); - - L(end_h_loop); -#else - emit_h_loop(oh_block_size, is_last_block, false); -#endif - - mov(reg_out, reg_out_save); - mov(reg_inp, reg_inp_save); - - Label do_store; - // The reg_ker ptr has highest bit set if the output needs to - // be zeroed. Those who have byte-aligned their data will - // suffer the consiquences :( - mov(reg_tmp, reg_ker); - and_(reg_ker, ~1); - test(reg_tmp, 1); - jnz(do_store, T_NEAR); - - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { - auto zmm = zmm_ker(ic1); - if (jcp.ver == ver_4fma) { - vaddps(zmm, ker_addr(ic1)); - } else { - assert(!"unknown convolution version"); - } - } - - L(do_store); - for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { - auto zmm = zmm_ker(ic1); - vmovups(ker_addr(ic1), zmm); - } - - mov(reg_ker, reg_tmp); - add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out); - add(reg_kw, 1); - cmp(reg_kw, jcp.kw); - jl(kw_loop); - } - - sub(reg_inp, reg_ihs); - sub(reg_out, reg_ohs); - - - L(kh_loop_end); - add(reg_kh, 1); - cmp(reg_kh, jcp.kh); - jl(kh_loop); - } - }; - - mov(reg_inp, ptr[param + GET_OFF(src)]); - mov(reg_out, ptr[param + GET_OFF(dst)]); - mov(reg_ker, ptr[param + GET_OFF(filt)]); - mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]); - mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]); - mov(reg_tmp, ptr[param + GET_OFF(channel)]); - or_(reg_ker, reg_tmp); - - bool single_kh_kw_loop = (h_block_size == jcp.oh); - - size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in; - size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad); - size_t inp_block_step = inp_row_step * h_block_size; - size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in - * h_block_size; - - if (!single_kh_kw_loop) { - // Save the original prefetch pointers from the OpenMP driver - vmovq(reg_inp_pf_save, reg_inp_pf_l2); - vmovq(reg_out_pf_save, reg_out_pf_l2); - mov(reg_inp_pf_l2, reg_inp); - add(reg_inp_pf_l2, first_inp_block_step); - mov(reg_out_pf_l2, reg_out); - add(reg_out_pf_l2, out_block_step); - } - emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size); - - if (!single_kh_kw_loop) { - size_t ker_reset_offset - = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh; - sub(reg_ker, ker_reset_offset); - and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates - - add(reg_inp, first_inp_block_step); - add(reg_out, out_block_step); - mov(reg_inp_pf_l2, reg_inp); - add(reg_inp_pf_l2, inp_block_step); - mov(reg_out_pf_l2, reg_out); - add(reg_out_pf_l2, out_block_step); - - int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2; - if (num_innermost_iters > 0) { - Label h_block_loop; - - mov(reg_tmp_w, num_innermost_iters); - kmovw(reg_h_block, reg_tmp_w); - L(h_block_loop); { - emit_kh_kw_loop(false, false, h_block_size); - sub(reg_ker, ker_reset_offset); - add(reg_inp, inp_row_step * h_block_size); - add(reg_out, out_block_step); - mov(reg_inp_pf_l2, reg_inp); - add(reg_inp_pf_l2, inp_block_step); - mov(reg_out_pf_l2, reg_out); - add(reg_out_pf_l2, out_block_step); - kmovw(reg_tmp_w, reg_h_block); - sub(reg_tmp_w, 1); - kmovw(reg_h_block, reg_tmp_w); - jnz(h_block_loop); - } - } - - // Restore the original prefetch pointers that came from the OpenMP - // driver - vmovq(reg_inp_pf_l2, reg_inp_pf_save); - vmovq(reg_out_pf_l2, reg_out_pf_save); - emit_kh_kw_loop(false, true, h_block_size); - } - - return true; -} - -bool jit_avx512_common_conv_bwd_weights_kernel_f32 - ::flat_4ops_compute() { - const auto &j = jcp; - const bool ok = j.ver == ver_4fma && j.is_1stconv - && everyone_is(0, j.dilate_h, j.dilate_w); - if (!ok) return false; - - Reg64 reg_ptr_tr_src = r8; - Reg64 reg_ptr_dst = r9; - Reg64 reg_ptr_wei = r10; - Reg64 reg_ptr_bia = r11; - - Reg64 reg_kh_step = rax; - Reg64 reg_oh = abi_not_param1; - Reg64 reg_kh = rdx; - - Reg32 reg_flag_save = ebx; - Reg32 reg_flag = esi; - - Zmm vbia(31); - - auto zmm_wei = [&](int kh, int kw) { - return Zmm(8 + kh * j.kw + kw); - }; - auto zmm_dst = [&](int ow) { - return Zmm(ow % 8); - }; - - auto addr_tr_src = [&](int kh, int iw) { - return ptr[reg_ptr_tr_src - + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in]; - }; - auto addr_dst = [&](int ow) { - return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in]; - }; - auto addr_wei = [&](int kh, int kw) { - return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block - * jcp.typesize_out]; - }; - - auto emit_fma_block = [&](int kh_step) { - for (int kh = 0; kh < kh_step; ++kh) { - for (int kw = 0; kw < j.kw; ++kw) { - auto vwei = zmm_wei(kh, kw); - vpxord(vwei, vwei, vwei); - } - } - - for (int ow = 0; ow < j.ow; ow += 4) { - for (int _ow = ow; _ow < ow + 4; ++_ow) { - auto vdst = zmm_dst(_ow); - if (_ow < j.ow) - vmovups(vdst, addr_dst(_ow)); - else - vpxord(vdst, vdst, vdst); - } - - for (int kh = 0; kh < kh_step; ++kh) { - for (int kw = 0; kw < j.kw; ++kw) { - const int iw = ow + (kw % j.stride_w) * j.tr_ld - + (kw / j.stride_w); - v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow), - addr_tr_src(kh, iw)); - if (1 && kh == 0 && kw < 4) { - prefetcht1(ptr[reg_ptr_dst - + (j.ow + ow + kw) * jcp.oc_block - * jcp.typesize_in]); - } - if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */ - const int off = kw + 4 - j.kw; - if (off >= 0 && ow + off < j.ow) - vaddps(vbia, vbia, zmm_dst(ow + off)); - } - } - } - } - - Label l_store; - test(reg_flag, FLAG_MB_FIRST); - jnz(l_store, T_NEAR); - for (int kh = 0; kh < kh_step; ++kh) { - for (int kw = 0; kw < j.kw; ++kw) - vaddps(zmm_wei(kh, kw), addr_wei(kh, kw)); - } - L(l_store); - for (int kh = 0; kh < kh_step; ++kh) { - for (int kw = 0; kw < j.kw; ++kw) - vmovups(addr_wei(kh, kw), zmm_wei(kh, kw)); - } - }; - - auto emit_kh_loop = [&]() { - const int kh_step_rem = j.kh % j.kh_step; - xor_(reg_kh, reg_kh); - mov(reg_kh_step, j.kh_step); - - Label l_kh_loop; - L(l_kh_loop); { - Label l_done; - - if (kh_step_rem != 0) { - Label l_keep_kh_step; - cmp(reg_kh, j.kh - j.kh_step); - jle(l_keep_kh_step, T_NEAR); - - mov(reg_kh_step, kh_step_rem); - emit_fma_block(kh_step_rem); - jmp(l_done, T_NEAR); - - L(l_keep_kh_step); - } - - emit_fma_block(j.kh_step); - - L(l_done); - - add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld - * jcp.typesize_in); - add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out); - add(reg_kh, j.kh_step); - - cmp(reg_kh, j.kh); - jl(l_kh_loop, T_NEAR); - } - - const int kh_steps = rnd_up(j.kh, j.kh_step); - sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in); - sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out); - }; - - auto emit_oh_loop = [&]() { - mov(reg_oh, j.oh); - - Label l_oh_loop; - L(l_oh_loop); { - Label l_restore_mb_flag, l_jump; - - cmp(reg_oh, j.oh); - je(l_restore_mb_flag, T_NEAR); - - and_(reg_flag, ~FLAG_MB_FIRST); - jmp(l_jump, T_NEAR); - - L(l_restore_mb_flag); - mov(reg_flag, reg_flag_save); - - L(l_jump); - - emit_kh_loop(); - - add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld - * jcp.typesize_in); - add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in); - - dec(reg_oh); - jnz(l_oh_loop, T_NEAR); - } - }; - - auto emit_bia_store = [&]() { - if (!j.with_bias) return; - - Label l_bia_store, l_bia_skip; - test(reg_flag, FLAG_IC_FIRST); - jz(l_bia_skip); - - test(reg_flag, FLAG_MB_FIRST); - jnz(l_bia_store, T_NEAR); - vaddps(vbia, ptr[reg_ptr_bia]); - L(l_bia_store); - vmovups(ptr[reg_ptr_bia], vbia); - L(l_bia_skip); - }; - - mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]); - mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]); - mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]); - mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]); - mov(reg_flag_save, ptr[param + GET_OFF(flags)]); - - vpxord(vbia, vbia, vbia); - emit_oh_loop(); - emit_bia_store(); - - return true; -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() -{ - if (flat_4ops_compute()) - return; - if (compute_full_spat_loop()) - return; - - maybe_zero_kernel(); - - if (jcp.ndims == 5) compute_d_loop_common(); - else compute_oh_loop_common(); -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::generate() -{ - preamble(); - - mov(reg_input, ptr[param + GET_OFF(src)]); - mov(reg_output, ptr[param + GET_OFF(dst)]); - mov(reg_kernel, ptr[param + GET_OFF(filt)]); - - compute_loop(); - - postamble(); -} - -status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( - jit_conv_conf_t &jcp, const convolution_desc_t &cd, - memory_desc_t &src_md, memory_desc_t &diff_weights_md, - memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) { - if (!mayiuse(avx512_common)) - return status::unimplemented; - - const memory_desc_wrapper src_d(&src_md); - const memory_desc_wrapper diff_weights_d(&diff_weights_md); - const memory_desc_wrapper diff_bias_d(&diff_bias_md); - const memory_desc_wrapper diff_dst_d(&diff_dst_md); - - const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; - int ndims = src_d.ndims(); - - jcp = zero(); - - jcp.simd_w = cpu_isa_traits::vlen / sizeof(float); - jcp.ndims = ndims; - jcp.prop_kind = cd.prop_kind; - - jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; - jcp.iw = src_d.dims()[ndims-1]; - jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; - jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; - jcp.ow = diff_dst_d.dims()[ndims-1]; - - jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; - jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; - jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; - - jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; - jcp.l_pad = cd.padding[0][ndims-3]; - - jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; - jcp.stride_w = cd.strides[ndims-3]; - - jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; - jcp.dilate_w = cd.dilates[ndims-3]; - - const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); - bool ok = true - // general condition to simplify dilations - && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) - && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) - && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) - // special condition to simplify dilations in compute_oh_loop_common - && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih); - if (!ok) - return status::unimplemented; - - jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h - + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); - jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d - + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); - - /* XXX: currently, does not support dilation_d > 0 */ - if (ndims == 5) - if (jcp.dilate_d > 0) - return status::unimplemented; - - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - jcp.ohp = jcp.oh; - jcp.owp = jcp.ow; - jcp.aligned_threads = 0; - - /* check for the 1st convolution */ - jcp.is_1stconv = is_1stconv(jcp); - - jcp.oc_block = jcp.simd_w; - - bool ok_to_pad_channels = true - && jcp.ngroups == 1 - && src_d.data_type() == data_type::f32; - - if (ok_to_pad_channels) - jcp.oc = rnd_up(jcp.oc, jcp.simd_w); - - if (jcp.oc % jcp.oc_block) - return status::unimplemented; - - auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = with_groups - ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) - : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); - - if (diff_dst_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); - jcp.dst_tag = dst_tag; - } else { - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag); - } - if (jcp.dst_tag != dst_tag) - return status::unimplemented; - - /* conditions on bias memory */ - jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; - if (jcp.with_bias) { - if (diff_bias_d.format_kind() == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_bias_md, x)); - } - - jcp.nb_oc = jcp.oc / jcp.oc_block; - - /* kernel applicability check wrt boundaries - * the conditions are quite general across the kernels we have, - * but ideally the check should belong to a specific kernel... */ - const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; - const bool boundaries_ok = true - && jcp.t_pad <= max_pad - && jcp.b_pad <= max_pad - && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad) - && jcp.f_pad < jcp.kd; - if (!boundaries_ok) - return status::unimplemented; - - /* yet another common check */ - if (jcp.kw > 14) - return status::unimplemented; - - /* setting register strategy */ - for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) { - if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; } - } - - if (jcp.is_1stconv) { - auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw); - if (src_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md, src_tag)); - jcp.src_tag = src_tag; - } else { - jcp.src_tag = src_d.matches_one_of_tag(src_tag); - if (jcp.ic == 1 && jcp.src_tag != src_tag) - jcp.src_tag = src_d.matches_one_of_tag( - pick(ndims - 3, nwc, nhwc, ndhwc)); - } - if (jcp.src_tag == format_tag::undef) - return status::unimplemented; - - const bool src_ok = true - && utils::everyone_is(data_type::f32, - src_d.data_type(), diff_weights_d.data_type(), - diff_dst_d.data_type()) - && one_of(jcp.ic, 1, 2, 3) - && jcp.ngroups == 1; - if (!src_ok) - return status::unimplemented; - - const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, - jcp.stride_w), 16); - const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1); - const int kh_step_rem = jcp.kh % kh_step; - - const auto wei_4fma_tag = with_groups - ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o) - : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o); - - auto current_wei_tag = format_tag::undef; - if (diff_weights_d.format_kind() != format_kind::any) - current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag); - - const bool use_4fma = true - && one_of(ndims, 3, 4) - && mayiuse(avx512_mic_4ops) - && mkldnn_thr_syncable() - && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) - && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) - && jcp.kw <= 28 - jcp.with_bias - && jcp.stride_w == 4 - && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */ - && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */ - && IMPLICATION(diff_weights_d.format_kind() != format_kind::any, - current_wei_tag == wei_4fma_tag); - - if (use_4fma) { - jcp.ver = ver_4fma; - jcp.kh_step = kh_step; - jcp.tr_ld = tr_ld; - jcp.ic_block = 1; - if (diff_weights_d.format_kind() == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag)); - jcp.wei_tag = wei_4fma_tag; - } else { - jcp.ver = ver_fma; - jcp.ic_block = jcp.ic; - - wei_tag = with_groups - ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) - : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); - - if (diff_weights_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); - jcp.wei_tag = wei_tag; - } else { - jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); - } - if (jcp.wei_tag != wei_tag) - return status::unimplemented; - } - - jcp.nb_ic = jcp.ic / jcp.ic_block; - } else { - auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); - if (src_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md, src_tag)); - jcp.src_tag = src_tag; - } else { - jcp.src_tag = src_d.matches_one_of_tag(src_tag); - } - if (jcp.src_tag != src_tag) - return status::unimplemented; - - if (diff_weights_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); - jcp.wei_tag = wei_tag; - } else { - jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); - } - if (jcp.wei_tag != wei_tag) - return status::unimplemented; - - jcp.ic_block = jcp.simd_w; - if (ok_to_pad_channels) - jcp.ic = rnd_up(jcp.ic, jcp.ic_block); - jcp.nb_ic = jcp.ic / jcp.ic_block; - if ((mayiuse(avx512_mic) || mayiuse(avx512_core)) - && utils::everyone_is(data_type::f32, - src_d.data_type(), diff_weights_d.data_type(), - diff_dst_d.data_type())) { - jcp.ver = ver_fma; - if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && - everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && - mkldnn_thr_syncable()) { - jcp.ver = ver_4fma; - } - } else { - return status::unimplemented; - } - if (jcp.ver == ver_4fma) { - jcp.ur_w = jcp.ow; - // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to - // cross the right boundary. The only requirement is not to have - // NaNs there because another multiplicand is always guaranteed to - // be zero. This also may require the top-level driver to allocate - // four extra guarding elements at the very end of the buffer. - // I'm not proud of this hack, but it improves performance by - // about 5-10% depending on the dimensions (Roma) - - const int tr_round = 4; - - jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round); - jcp.tr_src_num_guard_elems = tr_round; // upper bound - } - } - - if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) { - jcp.typesize_in = sizeof(float); - jcp.typesize_out = sizeof(float); - } else - return status::unimplemented; - - bool args_ok = true - && jcp.ic % jcp.ic_block == 0 - && jcp.oc % jcp.oc_block == 0 - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= diff_dst_d.padded_dims()[1] - && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; - if (!args_ok) return status::unimplemented; - - { // balancing - int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; - balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); - jcp.nthr = nthr; - jcp.nthr_mb = nthr_mb; - jcp.nthr_g = nthr_g; - jcp.nthr_oc_b = nthr_oc_b; - jcp.nthr_ic_b = nthr_ic_b; - } - - return status::success; -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - if (jcp.ver == ver_4fma) { - if (jcp.is_1stconv) { - const size_t tr_src_size = - jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld; - scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); - } else { - // XXX: See the comment about tr_iw and guarding elements in - // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() - const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic; - const size_t min_tr_src_size_per_thr - = jcp.ih * jcp.ic_block * jcp.tr_iw; - const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr - + jcp.tr_src_num_guard_elems; - scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); - } - - /* prepare synchronization contexts */ - if (jcp.nthr_oc_b > 1) { - const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; - scratchpad.book(key_conv_tr_src_bctx, - sizeof(simple_barrier::ctx_t) * tr_src_bctx_size); - } - } - - if (jcp.nthr_mb > 1) { - const int wei_size = jcp.ngroups * jcp.oc * jcp.ic - * jcp.kh * jcp.kw * jcp.kd; - const int bia_size = jcp.ngroups * jcp.oc; - const size_t wei_bia_reduction_size = wei_size + bia_size; - - scratchpad.book(key_conv_wei_bia_reduction, - jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1)); - scratchpad.book(key_conv_wei_bia_reduction_bctx, - sizeof(simple_barrier::ctx_t)); - } - - if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) - scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); -} - -void jit_avx512_common_conv_bwd_weights_kernel_f32::balance( - const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, - int &nthr_oc_b_, int &nthr_ic_b_) -{ - nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; - - const int max_threads = mkldnn_get_max_threads(); - - if (max_threads < j.ngroups) { - /* simplification... fortunately it doesn't hurt much */ - return; - } - - if (!mkldnn_thr_syncable() && j.ver == ver_4fma) { - // should not happen -- the driver is not ready - // for TBB-like non-synchronous threading yet - return; - } - - if (j.ver == ver_4fma && j.is_1stconv) { - nthr_g_ = 1; - nthr_oc_b_ = 1; - nthr_ic_b_ = nstl::min(j.nb_ic, max_threads); - nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb); - nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_; - return; - } - - nthr_g_ = j.ngroups; - const int nthr = max_threads / nthr_g_; - - auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { - /* calculate per thread memory cost (read/write). high level optimizer - * tries to minimize memory consumption. few notes: - * (n1) unclear why, but that essentially helps first convolution... - * (n2) assuming the reduction over minibatch is always there: - * - instead of 8 it should be 5 here (write ~= 2 read): - * kernel: temporal workspace 1 write - * reduction: 1 read from workspace and 1 write to the diff_wei - * - but experiments showed 8 works better than 5 or 6... */ - - const int src_coef = j.ver == ver_4fma ? 4 : 1; - const int dst_coef = 1; - const int wei_coef = 8; - - return 0 - + src_coef - * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) - * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id - / j.stride_d / j.stride_h / j.stride_w /* (n1) */ - + dst_coef - * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) - * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od - + wei_coef /* (n2) */ - * div_up(j.ngroups, nthr_g_) - * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) - * j.kh * j.kw * j.kd * j.ic_block * j.oc_block; - }; - - int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); - - /* step 1: find the best thread distribution with lowest memory cost */ - const int nthr_mb_max = nstl::min(nthr, j.mb * j.od); - for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { - const int nthr_par = nthr / nthr_mb; - const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); - for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { - int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); - - int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - if (mem_cost <= best_mem_cost) { - best_mem_cost = mem_cost; - nthr_mb_ = nthr_mb; - nthr_oc_b_ = nthr_oc_b; - nthr_ic_b_ = nthr_ic_b; - } - } - - if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } - } - - if (!mayiuse(avx512_mic)) { - auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { - return 1 - * div_up(j.mb, nthr_mb) - * div_up(j.ngroups, nthr_g_) - * div_up(j.nb_oc, nthr_oc_b) - * div_up(j.nb_ic, nthr_ic_b); - }; - - /* step 2: search for a thread distribution with lower compute cost. - * the constrains: - * - memory cost cannot exceed 110% of the best found in the step 1 - * - unless compute cost is 133% lower than the current best case - * note: both constants were found empirically */ - int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); - for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { - const int nthr_par = nthr / nthr_mb; - const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); - for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { - int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); - int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b); - - const bool opt1 = comp_cost <= best_comp_cost - && mem_cost < 1.1 * best_mem_cost; - const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost; - - if (opt1 || opt2) { - best_comp_cost = comp_cost; - nthr_mb_ = nthr_mb; - nthr_oc_b_ = nthr_oc_b; - nthr_ic_b_ = nthr_ic_b; - } - } - - if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } - } - } - - if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads) - nthr_mb_ = nstl::min(j.mb * j.od, max_threads); - nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; - - assert(nthr_ <= max_threads); - assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1)); -} - -template struct _jit_avx512_common_conv_fwd_kernel; -template struct _jit_avx512_common_conv_fwd_kernel; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp deleted file mode 100644 index f76770797..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp +++ /dev/null @@ -1,423 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP -#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { - - _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, - const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32( - this, jcp.eltwise); - - generate(); - jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); - } - - ~_jit_avx512_common_conv_fwd_kernel() { - delete eltwise_injector_; - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) - - jit_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker_)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - enum { - typesize = sizeof(float), - ker_reg_base_idx = 28, - }; - - reg64_t param = abi_param1; - reg64_t reg_inp = r8; - reg64_t reg_ker = r9; - reg64_t reg_out = r10; - - reg64_t reg_inp_prf = r11; - reg64_t reg_ker_prf = r12; - reg64_t reg_out_prf = r13; - reg64_t reg_owb = r12; - - reg64_t aux_reg_inp = r14; - reg64_t aux_reg_ker = r15; - - reg64_t aux_reg_inp_prf = rsi; - reg64_t aux_reg_ker_prf = rdx; - - reg64_t reg_channel = rsi; - reg64_t reg_bias = rdx; - - reg64_t aux_reg_ker_d = r9; - reg64_t aux_reg_inp_d = rbx; - reg64_t aux_reg_inp_d_prf = r13; - reg64_t aux_reg_ker_d_prf = abi_not_param1; - reg64_t reg_ki = r10; - - reg64_t reg_kj = rax; - reg64_t reg_relu_ns = rax; - reg64_t reg_oi = rbx; - reg64_t reg_kh = abi_not_param1; - - reg64_t reg_tmp = rbp; - - reg64_t reg_ic_loop = rdx; - reg64_t reg_inp_loop = rsi; - - reg64_t reg_init_flag = r13; - reg64_t reg_bias_ptr = param; - - reg64_t aux_reg_ic = r12; - reg64_t reg_binp = rax; - reg64_t reg_bout = r11; - reg64_t aux1_reg_inp = rbx; - reg64_t aux_reg_out = abi_not_param1; - - reg64_t reg_long_offt = r11; - reg64_t reg_out_long_offt = r14; - - inline Vmm vmm_ker(int i_ic) { - assert(i_ic < 4); - return Vmm(ker_reg_base_idx + i_ic); - } - - inline Vmm vmm_out(int i_ur, int i_oc) { - int idx = i_ur + i_oc * jcp.ur_w; - assert(idx < ker_reg_base_idx); - return Vmm(idx); - } - - inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { - int idx = i_ic + nb_x_blocking * jcp.ur_w; - assert(idx < 31); - return Vmm(idx); - } - - Xbyak::Reg64 imm_addr64 = r15; - Vmm vmm_wei = Vmm(31); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - inline void prepare_output(int ur_w); - inline void store_output(int ur_w); - inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); - inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); - inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); - inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); - inline void compute_loop(int ur_w, int pad_l, int pad_r); - - void generate(); - - inline size_t get_output_offset(int oi, int n_oc_block) { - return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh - * jcp.ow * jcp.od + oi) * jcp.oc_block; - } - - inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { - size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; - size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; - return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) - + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); - } - - inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { - return jcp.typesize_in * jcp.oc_block - * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd - + (ic + ker_number) + ki * jcp.ic_block); - } - - inline int get_ow_start(int ki, int pad_l) { - return nstl::max(0, - utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); - } - - inline int get_ow_end(int ur_w, int ki, int pad_r) { - return ur_w - nstl::max(0, utils::div_up(pad_r - - (jcp.kw - 1 - ki) - * (jcp.dilate_w + 1), - jcp.stride_w)); - } -}; - -struct jit_avx512_common_conv_fwd_kernel { - - jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, - const primitive_attr_t &attr) : - jit_ker(nullptr), - zmm_kernel_(nullptr), - xmm_kernel_(nullptr) { - int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; - switch (ch_block) { - case 16: - zmm_kernel_ = - new _jit_avx512_common_conv_fwd_kernel( - ajcp, attr); - jit_ker = zmm_kernel_->jit_ker_; - return; - case 4: - xmm_kernel_ = - new _jit_avx512_common_conv_fwd_kernel( - ajcp, attr); - jit_ker = xmm_kernel_->jit_ker_; - return; - default: - assert(!"invalid channel blocking"); - } - } - - ~jit_avx512_common_conv_fwd_kernel() { - delete xmm_kernel_; - delete zmm_kernel_; - } - - enum { - typesize = sizeof(float) - }; - - static bool post_ops_ok(jit_conv_conf_t &jcp, - const primitive_attr_t &attr); - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, - memory_desc_t &src_pd, - memory_desc_t &weights_pd, - memory_desc_t &dst_pd, - memory_desc_t &bias_pd, - const primitive_attr_t &attr, - int nthreads); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - void(*jit_ker)(jit_conv_call_s *); - _jit_avx512_common_conv_fwd_kernel *zmm_kernel_; - _jit_avx512_common_conv_fwd_kernel *xmm_kernel_; -}; - -struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { - - jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) - { - generate(); - jit_ker = (void (*)(jit_conv_call_s *))getCode(); - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - enum { - typesize = sizeof(float), - ker_reg_base_idx = 28, - }; - - reg64_t param = abi_param1; - reg64_t reg_dst = r8; - reg64_t reg_ker = r9; - reg64_t reg_src = r10; - - reg64_t reg_dst_prf = r11; - reg64_t reg_ker_prf = r12; - reg64_t reg_src_prf = r13; - - reg64_t aux_reg_dst = r14; - reg64_t aux_reg_ker = r15; - - reg64_t aux_reg_dst_prf = rsi; - reg64_t aux_reg_ker_prf = rdx; - - reg64_t aux_reg_dst_d_prf = r13; - reg64_t aux_reg_dst_d = rbx; - reg64_t aux_reg_ker_d_prf = abi_not_param1; - reg64_t aux_reg_ker_d = r9; - reg64_t reg_ki = r10; - - reg64_t reg_kj = rax; - reg64_t reg_oi = rbx; - reg64_t reg_kh = abi_not_param1; - - reg64_t reg_channel = rsi; - - reg64_t reg_tmp = rbp; - reg64_t reg_long_offt = r14; - - inline Xbyak::Zmm zmm_ker(int i_ic) { - assert(i_ic < 4); - return Xbyak::Zmm(ker_reg_base_idx + i_ic); - } - inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { - int idx = i_ic + nb_x_blocking * jcp.ur_w; - assert(idx < 31); - return Xbyak::Zmm(idx); - } - inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { - int idx = i_ur + i_oc * jcp.ur_w; - assert(idx < ker_reg_base_idx); - return Xbyak::Zmm(idx); - } - - Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); - - inline void prepare_output(int ur_w); - inline void store_output(int ur_w); - inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); - inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); - inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); - inline void compute_loop(int ur_w, int l_overflow, int r_overflow); - void generate(); - - inline int get_iw_start(int ki, int l_overflow) - { - int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w - + l_overflow * jcp.stride_w - - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); - while (res < 0) - res += jcp.stride_w; - - return res; - } - - inline int get_iw_end(int ur_w, int ki, int r_overflow) - { - if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) - ur_w += nstl::min(0, jcp.r_pad); // remove negative padding - int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w - + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); - while (res < 0) - res += jcp.stride_w; - - return ur_w - res; - } -}; - -struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { - - jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) - : jcp(ajcp) - { - generate(); - jit_ker = (void (*)(jit_conv_call_s *))getCode(); - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, - memory_desc_t &src_md, - memory_desc_t &diff_weights_md, - memory_desc_t &diff_bias_md, - memory_desc_t &diff_dst_md); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - enum {typesize = sizeof(float)}; - static const int max_ur_w; - - reg64_t param = abi_param1; - reg64_t reg_input = rax; - reg64_t reg_kernel = rdx; - reg64_t reg_output = rsi; - reg64_t b_ic = abi_not_param1; - reg64_t kj = r8; - reg64_t reg_kh = r9; - reg64_t reg_ur_w_trips = r10; - reg64_t reg_oj = r15; - reg64_t reg_ih_count = rbx; - reg64_t reg_tmp = r14; - reg64_t reg_long_offt = r14; - - reg64_t ki = r11; - reg64_t reg_kd_count = r12; - reg64_t reg_oi = r12; - reg64_t reg_d_index = r13; - reg64_t reg_input_d = r15; - reg64_t reg_output_d = rbx; - reg64_t aux_reg_input = r12; - reg64_t aux_reg_kernel = r13; - reg64_t reg_bias = rbx; - - inline void bias_kernel(); - inline void maybe_zero_kernel(); - inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, - int max_ur_w); - inline void od_step_comeback_pointers(); - inline void oh_step_comeback_pointers(); - inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); - inline void compute_ic_block_step(int ur_w, - int pad_l, int pad_r, int ic_block_step, - int input_offset, int kernel_offset, int output_offset, - bool input_wraparound = false); - inline void compute_ic_block_step_fma(int ur_w, - int pad_l, int pad_r, int ic_block_step, - int input_offset, int kernel_offset, int output_offset, - bool input_wraparound); - inline void compute_ic_block_step_4fma(int ur_w, - int pad_l, int pad_r, int ic_block_step, - int input_offset, int kernel_offset, int output_offset, - bool input_wraparound); - inline void compute_oh_step_common(int ic_block_step, int max_ur_w); - inline void compute_oh_step_disp(); - inline void compute_oh_loop_common(); - inline void compute_d_loop_common(); - - inline bool compute_full_spat_loop(); - inline bool flat_4ops_compute(); - - inline void compute_loop(); - - void generate(); - - static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, - int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp deleted file mode 100644 index 1bdcd0d6a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp +++ /dev/null @@ -1,1163 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "cpu_memory.hpp" - -#include - -#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" - -#ifndef KERNEL_SIZE_THRESHOLD -#define KERNEL_SIZE_THRESHOLD 16 -#endif - -#define MIN_REQUIRED_DIMN_REG_BLOCK 14 - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace { - -using namespace mkldnn::impl::utils; - -unsigned int L1_cache_size = get_cache_size(1, true); -unsigned int L2_cache_size = get_cache_size(2, true); -unsigned int LLC_data_size = get_cache_size(3, false); - -// the test funtion takes jcp, the candidate and the current best. -// it returns true if the new candidate is better -int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, - int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) -{ - int best_divisor = default_best; - auto test_num - = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { - if (test(jcp, num, best_divisor)) { - best_divisor = num; - } - }; - - for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { - if (number % divisor == 0) { - test_num(jcp, divisor); - test_num(jcp, number / divisor); - } - } - - return best_divisor; -} - -namespace { -bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { - if (jcp.ver == ver_4fma) - return jcp.mb >= 32; - else - return jcp.mb >= 16; -} -} - -/* assumes 512 bits registers */ -/* TODO: add support for strides */ -/* TODO: handle the prefetch distance automatically */ -typedef enum cache_t_ { L1, L2, L3 } cache_t; - -template -struct prefetcher_t { - prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, - cache_t cache_type, size_t block_size, /* in number of elements*/ - int nb_instructions_in_block, int fma_ipc) - : cg_(generator) - , reg_base_addr_(reg_base_addr) - , cache_type_(cache_type) - , cache_block_size_(block_size) - { - nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); - prefetch_spread_ - = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); - prefetch_blk_ - = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); - - /* assumption: when fetch in Li, data is already in L(i+1) */ - int cache_latency; - switch (cache_type_) { - case L1: cache_latency = 14; break; - case L2: - case L3: - default: cache_latency = 250; break; - } - - prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); - } - - void prefetch(int instruction_number) - { - if (instruction_number % prefetch_spread_ == 0) { - for (int i = 0; (i < prefetch_blk_) - && (prefetches_issued_ < nb_cache_lines_to_prefetch_); - i++, prefetches_issued_++) { - prefetch_inst_(cg_->EVEX_compress_addr( - reg_base_addr_, (cache_block_size_ * prefetch_distance_) - * sizeof(data_t) - + (prefetches_issued_ * 64))); - } - } - } - -private: - void prefetch_inst_(const Xbyak::Address &addr) - { - switch (cache_type_) { - case L1: cg_->prefetcht0(addr); break; - case L2: cg_->prefetcht1(addr); break; - case L3: cg_->prefetcht2(addr); break; - default: - break; // TODO: raise an exception or put an assert - } - } - - jit_generator *cg_; - Xbyak::Reg64 reg_base_addr_; - cache_t cache_type_; - int cache_block_size_ = 0; - int nb_cache_lines_to_prefetch_ = 0; - int prefetches_issued_ = 0; - int prefetch_spread_ = 0; - int prefetch_blk_ = 0; - int prefetch_distance_ = 0; -}; - -// utilities to support kernel parameter selection -bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, - int dimM_block, int dimM_simd_block, float C) -{ - float lhs = (dimM_block * dimN_reg_block * dimM_simd_block - + dimM_block * dimK_block * dimK_reg_block - * dimM_simd_block - + dimK_block * dimN_reg_block * dimK_reg_block) - * (float)sizeof(float); - float rhs = C * L1_cache_size; - return (lhs < rhs); -} - -bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, - int dimM_block, int dimM_simd_block, float C) -{ - float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block - + dimK_block * dimN_reg_block * dimK_reg_block) - * (float)sizeof(float); - float rhs = C * L1_cache_size; - return (lhs < rhs); -} - -bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, - int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block, - float C) -{ - float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block - + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block - * dimM_simd_block - + nb_dimN_reg_block * dimK_nb_block * dimK_block - * dimN_reg_block * dimK_reg_block) - * (float)sizeof(float); - float rhs = C * L2_cache_size; - return (lhs < rhs); -} -} - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate( - bool is_beta_zero) -{ - // const int dimK_simd_block = jcp.dimK_reg_block; - - // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) - // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) - // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; - // dimK_reg_block++) - // for (int tile =0; tile < jcp.dimN_reg_block; tile++) - // C[dimM_block][tile] += - // A[dimM_block][dimK_block][dimK_reg_block] * - // broadcast(B[dimK_block][tile][dimK_reg_block]); - // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block], - // so we load it before the loop on tile - // 2) the loop on tile must be fully unrolled. Don't know about the one on - // dimK_reg_block. I think it should be - - auto inner_loops = [=]() { - Label dimM_block_loop, dimK_block_loop; - const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1; - const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; - - prefetcher_t L1_pf(this, reg_srcB, L1, - jcp.dimN_reg_block * jcp.dimK_reg_block, - jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, - fma_ipc); - prefetcher_t L2_pf(this, reg_srcB, L2, - jcp.dimN_reg_block * jcp.dimK_reg_block, - jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, - fma_ipc); - - if (jcp.dimM_block > 1) { - mov(reg_dimM_block_loop_cnt, jcp.dimM_block); - L(dimM_block_loop); - } - { - // First, we zero the accumulators if first nb_ic iteration, - // otherwise we load them - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - Zmm zmm(jcp.zmm_start + tile); - if (is_beta_zero) - vpxord(zmm, zmm, zmm); - else - vmovups(zmm, zword[reg_dstC + 64 * tile]); - } - - if (jcp.dimK_block > 1) { - mov(reg_dimK_block_loop_cnt, jcp.dimK_block); - L(dimK_block_loop); - } - { - auto load_A = [=](int reg_idx, int offset) { - for (int i = 0; i < inc_dimK_reg_block; i++) - vmovups(Zmm(reg_idx + i), - zword[reg_srcA + 64 * (offset + i)]); - }; - - // Used when doing double buffering - int next = 0; - if (jcp.double_buffering) { - load_A(next, 0); - } - for (int dimK_reg_block = 0; - dimK_reg_block < jcp.dimK_reg_block; - dimK_reg_block += inc_dimK_reg_block) { - int current; - /* Loading the next vector from A */ - current = next; - if (jcp.double_buffering) { - next = (dimK_reg_block + inc_dimK_reg_block) - % (2 * inc_dimK_reg_block); - load_A(next, dimK_reg_block + inc_dimK_reg_block); - } else { - next = 0; - load_A(next, dimK_reg_block); - } - /* Performing the fmas */ - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - Zmm zmm(jcp.zmm_start + tile); - if (jcp.ver != ver_avx512_core) - L1_pf.prefetch( - dimK_reg_block * jcp.dimN_reg_block + tile); - if (jcp.ver == ver_4fma) - v4fmaddps(zmm, Zmm(current), - EVEX_compress_addr(reg_srcB, - 64 * tile + dimK_reg_block * 4)); - else - vfmadd231ps(zmm, Zmm(current), - EVEX_compress_addr(reg_srcB, - 64 * tile + dimK_reg_block * 4, - true)); - if (jcp.ver != ver_avx512_core) - L2_pf.prefetch( - dimK_reg_block * jcp.dimN_reg_block + tile); - } - } - - add(reg_srcA, jcp.dimK_reg_block * 64); - add(reg_srcB, jcp.dimN_reg_block * 64); - if (jcp.dimK_block > 1) { - sub(reg_dimK_block_loop_cnt, 1); - jnz(dimK_block_loop); - } - } - - - auto store_output = [=](bool output_is_aligned) { - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - Zmm zmm(jcp.zmm_start + tile); - if (output_is_aligned - && jcp.dimK_nb_block == 1 - && (jcp.dimN * jcp.dimM * alpha * alpha - * sizeof(float) > 2 * LLC_data_size)) - vmovntps(zword[reg_dstC + 64 * tile], zmm); - else - vmovups(zword[reg_dstC + 64 * tile], zmm); - } - }; - - Label unaligned_store, end_store; - test(reg_dstC, cpu_isa_traits::vlen - 1); - jnz(unaligned_store, T_NEAR); - store_output(true); - jmp(end_store, T_NEAR); - L(unaligned_store); { - store_output(false); - } - L(end_store); - - if (jcp.dimM_block > 1) { - sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); - add(reg_dstC, jcp.dimN_reg_block * 64); - sub(reg_dimM_block_loop_cnt, 1); - jnz(dimM_block_loop); - } - } - }; - - /* Preamble */ - preamble(); - - /* kernel */ - inner_loops(); - - /* Postamble */ - postamble(); - ret(); -} - -status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d) -{ - - if (mayiuse(avx512_core)) - return status::unimplemented; - else if (!mayiuse(avx512_common)) - return status::unimplemented; - else if (mayiuse(avx512_mic_4ops)) - jcp.ver = ver_4fma; - else - jcp.ver = ver_fma; - - jcp.nthr = mkldnn_get_max_threads(); - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = dst_d.dims()[2]; - jcp.ow = dst_d.dims()[3]; - jcp.kh = weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - jcp.r_pad = nstl::max( - 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); - jcp.b_pad = nstl::max( - 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - jcp.ohp = jcp.oh; - jcp.owp = jcp.ow; - - bool ok_to_pad_channels = jcp.ngroups == 1; - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, - is_winograd_faster_than_direct(jcp))) - return status::unimplemented; - - // Checking conditions not supported by these kernels - if (jcp.ngroups != 1) - return status::unimplemented; - if ((jcp.kh != 3) || (jcp.kw != 3)) - return status::unimplemented; - if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) - return status::unimplemented; - if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) - return status::unimplemented; - if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) - return status::unimplemented; - - format_tag_t dat_tag = nChw16c; - format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - if (jcp.src_tag != dat_tag) return status::unimplemented; - if (jcp.wei_tag != wei_tag) return status::unimplemented; - if (jcp.dst_tag != dat_tag) return status::unimplemented; - - bool layout_consistency = true - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= dst_d.padded_dims()[1] - && jcp.ic <= weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; - if (!layout_consistency) return status::unimplemented; - - return status::success; -} - - -status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) { - - auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, - int dimN_reg_block, int current_best) { - return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK) - && (dimN_reg_block < jcp.nb_reg) - && (dimN_reg_block < current_best); - }; - jcp.dimN_reg_block = get_divisor_satisfying_cond( - jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block); - - if (jcp.dimN_reg_block >= jcp.nb_reg) { - auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, - int dimN_reg_block, int current_best) { - return (dimN_reg_block < jcp.nb_reg) - && (dimN_reg_block > current_best); - }; - - jcp.dimN_reg_block = get_divisor_satisfying_cond( - jcp, jcp.dimN, 1, test_cond_dimN_reg_block); - } - - //********************* Choosing dimK_block **********************// - auto test_cond1_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, - 1, jcp.dimM_simd_block, .75f) - && (dimK_block > current_best); - }; - - auto test_cond1_bis_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond1_bis(jcp.dimN_reg_block, dimK_block, - jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f) - && (dimK_block > current_best); - }; - - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); - // If we are not able to use streams, we fall back to condition [1] - if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); - jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; - - //********************* Choosing dimM_block **********************// - jcp.dimM_simd_block = 16; - /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/ - auto test_cond1_dimM_block = []( - jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { - return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, - jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f) - && (dimM_block > current_best); - }; - - auto test_cond1_bis_dimM_block = []( - jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { - return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, - jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f) - && (dimM_block > current_best); - }; - - if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) - jcp.dimM_block = get_divisor_satisfying_cond( - jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); - else - jcp.dimM_block = get_divisor_satisfying_cond(jcp, - jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block); - jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; - - //******************* Choosing dimN_block *******************// - auto test_cond2_dimN_block = []( - jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { - return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, - jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, - jcp.dimM_simd_block, .5f) - && (dimN_block > current_best); - }; - - jcp.dimN_block = get_divisor_satisfying_cond( - jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); - jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); - jcp.sched_policy = WSCHED_DATA_W_S_G_D; - return status::success; -} - -status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel( - jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) -{ - jcp.dimK_reg_block = 16; - jcp.dimM_simd_block = 16; - - // TODO: replace double buffering with nuple buffering to maximize register - // usage. - // the choice of the number of buffers will then come after choosing - // dimN_reg_block - jcp.double_buffering = true; - if (jcp.double_buffering) - jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2); - else - jcp.zmm_start = 1; - jcp.nb_reg = 32 - jcp.zmm_start; - - jcp.dimN = dimN; - jcp.dimK = dimK; - jcp.dimM = dimM; - - jcp.sched_policy = WSCHED_INVALID; - set_wsched_DATA_W_S_G_D_avx512_common(jcp); - - assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D); - return status::success; -} - -bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_relu(0) || is_sum(0); // relu or sum - case 2: return (is_sum(0) && is_relu(1)) || - (is_relu(0) && is_sum(1)); // sum->relu or relu->sum - case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu - default: return false; - } - - return false; -} - -status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { - status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d); - - if (st != status::success) - return st; - - // Winograd specific initialization - jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; - jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; - jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; - jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; - - status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); - jcp.ic_simd_block = jcp.dimK_reg_block; - jcp.ic_block = jcp.dimK_block; - jcp.nb_ic = jcp.dimK_nb_block; - jcp.oc_simd_block = jcp.dimM_simd_block; - jcp.oc_block = jcp.dimM_block; - jcp.nb_oc = jcp.dimM_nb_block; - jcp.tile_block_ur = jcp.dimN_reg_block; - jcp.nb_tile_block_ur = jcp.dimN_block; - jcp.tile_block = jcp.dimN_nb_block; - jcp.tile_4fma_padding = 0; // only relevant for backward weights - - return res; -} - -status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d) -{ - status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); - - if (st != status::success) - return st; - - jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; - jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; - jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; - - status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); - jcp.oc_simd_block = jcp.dimK_reg_block; - jcp.oc_block = jcp.dimK_block; - jcp.nb_oc = jcp.dimK_nb_block; - jcp.ic_simd_block = jcp.dimM_simd_block; - jcp.ic_block = jcp.dimM_block; - jcp.nb_ic = jcp.dimM_nb_block; - jcp.tile_block_ur = jcp.dimN_reg_block; - jcp.nb_tile_block_ur = jcp.dimN_block; - jcp.tile_block = jcp.dimN_nb_block; - jcp.tile_4fma_padding = 0; // only relevant for backward weights - - return res; -} - -void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate() -{ - auto load_B = [=](int reg_idx, int offset) { - for (int i = 0; i < 4; i++) { - vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]); - } - }; - - preamble(); - int curr = 0; - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - int origB_offset = (j * alpha + i) * jcp.dimK_4fma; - size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block * - jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block * - jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float); - mov(reg_transB_idx, transB_offset); - for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) { - /*double buffering to hide load latencies*/ - int next = (curr + 4) % 8; - if (i == 0 && tb == 0) { - load_B(0, origB_offset); - } - if (tb + 4 < (jcp.dimK_4fma -1)) { - load_B(next, origB_offset + 4); - } else if (i < alpha - 1) { - load_B(next, origB_offset + jcp.dimK_4fma); - } - - vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1)); - vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3)); - vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1)); - vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3)); - - vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9)); - vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9)); - - vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1)); - vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1)); - - vmovntps(zword[reg_transB + reg_transB_idx - + sizeof(float) * tb * jcp.dimN_reg_block], - Zmm(curr+2)); - vmovntps(zword[reg_transB + reg_transB_idx - + sizeof(float) * (tb + 1) * jcp.dimN_reg_block], - Zmm(curr+3)); - vmovntps(zword[reg_transB + reg_transB_idx - + sizeof(float) * (tb + 2) * jcp.dimN_reg_block], - Zmm(8)); - vmovntps(zword[reg_transB + reg_transB_idx - + sizeof(float) * (tb + 3) * jcp.dimN_reg_block], - Zmm(9)); - curr = next; - - } - } - } - postamble(); - ret(); -} -void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate( - bool is_first_tile) -{ - // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) - // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++) - // for (int nb_tile_block_ur = 0; nb_tile_block_ur < - // jcp.nb_tile_block_ur; nb_tile_block_ur++) - // for (int tile_block_ur = 0; tile_block_ur < - // jcp.tile_block_ur; tile_block_ur++) - // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3) - // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] += - // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block] - // * - // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur]) - auto inner_loops = [=]() { - int inc_fma = jcp.ver == ver_4fma ? 4 : 1; - const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; - prefetcher_t L1_pf(this, reg_srcB, L1, - jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, - jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma - / inc_fma, - fma_ipc); - prefetcher_t L2_pf(this, reg_srcB, L2, - jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, - jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma - / inc_fma, - fma_ipc); - - auto load_A = [=](int reg_idx, int offset) { - for (int i = 0; i < inc_fma; i++) { - vmovups(Zmm(reg_idx + i), - zword[reg_srcA + - sizeof(float) * jcp.dimM_simd_block * (offset + i)]); - } - }; - - Label dimM_block_loop, dimK_block_loop, dimN_block_loop; - if (jcp.dimM_block > 1) { - mov(reg_dimM_block_loop_cnt, jcp.dimM_block); - L(dimM_block_loop); - } - { /************* OC_block (M) loop ***********/ - if (jcp.dimN_block > 1) { - mov(reg_dimN_block_loop_cnt, jcp.dimN_block); - L(dimN_block_loop); - } - { /*************** IC_block (N) loop *********/ - for (int dimN_reg_block = 0; - dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { - Zmm zmm(jcp.zmm_start + dimN_reg_block); - if (is_first_tile) - vpxord(zmm, zmm, zmm); - else - vmovups(zmm, zword[reg_dstC + - dimN_reg_block * jcp.dimM_simd_block * - sizeof(float)]); - } - - if (jcp.dimK_block > 1) { - mov(reg_dimK_block_loop_cnt, jcp.dimK_block); - L(dimK_block_loop); - } - { /************* nb_tile_ur(K) loop ********/ - int next = 0; - if (jcp.double_buffering) { - load_A(next, 0); - } - for (int dimK_reg_block = 0; - dimK_reg_block < jcp.dimK_reg_block; - dimK_reg_block++) { - int srcB_offset = dimK_reg_block * jcp.dimK_4fma - * jcp.dimN_reg_block; - for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma; - dimK_4fma += inc_fma) { - int current = next; - if (jcp.double_buffering) { - next = (dimK_reg_block * jcp.dimK_4fma - + dimK_4fma + inc_fma) - % (2 * inc_fma); - load_A(next, dimK_reg_block * jcp.dimK_4fma - + dimK_4fma + inc_fma); - } else { - next = 0; - load_A(next, dimK_reg_block * jcp.dimK_4fma - + dimK_4fma); - } - for (int dimN_reg_block = 0; - dimN_reg_block < jcp.dimN_reg_block; - ++dimN_reg_block) { - L1_pf.prefetch(srcB_offset / inc_fma - + dimK_4fma / inc_fma - * jcp.dimN_reg_block - + dimN_reg_block); - L2_pf.prefetch(srcB_offset / inc_fma - + dimK_4fma / inc_fma - * jcp.dimN_reg_block - + dimN_reg_block); - if (jcp.ver == ver_4fma) { - int srcB_trans_offset = (dimK_4fma / 4) * 64 - + dimK_4fma % 4; - v4fmaddps( - Zmm(jcp.zmm_start + dimN_reg_block), - Zmm(current), - EVEX_compress_addr(reg_srcB, - sizeof(float) * ( - srcB_offset + - srcB_trans_offset + - (dimN_reg_block % 4) * 16 + - (dimN_reg_block / 4) * 4))); - } else { - vfmadd231ps( - Zmm(jcp.zmm_start + dimN_reg_block), - Zmm(current), - EVEX_compress_addr(reg_srcB, - sizeof(float) * (srcB_offset + dimN_reg_block), - true)); - } - } - } - } - } - - add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma - * jcp.dimM_simd_block * sizeof(float)); - add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block - * jcp.dimK_4fma * sizeof(float)); - if (jcp.dimK_block > 1) { - sub(reg_dimK_block_loop_cnt, 1); - jnz(dimK_block_loop); - } - - /******** Write C back to memory *******/ - for (int dimN_reg_block = 0; - dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { - Zmm zmm(jcp.zmm_start + dimN_reg_block); - vmovups(zword[reg_dstC + - dimN_reg_block * jcp.dimM_simd_block * sizeof(float)], - zmm); - } - - sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block * - jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); - add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block - * sizeof(float)); - if (jcp.dimN_block > 1) { - sub(reg_dimN_block_loop_cnt, 1); - jnz(dimN_block_loop); - } - } - - if (jcp.dimM_block > 1) { - sub(reg_srcB, jcp.dimN_block * jcp.dimK_block - * jcp.dimK_reg_block * jcp.dimN_reg_block - * jcp.dimK_4fma * sizeof(float)); - add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block - * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); - sub(reg_dimM_block_loop_cnt, 1); - jnz(dimM_block_loop); - } - } - }; - - /* Preamble */ - // register used to handle long fma encoding - preamble(); - mov(reg_srcA, reg_srcA_const); - inner_loops(); - - /* Postamble */ - postamble(); - ret(); -} - -namespace { -bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block, - int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) -{ - float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw; - lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw; - lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; - lhs *= sizeof(float); - float rhs = C * L1_cache_size; - return (lhs <= rhs); -} - -bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block, - int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) -{ - float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma - * dimM_simdw; - lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; - lhs *= sizeof(float); - float rhs = C * L1_cache_size; - return (lhs <= rhs); -} - -bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block, - int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, - float C) -{ - float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block - * dimK_4fma; - lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block - * dimN_reg_block; - lhs *= sizeof(float); - float rhs = C * L2_cache_size; - return (lhs <= rhs); -} - -bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block, - int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, - float C) -{ - float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block; - lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma; - lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block - * dimN_reg_block; - lhs *= sizeof(float); - float rhs = C * L2_cache_size; - return (lhs <= rhs); -} -} // namespace - -status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp) -{ - /*************** Choose dimN_reg_block (ic_simd_block) - * *******************************/ - jcp.dimN = jcp.ic; - /*Hardcoded to 16 because N = ic for bwd weights and - innermost dimension for ic is assumed 16 in src transforms. This - choice covers load latencies while maintaining simplicity of kernel - for POR topologies. FIXME in future??: Will not work for future topologies - when ic%16 != 0*/ - jcp.dimN_reg_block = jcp.ic_simd_block; - - /****************************** Choose dimK_block - * **************************/ - // No freedom for choosing dimM_simd_block because ic_simd_block - // is determined by input data format - jcp.dimM_simd_block = jcp.oc_simd_block; - - auto test_cond1bis_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, - jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) - && (dimK_block > current_best); - }; - - auto test_cond1_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1, - jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) - && (dimK_block > current_best); - }; - - auto test_cond2bis_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, - jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f) - && (dimK_block > current_best); - }; - - auto test_cond2_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1, - jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f) - && (dimK_block > current_best); - }; - - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block); - if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma) - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block); - - jcp.dimK_reg_block = get_divisor_satisfying_cond( - jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block); - if (jcp.dimK_reg_block < jcp.dimK_block) { - jcp.dimK_reg_block = get_divisor_satisfying_cond( - jcp, jcp.dimK_block, 1, test_cond1_dimK_block); - } - jcp.dimK_block /= jcp.dimK_reg_block; - jcp.dimK_nb_block - = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block; - jcp.tile_block_ur = jcp.dimK_reg_block; - jcp.nb_tile_block_ur = jcp.dimK_block; - jcp.tile_block = jcp.dimK_nb_block; - - /***************************** Chose dimN block - * ****************************/ - auto test_cond2_dimN_block = []( - jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { - return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block, - jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block, - jcp.dimN_reg_block, 0.5f) - && (dimN_block > current_best); - }; - - jcp.dimN_block = get_divisor_satisfying_cond( - jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); - jcp.ic_block = jcp.dimN_block; - jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block; - jcp.nb_ic = jcp.dimN_nb_block; - - /********************************* Choose dimM block - * ************************/ - jcp.dimM = jcp.oc; - - auto test_cond1_dimM_block = []( - jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { - return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1, - jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block, - 1.0f) - && (dimM_block > current_best) - && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2; - }; - - jcp.dimM_block = get_divisor_satisfying_cond( - jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); - jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; - - jcp.sched_policy = WSCHED_WEI_S_D_G_W; - return status::success; -} - -status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, - const memory_desc_wrapper &diff_weights_d) -{ - jcp.nthr = mkldnn_get_max_threads(); - - const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; - - jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = diff_dst_d.dims()[2]; - jcp.ow = diff_dst_d.dims()[3]; - jcp.kh = diff_weights_d.dims()[with_groups + 2]; - jcp.kw = diff_weights_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.r_pad = nstl::max( - 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); - jcp.b_pad = nstl::max( - 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - jcp.ohp = jcp.oh; - jcp.owp = jcp.ow; - jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - bool ok_to_pad_channels = jcp.ngroups == 1; - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - if (mayiuse(avx512_core)) - return status::unimplemented; - if (!mayiuse(avx512_common)) - return status::unimplemented; - else if (mayiuse(avx512_mic_4ops)) - jcp.ver = ver_4fma; - else - jcp.ver = ver_fma; - - if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, - is_winograd_faster_than_direct(jcp))) - return status::unimplemented; - // Winograd specific initialization - jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; - jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; - jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; - - // Winograd kernel works only for 3x3 convolution with stride 1 - if (jcp.ngroups != 1) - return status::unimplemented; - if ((jcp.kh != 3) || (jcp.kw != 3)) - return status::unimplemented; - if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) - return status::unimplemented; - if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) - return status::unimplemented; - if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) - return status::unimplemented; - - format_tag_t dat_tag = nChw16c; - format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); - - if (jcp.src_tag != dat_tag) return status::unimplemented; - if (jcp.wei_tag != wei_tag) return status::unimplemented; - if (jcp.dst_tag != dat_tag) return status::unimplemented; - - bool layout_consistency = true - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= diff_dst_d.padded_dims()[1] - && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; - if (!layout_consistency) return status::unimplemented; - - /*************************** New Kernel Parameters - * *****************************/ - jcp.ic_simd_block = simd_w; - jcp.oc_simd_block = simd_w; - jcp.dimK_4fma = 1; - jcp.tile_4fma_padding = 0; - -#define MAX_4FMA_UR 8 - if (jcp.ver == ver_4fma) { - auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma, - int current_best) { - return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR) - && (dimK_4fma > current_best); - }; - jcp.dimK_4fma = get_divisor_satisfying_cond( - jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma); - if (jcp.dimK_4fma == 1) - jcp.dimK_4fma = 4; - if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0) - jcp.tile_4fma_padding = jcp.dimK_4fma - - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma); - } - - jcp.tile_4fma = jcp.dimK_4fma; - /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src - * transform - * will not work correctly, this is solved by applying padding.*/ - jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); - jcp.dimN = jcp.ic; - jcp.dimM = jcp.oc; - - jcp.double_buffering = true; - if (jcp.double_buffering) - jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2; - else - jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1; - jcp.nb_reg = 32 - jcp.zmm_start; - - jcp.sched_policy = WSCHED_INVALID; - status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp); - assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W); - - jcp.tile_block_ur = jcp.dimK_reg_block; - jcp.nb_tile_block_ur = jcp.dimK_block; - jcp.tile_block = jcp.dimK_nb_block; - - jcp.ic_block = jcp.dimN_block; - jcp.nb_ic = jcp.dimN_nb_block; - - jcp.oc_block = jcp.dimM_block; - jcp.nb_oc = jcp.dimM_nb_block; - - return res; - -} -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp deleted file mode 100644 index 6c117143f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp +++ /dev/null @@ -1,179 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP -#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "cpu_memory.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -//alpha determines the output tile_size -constexpr int alpha = 6; -constexpr int tile_size = 4; -//simd length used for vectorization -constexpr int simd_w = 16; - -struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator { - _jit_avx512_common_conv_winograd_data_kernel_f32( - jit_conv_winograd_conf_t ajcp) - : jcp(ajcp) - { - //******************* First iter kernel ********************// - this->gemm_loop_generate(true); - gemm_loop_ker_first_iter - = (decltype(gemm_loop_ker_first_iter)) this->getCode(); - - //************** Subsequent iterations kernel **************// - if (jcp.dimK_nb_block > 1) { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->gemm_loop_generate(false); - gemm_loop_ker = (decltype(gemm_loop_ker))addr; - } - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32) - - static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d); - - static status_t init_conf_kernel( - jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); - - jit_conv_winograd_conf_t jcp; - void (*gemm_loop_ker)(float *, const float *, const float *); - void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); - -protected: - using reg64_t = const Xbyak::Reg64; - enum { typesize = sizeof(float) }; - - void gemm_loop_generate(bool is_beta_zero); - - /* registers used for GEMM */ - reg64_t reg_dstC = abi_param1; - reg64_t reg_srcA = abi_param2; - reg64_t reg_srcB = abi_param3; - - reg64_t reg_dimM_block_loop_cnt = r10; - reg64_t reg_dimK_block_loop_cnt = r11; -}; - -struct jit_avx512_common_conv_winograd_fwd_kernel_f32 - : _jit_avx512_common_conv_winograd_data_kernel_f32 { - using _jit_avx512_common_conv_winograd_data_kernel_f32:: - _jit_avx512_common_conv_winograd_data_kernel_f32; - - static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); - - static status_t init_conf(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); -}; - -struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32 - : public _jit_avx512_common_conv_winograd_data_kernel_f32 { - using _jit_avx512_common_conv_winograd_data_kernel_f32:: - _jit_avx512_common_conv_winograd_data_kernel_f32; - - static status_t init_conf(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d); -}; - -struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 - : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32) - - jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( - jit_conv_winograd_conf_t ajcp) - : jcp(ajcp) - { - - //******************* First iter kernel ********************// - { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->gemm_loop_generate(true); - gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr; - } - - if (jcp.tile_block > 1) { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->gemm_loop_generate(false); - gemm_loop_ker = (decltype(gemm_loop_ker))addr; - } - - if (jcp.ver == ver_4fma) { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->transpose_ker_generate(); - transpose_4fma_ker = (decltype(transpose_4fma_ker))addr; - } - } - - static status_t init_conf(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &diff_dst_d, - const memory_desc_wrapper &diff_weights_d); - - jit_conv_winograd_conf_t jcp; - void (*gemm_loop_ker)(float *, const float *, const float *); - void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); - void (*transpose_4fma_ker)(float *, float *); - -private: - using reg64_t = const Xbyak::Reg64; - enum { typesize = sizeof(float) }; - - void gemm_loop_generate(bool is_first_tile); - void transpose_ker_generate(); - - reg64_t reg_origB = abi_param2; - reg64_t reg_transB = abi_param1; - - reg64_t reg_dstC = abi_param1; - reg64_t reg_srcA_const = abi_param2; - reg64_t reg_srcB = abi_param3; - - reg64_t reg_sp = rsp; - reg64_t reg_srcA = r9; - reg64_t reg_nb_ic = r10; - reg64_t reg_loop_cpt = r11; - reg64_t reg_transB_idx = r13; - - /* Registers used by new kernel */ - reg64_t reg_dimM_block_loop_cnt = r10; - reg64_t reg_dimK_block_loop_cnt = r12; - reg64_t reg_dimN_block_loop_cnt = r11; -}; -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp deleted file mode 100644 index abddc1922..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp +++ /dev/null @@ -1,1526 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_common_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -using namespace nstl; - -using jit_conv_ker_t = void (*)(jit_conv_call_s *); - -#define PIPELINE(field) \ - do { \ - p.field = p.field ## _prf; \ - p.field ## _prf = field; \ - } while (0) - -inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, - const void *src, const void *dst, const void *filt, const void *bias, - int channel, int kh_padding) -{ - PIPELINE(src); - PIPELINE(dst); - PIPELINE(filt); - PIPELINE(bias); - PIPELINE(channel); - PIPELINE(kh_padding); - - if (p.src) - ker(&p); -} -// The special case for the driver with ow-parallelization (FWD) -// TODO: implement it for BWD_D and BWD_W too -inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, - const void *src, const void *dst, const void *filt, const void *bias, - int channel, int kh_padding, int owb) -{ - PIPELINE(src); - PIPELINE(dst); - PIPELINE(filt); - PIPELINE(bias); - PIPELINE(channel); - PIPELINE(kh_padding); - PIPELINE(owb); - - if (p.src) - ker(&p); -} - -inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, - const void *src, const void *dst, const void *filt, const void *bias, - int channel, int kh_padding, int kd_padding) -{ - PIPELINE(src); - PIPELINE(dst); - PIPELINE(filt); - PIPELINE(bias); - PIPELINE(channel); - PIPELINE(kh_padding); - PIPELINE(kd_padding); - - if (p.src) - ker(&p); -} -// The special case for the driver with ow-parallelization (FWD) -// TODO: implement it for BWD_D and BWD_W too -inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker, - jit_conv_call_s &p, const void *src, const void *dst, const void *filt, - const void *bias, int channel, int kh_padding, int kd_padding, int owb) -{ - PIPELINE(src); - PIPELINE(dst); - PIPELINE(filt); - PIPELINE(bias); - PIPELINE(channel); - PIPELINE(kh_padding); - PIPELINE(kd_padding); - PIPELINE(owb); - - if (p.src) - ker(&p); -} - -void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, - const void *src, const void *dst, const void *filt, const void *bias, - int channel, int d_index, int d_worksize, - int kd_padding /* kd_work_size */, size_t kd_offset) { - PIPELINE(src); - PIPELINE(dst); - PIPELINE(filt); - PIPELINE(bias); - PIPELINE(channel); - PIPELINE(kd_padding); - PIPELINE(d_worksize); - PIPELINE(d_index); - PIPELINE(kd_offset); - - if (p.src) - ker(&p); -} -#define wht_blk_off(d, g, ...) \ - (pd()->with_groups() \ - ? (d).blk_off((g), __VA_ARGS__) \ - : (d).blk_off(__VA_ARGS__)) - -template -void jit_avx512_common_convolution_fwd_t::prepare_padded_bias(const dst_data_t *&bias, - const memory_tracking::grantor_t &scratchpad) const { - if (!pd()->wants_padded_bias()) return; - - auto padded_bias = scratchpad.template get( - key_conv_padded_bias); - utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding); - utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, - (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding); - bias = padded_bias; -} - -template -void jit_avx512_common_convolution_fwd_t:: -execute_forward_1d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - prepare_padded_bias(bias, this->scratchpad(ctx)); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = pd()->jcp_; - assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); - - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow; - - int nthr; - if (jcp.aligned_threads) - nthr = jcp.aligned_threads; - else - nthr = mkldnn_get_max_threads(); - - parallel(nthr, [&](const int ithr, const int nthr) { - int start{0}, end{0}, start_copy; - balance211(work_amount, nthr, ithr, start, end); - start_copy = start; - - auto par_conv = jit_conv_call_s(); - size_t src_c_stride = src_d.blk_off(0, 1); - size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); - - for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { - start = start_copy; - int n{0}, g{0}, occ{0}, owb{0}; - - if (jcp.loop_order == loop_cwgn) { - int dummy{0}; - nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, - g, jcp.ngroups, n, jcp.mb, dummy, 1); - } else if (jcp.loop_order == loop_gncw) { - int dummy{0}; - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ, - oc_chunks, owb, jcp.nb_ow, dummy, 1); - } else { - assert(!"unsupported loop order"); - } - - while (start < end) { - int ocb = occ * jcp.nb_oc_blocking; - int g_ocb = g * jcp.nb_oc + ocb; - int g_oc = g_ocb * jcp.oc_block; - int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; - - int ow_s = owb * jcp.ow_block; - int iw_s = ow_s * jcp.stride_w; - auto bias_w = bias ? bias + g_oc : nullptr; - auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s); - auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s); - auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); - - for (int icb = icb_l2; - icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { - jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, - src_w, dst_w, wht_w, bias_w, icb, 1, owb); - - src_w += src_c_stride; - wht_w += wht_ic_stride; - } - if (jcp.loop_order == loop_cwgn) { - int dummy{0}; - nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, - g, jcp.ngroups, n, jcp.mb, dummy, 1); - } else if (jcp.loop_order == loop_gncw) { - int dummy{0}; - nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, - occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); - } else { - assert(!"unsupported loop order"); - } - } - } - jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, - src, dst, weights, bias, 0, 0, 0); - }); -} - -template -void jit_avx512_common_convolution_fwd_t:: -execute_forward_2d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - prepare_padded_bias(bias, this->scratchpad(ctx)); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = pd()->jcp_; - assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); - - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow; - - int nthr; - if (jcp.aligned_threads) - nthr = jcp.aligned_threads; - else - nthr = mkldnn_get_max_threads(); - - parallel(nthr, [&](const int ithr, const int nthr) { - int start{0}, end{0}, start_copy; - balance211(work_amount, nthr, ithr, start, end); - start_copy = start; - - auto par_conv = jit_conv_call_s(); - size_t src_h_stride = src_d.blk_off(0, 0, 1); - size_t src_c_stride = src_d.blk_off(0, 1); - size_t dst_h_stride = dst_d.blk_off(0, 0, 1); - size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); - - for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { - start = start_copy; - int n{0}, g{0}, occ{0}, oh_s{0}, owb{0}; - - if (jcp.loop_order == loop_cwgn) - nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, - g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); - else if (jcp.loop_order == loop_gncw) - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, - occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); - else - assert(!"unsupported loop order"); - - while (start < end) { - int ocb = occ * jcp.nb_oc_blocking; - int g_ocb = g * jcp.nb_oc + ocb; - int g_oc = g_ocb * jcp.oc_block; - int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; - - int work_rem = end - start; - - int ow_s = owb * jcp.ow_block; - int iw_s = ow_s * jcp.stride_w; - int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; - auto bias_w = bias ? bias + g_oc : nullptr; - - for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) { - int ih_b = -jcp.t_pad + oh_b * jcp.stride_h; - - auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s); - auto src_w - = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s); - auto wht_w - = weights + wht_blk_off(weights_d, g, ocb, icb_l2); - - for (int icb = icb_l2; - icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); - ++icb) { - auto src_c = src_w; - auto dst_c = dst_w; - for (int oj = oh_b, ij = ih_b; - oj < min(oh_e, oh_b + jcp.h_blocking); - ++oj, ij += jcp.stride_h) { - int dilate_h = jcp.dilate_h + 1; - int i_t_overflow = div_up(max(0, -ij), dilate_h); - int i_b_overflow = div_up(max(0, ij - jcp.ih - + (jcp.kh - 1) * dilate_h + 1), dilate_h); - int kh_padding = nstl::max( - 0, jcp.kh - i_t_overflow - i_b_overflow); - - auto aux_src = src_c - + i_t_overflow * dilate_h * src_h_stride; - auto aux_wht = wht_w + i_t_overflow * wht_h_stride; - - jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, - par_conv, aux_src, dst_c, aux_wht, bias_w, icb, - kh_padding, owb); - - src_c += src_h_stride * jcp.stride_h; - dst_c += dst_h_stride; - } - src_w += src_c_stride; - wht_w += wht_ic_stride; - } - } - - if (jcp.loop_order == loop_cwgn) - nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, - g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); - else if (jcp.loop_order == loop_gncw) - nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ, - oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); - else - assert(!"unsupported loop order"); - } - } - - jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, - src, dst, weights, bias, 0, 0, 0); - }); -} - -template -void jit_avx512_common_convolution_fwd_t:: -execute_forward_3d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - prepare_padded_bias(bias, this->scratchpad(ctx)); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const auto &jcp = pd()->jcp_; - assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); - - parallel(0, [&](const int ithr, const int nthr) { - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int start{0}, end{0}, start_copy; - int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh - * jcp.nb_ow; - balance211(work_amount, nthr, ithr, start, end); - start_copy = start; - - auto par_conv = jit_conv_call_s(); - size_t src_d_stride = src_d.blk_off(0, 0, 1); - size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); - size_t src_c_stride = src_d.blk_off(0, 1); - size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); - size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); - size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); - - for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { - start = start_copy; - int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0}; - - if (jcp.loop_order == loop_cwgn) - nd_iterator_init(start, - occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, - od_s, jcp.od, oh_s, jcp.oh); - else if (jcp.loop_order == loop_gncw) - nd_iterator_init(start, - g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, - od_s, jcp.od, oh_s, jcp.oh); - else - assert(!"unsupported loop order"); - - while (start < end) { - int ocb = occ * jcp.nb_oc_blocking; - int g_ocb = g * jcp.nb_oc + ocb; - int g_oc = g_ocb * jcp.oc_block; - int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; - - int work_rem = end - start; - int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; - int ow_s = owb * jcp.ow_block; - int iw_s = ow_s * jcp.stride_w; - int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; - - int id_s = -jcp.f_pad + od_s * jcp.stride_d; - - int dilate_d = jcp.dilate_d + 1; - int d_t_overflow = div_up(max(0, -id_s), dilate_d); - int d_b_overflow = div_up( - max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), - dilate_d); - int kd_padding = nstl::max(0, - jcp.kd - d_t_overflow - d_b_overflow); - - auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0; - auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s); - auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s, - iw_s) + d_t_overflow * dilate_d * src_d_stride; - auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2) - + d_t_overflow * wht_d_stride; - - for (int icb = icb_l2; - icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { - auto src_c = src_w; - auto dst_c = dst_w; - for (int oj = oh_s, ij = ih_s; - oj < oh_e; ++oj, ij += jcp.stride_h) - { - int dilate_h = jcp.dilate_h + 1; - int i_t_overflow = div_up(max(0, -ij), dilate_h); - int i_b_overflow = div_up( - max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h - + 1), - dilate_h); - int kh_padding = nstl::max(0, - jcp.kh - i_t_overflow - i_b_overflow); - jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker, - par_conv, - src_c + i_t_overflow * dilate_h * src_h_stride, - dst_c, wht_w + i_t_overflow * wht_h_stride, - bias_w, icb, kh_padding, kd_padding, owb); - - src_c += src_h_stride * jcp.stride_h; - dst_c += dst_h_stride; - } - src_w += src_c_stride; - wht_w += wht_ic_stride; - } - - if (jcp.loop_order == loop_cwgn) - nd_iterator_jump(start, end, - occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, - od_s, jcp.od, oh_s, jcp.oh); - else if (jcp.loop_order == loop_gncw) - nd_iterator_jump(start, end, - g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, - od_s, jcp.od, oh_s, jcp.oh); - else - assert(!"unsupported loop order"); - } - } - jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, - src, dst, weights, bias, 0, 0, 0); - }); -} - -template struct jit_avx512_common_convolution_fwd_t; - -template -void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_1d(const exec_ctx_t &ctx) const -{ - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - - parallel(0, [&](const int ithr, const int nthr) { - int start{0}, end{0}, start_copy; - int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; - int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; - balance211(work_amount, nthr, ithr, start, end); - start_copy = start; - - auto par_conv = jit_conv_call_s(); - size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); - size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); - - for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { - start = start_copy; - int n{0}, g{0}, icc{0}; - if (jcp.loop_order == loop_cgn) { - int dummy{0}; - nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n, - jcp.mb, dummy, 1); - } else if (jcp.loop_order == loop_gnc) { - int dummy{0}; - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc, - ic_chunks, dummy, 1); - } else { - assert(!"unsupported loop order"); - } - - while (start < end) { - int icb = icc * jcp.nb_ic_blocking; - int g_icb = g * jcp.nb_ic + icb; - int g_ocb = g * jcp.nb_oc; - - auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); - auto diff_dst_w = diff_dst - + diff_dst_d.blk_off(n, g_ocb + ocb_l2); - auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); - - for (int ocb = ocb_l2; - ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { - jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src_w, diff_dst_w, wht_w, 0, ocb, 1); - diff_dst_w += diff_dst_c_stride; - wht_w += wht_oc_stride; - } - - if (jcp.loop_order == loop_cgn) { - int dummy{0}; - nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups, - n, jcp.mb, dummy, 1); - } else if (jcp.loop_order == loop_gnc) { - int dummy{0}; - nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc, - ic_chunks, dummy, 1); - } else { - assert(!"unsupported loop order"); - } - } - } - - jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src, diff_dst, weights, 0, 0, 1); - }); -} - -template -void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_2d(const exec_ctx_t &ctx) const -{ - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - - parallel(0, [&](const int ithr, const int nthr) { - int start{0}, end{0}, start_copy; - int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; - int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; - balance211(work_amount, nthr, ithr, start, end); - start_copy = start; - - auto par_conv = jit_conv_call_s(); - size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); - size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); - size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); - size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); - - bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; - - for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { - start = start_copy; - int n{0}, g{0}, icc{0}, ih_s{0}; - if (jcp.loop_order == loop_cgn) - nd_iterator_init(start, - icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); - else if (jcp.loop_order == loop_gnc) - nd_iterator_init(start, - g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); - else - assert(!"unsupported loop order"); - - while (start < end) { - int icb = icc * jcp.nb_ic_blocking; - int g_icb = g * jcp.nb_ic + icb; - int g_ocb = g * jcp.nb_oc; - - int work_rem = end - start; - int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; - - auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); - auto diff_dst_w = diff_dst - + diff_dst_d.blk_off(n, g_ocb + ocb_l2); - auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); - - for (int ocb = ocb_l2; - ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { - for (int ij = ih_s; ij < ih_e; ++ij) { - int oj, k_len, k_lo; - if (is_fast_path) { // dilate == 0 && stride == 1 - int i_t_overflow = max(0, jcp.kh - 1 - ij - - jcp.t_pad); - int i_b_overflow = max(0, jcp.kh - jcp.ih + ij - - jcp.b_pad); - k_len = jcp.kh - i_t_overflow - i_b_overflow; - k_lo = i_b_overflow; - oj = ij + jcp.t_pad - i_b_overflow; - } else if (jcp.dilate_h != 0) { // stride == 1 - int dilate_h = jcp.dilate_h + 1; - // Note: use div_up to account for "holes" in filter - int i_t_overflow - = div_up(max(0, (jcp.kh - 1) * dilate_h - - ij - jcp.t_pad), dilate_h); - int i_b_overflow - = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - - jcp.ih + ij - jcp.b_pad), dilate_h); - k_len = jcp.kh - i_t_overflow - i_b_overflow; - k_lo = i_b_overflow; - oj = ij + jcp.t_pad - i_b_overflow * dilate_h; - } else { // dilate == 0 - int i_t_overflow = max(0, (jcp.kh - 1 - ij - - jcp.t_pad) / jcp.stride_h); - int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij - - jcp.b_pad) / jcp.stride_h); - int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 - + jcp.b_pad - ij) % jcp.stride_h); - int overflow_kh_lo = (ij + jcp.t_pad) - % jcp.stride_h; - - k_len = (overflow_kh_hi - overflow_kh_lo) - / jcp.stride_h + 1 - i_t_overflow - - i_b_overflow; - k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; - oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; - } - assert(k_len >= 0); - - jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src_w + ij * diff_src_h_stride, - diff_dst_w + oj * diff_dst_h_stride, - wht_w + k_lo * wht_h_stride, - 0, ocb, k_len); - } - diff_dst_w += diff_dst_c_stride; - wht_w += wht_oc_stride; - } - - if (jcp.loop_order == loop_cgn) - nd_iterator_jump(start, end, - icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); - else if (jcp.loop_order == loop_gnc) - nd_iterator_jump(start, end, - g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); - else - assert(!"unsupported loop order"); - } - } - - jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src, diff_dst, weights, 0, 0, 1); - }); -} - -template -void jit_avx512_common_convolution_bwd_data_t::execute_backward_data_3d(const exec_ctx_t &ctx) const -{ - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - - parallel(0, [&](const int ithr, const int nthr) { - int start{0}, end{0}, start_copy; - int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; - int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih; - balance211(work_amount, nthr, ithr, start, end); - start_copy = start; - - auto par_conv = jit_conv_call_s(); - size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); - size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1); - size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); - size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1); - size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); - size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); - size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); - - bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; - bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; - - for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { - start = start_copy; - int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0}; - if (jcp.loop_order == loop_cgn) - nd_iterator_init(start, - icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, - ih_s, jcp.ih); - else if (jcp.loop_order == loop_gnc) - nd_iterator_init(start, - g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, - ih_s, jcp.ih); - else - assert(!"unsupported loop order"); - - while (start < end) { - int icb = icc * jcp.nb_ic_blocking; - int g_icb = g * jcp.nb_ic + icb; - int g_ocb = g * jcp.nb_oc; - - int work_rem = end - start; - int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; - int d_len = 0, d_lo = 0, d_oj = 0; - if (is_fast_path_d) { // dilate == 0 && stride == 1 - int d_t_overflow = max(0, jcp.kd - 1 - id_s - - jcp.f_pad); - int d_b_overflow = max(0, jcp.kd - jcp.id + id_s - - jcp.back_pad); - d_len = jcp.kd - d_t_overflow - d_b_overflow; - d_lo = d_b_overflow; - d_oj = id_s + jcp.f_pad - d_b_overflow; - } else if (jcp.dilate_d != 0) { // stride == 1 - int dilate_d = jcp.dilate_d + 1; - // Note: use div_up to account for "holes" in filter - int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d - - id_s - jcp.f_pad), dilate_d); - int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1 - - jcp.id + id_s - jcp.back_pad), dilate_d); - d_len = jcp.kd - d_t_overflow - d_b_overflow; - d_lo = d_b_overflow; - d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d; - } else { // dilate == 0 - int d_t_overflow = max(0, (jcp.kd - 1 - id_s - - jcp.f_pad) / jcp.stride_d); - int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s - - jcp.back_pad) / jcp.stride_d); - int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1 - + jcp.back_pad - id_s) % jcp.stride_d); - int overflow_kd_lo = (id_s + jcp.f_pad) - % jcp.stride_d; - - d_len = (overflow_kd_hi - overflow_kd_lo) - / jcp.stride_d + 1 - d_t_overflow - - d_b_overflow; - d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; - d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d; - } - assert(d_len >= 0); - - auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb) - + id_s * diff_src_d_stride; - auto diff_dst_w = diff_dst - + diff_dst_d.blk_off(n, g_ocb + ocb_l2) - + d_oj * diff_dst_d_stride; - auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb) - + d_lo * wht_d_stride; - - for (int ocb = ocb_l2; - ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { - for (int ij = ih_s; ij < ih_e; ++ij) { - int oj, k_len, k_lo; - if (is_fast_path_h) { // dilate == 0 && stride == 1 - int i_t_overflow = max(0, jcp.kh - 1 - ij - - jcp.t_pad); - int i_b_overflow = max(0, jcp.kh - jcp.ih + ij - - jcp.b_pad); - k_len = jcp.kh - i_t_overflow - i_b_overflow; - k_lo = i_b_overflow; - oj = ij + jcp.t_pad - i_b_overflow; - } else if (jcp.dilate_h != 0) { // stride == 1 - int dilate_h = jcp.dilate_h + 1; - // Note: use div_up to account for "holes" in filter - int i_t_overflow - = div_up(max(0, (jcp.kh - 1) * dilate_h - - ij - jcp.t_pad), dilate_h); - int i_b_overflow - = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - - jcp.ih + ij - jcp.b_pad), dilate_h); - k_len = jcp.kh - i_t_overflow - i_b_overflow; - k_lo = i_b_overflow; - oj = ij + jcp.t_pad - i_b_overflow * dilate_h; - } else { // dilate == 0 - int i_t_overflow = max(0, (jcp.kh - 1 - ij - - jcp.t_pad) / jcp.stride_h); - int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij - - jcp.b_pad) / jcp.stride_h); - int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 - + jcp.b_pad - ij) % jcp.stride_h); - int overflow_kh_lo = (ij + jcp.t_pad) - % jcp.stride_h; - - k_len = (overflow_kh_hi - overflow_kh_lo) - / jcp.stride_h + 1 - i_t_overflow - - i_b_overflow; - k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; - oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; - } - assert(k_len >= 0); - - jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src_w + ij * diff_src_h_stride, - diff_dst_w + oj * diff_dst_h_stride, - wht_w + k_lo * wht_h_stride, - 0, ocb, k_len, d_len); - } - diff_dst_w += diff_dst_c_stride; - wht_w += wht_oc_stride; - } - - if (jcp.loop_order == loop_cgn) - nd_iterator_jump(start, end, - icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, - ih_s, jcp.ih); - else if (jcp.loop_order == loop_gnc) - nd_iterator_jump(start, end, - g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, - ih_s, jcp.ih); - else - assert(!"unsupported loop order"); - } - } - - jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, - diff_src, diff_dst, weights, 0, 0, 1, 1); - }); -} - -template struct jit_avx512_common_convolution_bwd_data_t; - -template -jit_avx512_common_convolution_bwd_weights_t:: -jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd), kernel_(nullptr) - , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) -{ - const auto &j = pd()->jcp_; - - nthr_ = j.nthr; - nthr_mb_ = j.nthr_mb; - nthr_g_ = j.nthr_g; - nthr_oc_b_ = j.nthr_oc_b; - nthr_ic_b_ = j.nthr_ic_b; - - kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j); - - if (j.ver == ver_4fma) - trans_kernel_ = create_trans_src(&j); - - if (nthr_mb_ > 1) - acc_ker_ = new cpu_accumulator_1d_t(); - - reducer_bias_ = - new cpu_reducer_t(pd()->reducer_bia_conf_); -} - -template -struct jit_avx512_common_convolution_bwd_weights_t::thread_info_t { - const src_data_t *src; - const diff_dst_data_t *diff_dst; - const diff_weights_data_t *diff_weights; - diff_weights_data_t *diff_bias; - - const memory_tracking::grantor_t scratchpad; - - src_data_t *tr_src; - simple_barrier::ctx_t *tr_src_bctx; - - diff_dst_data_t *tr_diff_dst; - simple_barrier::ctx_t *tr_diff_dst_bctx; - - diff_weights_data_t *wei_bia_reduction; - simple_barrier::ctx_t *wei_bia_reduction_bctx; - - int ithr; - int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb; - int ithr_but_oc; - int ithr_but_ic; - - int img_start = 0, img_end = 0, img_work; - int g_start = 0, g_end = 0, g_work; - int oc_b_start = 0, oc_b_end = 0, oc_b_work; - int ic_b_start = 0, ic_b_end = 0, ic_b_work; - - thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self, - const exec_ctx_t &ctx, int ithr) - : scratchpad(self->scratchpad(ctx)), ithr(ithr) - { - diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - diff_bias = self->pd()->wants_padded_bias() - ? scratchpad.template get( - key_conv_padded_bias) - : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); - - tr_src = scratchpad.template get(key_conv_tr_src); - tr_src_bctx = scratchpad.template get( - key_conv_tr_src_bctx); - - tr_diff_dst = scratchpad.template get( - key_conv_tr_diff_dst); - tr_diff_dst_bctx = scratchpad.template get( - key_conv_tr_diff_dst_bctx); - - wei_bia_reduction = scratchpad.template get( - key_conv_wei_bia_reduction); - wei_bia_reduction_bctx = scratchpad.template get( - key_conv_wei_bia_reduction_bctx); - - ithr_ic_b = ithr % self->nthr_ic_b_; - ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; - ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; - ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; - - ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ - + ithr_ic_b; - - ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ - + ithr_oc_b; - - const auto &jcp = self->kernel_->jcp; - - /* reduction dimension */ - balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end); - img_work = img_end - img_start; - - /* independent dimensions */ - balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); - g_work = g_end - g_start; - - balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, - oc_b_end); - oc_b_work = oc_b_end - oc_b_start; - - balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, - ic_b_end); - ic_b_work = ic_b_end - ic_b_start; - } -}; - -template -void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights(const thread_info_t *ti) const { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd; - - diff_weights_data_t *diff_wei = ti->ithr_mb == 0 - ? (diff_weights_data_t*)ti->diff_weights - : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; - diff_weights_data_t *diff_bia = ti->ithr_mb == 0 - ? (diff_weights_data_t*)ti->diff_bias - : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size - + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; - - // TODO: use memory descriptor with the same fmt as src (or use a macro :)) - auto tr_src_off = [&](int ithr_mb, int ic, int ij) { - const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; - const size_t tr_chn_size = tr_row_size * jcp.ih; - const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups; - - return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size; - }; - - auto uker_trans = [&](int img) { - const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih; - - int start{0}, end{0}; - balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); - const int my_work = end - start; - - int g{0}, ic_b{0}, j{0}; - nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih); - g += ti->g_start; - ic_b += ti->ic_b_start; - - const int _ic = g * jcp.nb_ic + ic_b; - src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)]; - src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)]; - - assert(jcp.ic_block == 16); - const int src_stride = jcp.iw * jcp.ic_block; - const int tr_src_stride = jcp.tr_iw * jcp.ic_block; - - const int pf_depth = 2; - struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth]; - - for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) { - pf_circ_buf[iwork % pf_depth] = {src1, tr_src1}; - - if (iwork >= pf_depth - 1) { - int old_idx = (iwork - pf_depth + 1) % pf_depth; - auto ctx = jit_trans_src_t::ctx_t(); - ctx.src = pf_circ_buf[old_idx].src; - ctx.tr_src = pf_circ_buf[old_idx].tr_src; - ctx.src_prf = src1; - ctx.tr_src_prf = tr_src1; - (*trans_kernel_)(&ctx); - } - src1 += src_stride; - tr_src1 += tr_src_stride; - } -#if 0 - // reference transposition - const int l_pad = jcp.l_pad; - const int iwlp = l_pad + jcp.iw; - const int tr_iw = jcp.tr_iw; - - for (size_t iwork = start; iwork < end; iwork++) { - PRAGMA_OMP_SIMD() -# pragma unroll - for (int i = 0; i < l_pad; i++) - for (int j = 0; j < jcp.ic_block; j++) - tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; - - PRAGMA_OMP_SIMD() -# pragma unroll - for (int i = l_pad; i < iwlp; i++) - for (int j = 0; j < jcp.ic_block; j++) - tr_src1[j * jcp.tr_iw + i] - = (src_data_t)src1[(i - l_pad) * 16 + j]; - - PRAGMA_OMP_SIMD() -# pragma unroll - for (int i = iwlp; i < tr_iw; i++) - for (int j = 0; j < jcp.ic_block; j++) - tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; - - src1 += src_stride; - tr_src1 += tr_src_stride; - } -#endif - }; - - if (jcp.is_1stconv && jcp.ver == ver_4fma) { - /* prepare contexts */ - auto tr_ctx = jit_trans_src_t::ctx_t(); - tr_ctx.tr_src = ti->tr_src - + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld; - - assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1)); - tr_ctx.nthr_oc_b = nthr_oc_b_; - int ih_start{0}, ih_end{0}; - balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end); - tr_ctx.tr_src_ih_start = ih_start; - tr_ctx.tr_src_ih_end = ih_end; - tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc; - - auto p = jit_conv_call_s(); - p.src = tr_ctx.tr_src; - - /* zero diff_bias if applicable */ - if (jcp.with_bias && ti->ithr_ic_b == 0) { - assert(jcp.oc_block == 16); - for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) { - diff_weights_data_t *db = &diff_bia[oc_b * 16]; - for (int o = 0; o < 16; ++o) - db[o] = 0; - } - } - - for (int img = ti->img_start; img < ti->img_end; ++img) { - p.flags = (img == ti->img_start) * FLAG_MB_FIRST; - - for (int g = ti->g_start; g < ti->g_end; ++g) { - for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { - const int _ic = g * jcp.nb_ic + ic_b; - tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)]; - - (*trans_kernel_)(&tr_ctx); - - if (ic_b == 0) - p.flags |= FLAG_IC_FIRST; - else - p.flags &= ~FLAG_IC_FIRST; - - for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { - const int _oc = g * jcp.nb_oc + oc_b; - p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; - - const size_t off = - wht_blk_off(diff_weights_d, g, oc_b, ic_b); - p.filt = diff_wei + off; - p.bias = diff_bia + _oc * jcp.oc_block; - - kernel_->jit_ker(&p); - } - } - } - } - } else { - for (int img = ti->img_start; img < ti->img_end; ++img) { - auto p = jit_conv_call_s(); - - if (jcp.ver == ver_4fma) { - /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ - using simple_barrier::barrier; - if (nthr_oc_b_ > 1) - barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); - uker_trans(img); - if (nthr_oc_b_ > 1) - barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); - } - - for (int g = ti->g_start; g < ti->g_end; ++g) { - for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { - for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { - const int _oc = g * jcp.nb_oc + oc_b; - const int _ic = g * jcp.nb_ic + ic_b; - - jit_conv_ker_pipeline(kernel_->jit_ker, p, - jcp.ver == ver_4fma - ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] - : &ti->src[src_d.blk_off(img, _ic)], - &ti->diff_dst[diff_dst_d.blk_off(img, _oc)], - diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), - 0, (img == ti->img_start), 0); - - } - } - } - - const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; - const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; - jit_conv_ker_pipeline(kernel_->jit_ker, p, - jcp.ver == ver_4fma - ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] - : &ti->src[src_d.blk_off(img + 1, _ic)], - &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], - diff_wei + wht_blk_off( - diff_weights_d, ti->g_start, - ti->oc_b_start, ti->ic_b_start), - 0, 0, 0); - } - } -} - -template -void jit_avx512_common_convolution_bwd_weights_t::compute_diff_weights_3d(const thread_info_t *ti) const -{ - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - const int wei_size - = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd; - - diff_weights_data_t *diff_wei = ti->ithr_mb == 0 - ? (diff_weights_data_t*)ti->diff_weights - : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; - diff_weights_data_t *diff_bia = ti->ithr_mb == 0 - ? (diff_weights_data_t*)ti->diff_bias - : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size - + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; - - const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; - const int input_step = jcp.ih * jcp.iw * inp_mult; - const int output_step = jcp.ow * jcp.oh * jcp.oc_block; - int img{0}, od_s{0}; - int img_start = ti->img_start, img_end = ti->img_end; - nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); - const int img_first = img; - - while (img_start < img_end) { - auto p = jit_conv_call_s(); - - int work_rem = img_end - img_start; - const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; - const int id_s = od_s * jcp.stride_d; - const int ik_overlap = nstl::max(0, id_s - jcp.f_pad); - const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s); - const int kd_back_pad - = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd); - int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw - * jcp.ic_block * jcp.oc_block * jcp.typesize_out; - - for (int g = ti->g_start; g < ti->g_end; ++g) { - for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { - for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { - const int _oc = g * jcp.nb_oc + oc_b; - const int _ic = g * jcp.nb_ic + ic_b; - - auto src = &ti->src[src_d.blk_off(img, _ic) - + ik_overlap * input_step]; - auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc) - + od_s * output_step]; - - jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst, - diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), - diff_bia + _oc * 16, (img == img_first), od_s, od_e, - jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off); - - if (ic_b == 0) p.flags = 0; - else p.flags = 1; - } - } - } - - const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; - const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; - jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, - &ti->src[src_d.blk_off(img + 1, _ic)], - &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], - diff_wei + wht_blk_off(diff_weights_d, ti->g_start, - ti->oc_b_start, ti->ic_b_start), - diff_bia, 0, 0, 0, 0, 0); - nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); - } -} - -template -void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights(const thread_info_t *ti) const { - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; - const int bia_size = jcp.ngroups * jcp.oc; - const diff_weights_data_t *diff_bias_ws - = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size; - - /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ - simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); - - const int ic_b_kh_work = ti->ic_b_work * jcp.kh; - const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; - - int start{0}, end{0}; - balance211(work, nthr_mb_, ti->ithr_mb, start, end); - if (start == end) return; - - for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { - int w = start; - int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; - nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, - ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); - while (w < end) { - const int g = ti->g_start + sub_g_start; - const int oc_b = ti->oc_b_start + sub_oc_b_start; - const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh; - const int kh = sub_ic_b_kh_start % jcp.kh; - - const int acc_size - = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) - * jcp.kw * jcp.ic_block * jcp.oc_block; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh); - - diff_weights_data_t *d - = (diff_weights_data_t *)ti->diff_weights + off; - diff_weights_data_t *s - = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; - - acc_ker_->accumulate(d, s, acc_size); - - nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, - ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); - } - - if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) { - if (ti->ithr == 0) - acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias, - diff_bias_ws, bia_size); - diff_bias_ws += bia_size; - } - } -} - -template -void jit_avx512_common_convolution_bwd_weights_t::reduce_diff_weights_3d(const thread_info_t *ti) const { - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - - const auto &jcp = kernel_->jcp; - const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw - * jcp.kd; - - /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ - simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); - - const int ic_b_kh_work = ti->ic_b_work * jcp.kd; - const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; - - int start{0}, end{0}; - balance211(work, nthr_mb_, ti->ithr_mb, start, end); - if (start == end) return; - - for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { - int w = start; - int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; - nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, - ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); - while (w < end) { - const int g = ti->g_start + sub_g_start; - const int oc_b = ti->oc_b_start + sub_oc_b_start; - const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd; - const int kd = sub_ic_b_kh_start % jcp.kd; - - const int acc_size - = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) - * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh; - - const size_t off - = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd); - diff_weights_data_t *d - = (diff_weights_data_t *)ti->diff_weights + off; - diff_weights_data_t *s - = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; - acc_ker_->accumulate(d, s, acc_size); - - nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, - ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); - } - } -} - -template -void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias(const thread_info_t *ti) const { - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - - auto rb = this->reducer_bias_; - assert(nthr_ == rb->balancer().nthr_); - - const auto reducer_bia_scratchpad = memory_tracking::grantor_t( - ti->scratchpad, prefix_reducer_bia); - - const auto &jcp = kernel_->jcp; - - if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return; - - const int b_job_start = rb->balancer().ithr_job_off(ti->ithr); - const int b_njobs = rb->balancer().ithr_njobs(ti->ithr); - - if (b_njobs == 0) return; - - /* reduction dimension */ - int img_start{0}, img_end{0}; - balance211(jcp.mb, rb->balancer().nthr_per_group_, - rb->balancer().id_in_group(ti->ithr), img_start, img_end); - - /* jobs */ - int g_start{0}, ocb_start{0}; - nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); - for (int img = img_start; img < img_end; ++img) { - int g = g_start, ocb = ocb_start; - for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { - const size_t _oc = g * jcp.nb_oc + ocb; - - const diff_dst_data_t *d_dst - = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; - diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr, - ti->diff_bias, reducer_bia_scratchpad) - + b_job_loc * rb->balancer().job_size_; - - if (img == img_start) - for (int o = 0; o < 16; ++o) - d_bias[o] = 0; - for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) { - PRAGMA_OMP_SIMD() - for (int o = 0; o < 16; ++o) - d_bias[o] += d_dst[o]; - d_dst += 16; - } - - nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); - } - } - - rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad); -} - -template -void jit_avx512_common_convolution_bwd_weights_t::compute_diff_bias_3d(const thread_info_t *ti) const { - - const auto &jcp = kernel_->jcp; - - const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh - * jcp.kw * jcp.kd; - const int bia_size = jcp.ngroups * jcp.oc; - const diff_weights_data_t *diff_bias_ws - = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size; - - if (nthr_mb_ > 1) mkldnn_thr_barrier(); - - if (ti->ithr == 0) - { - for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { - acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size); - diff_bias_ws += bia_size; - } - } -} - -template -void jit_avx512_common_convolution_bwd_weights_t::prepare_scratchpad_data(const exec_ctx_t &ctx) const -{ - const auto &j = pd()->jcp_; - auto scratchpad = this->scratchpad(ctx); - - if (j.ver == ver_4fma) { - if (!j.is_1stconv) { - // XXX: See the comment about tr_iw and guarding elements in - // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() - const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic; - const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw; - - auto tr_src = scratchpad.template get(key_conv_tr_src); - /* to avoid NaNs in computations we zero tail num_guard_elems for - * each possible thread group */ - - for (int ithr = 1; ithr <= max_nthr; ++ithr) { - src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr]; - for (int i = 0; i < j.tr_src_num_guard_elems; ++i) - ts[i] = 0; - } - } - - if (j.nthr_oc_b > 1) { - const int tr_src_bctx_size = j.nthr / j.nthr_oc_b; - auto tr_src_bctx = scratchpad.template get( - key_conv_tr_src_bctx); - for (int i = 0; i < tr_src_bctx_size; ++i) - simple_barrier::ctx_init(&tr_src_bctx[i]); - } - } - - if (nthr_mb_ > 1) { - simple_barrier::ctx_init(scratchpad.template get( - key_conv_wei_bia_reduction_bctx)); - } - - const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, - prefix_reducer_bia); - auto rb = this->reducer_bias_; - rb->init(reducer_bia_scratchpad); -} - -template -void jit_avx512_common_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { - prepare_scratchpad_data(ctx); - - parallel(nthr_, [&](const int ithr, const int nthr) { - assert(nthr_ == nthr); - - thread_info_t thread_info(this, ctx, ithr); - - if (utils::one_of(pd()->ndims(), 3, 4)) { - compute_diff_weights(&thread_info); - if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); - if (pd()->with_bias()) compute_diff_bias(&thread_info); - } else if (pd()->ndims() == 5) { - compute_diff_weights_3d(&thread_info); - if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); - if (pd()->with_bias()) compute_diff_bias_3d(&thread_info); - } else { - assert(false); - } - }); - - /* TODO: put that into compute_diff_bias() */ - if (pd()->wants_padded_bias()) { - auto diff_bias = scratchpad(ctx).template get( - key_conv_padded_bias); - auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); - for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) - diff_bias_in[oc] = diff_bias[oc]; - } -} - -template struct jit_avx512_common_convolution_bwd_weights_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp deleted file mode 100644 index 3341c3ebe..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp +++ /dev/null @@ -1,302 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP -#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_barrier.hpp" -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" -#include "cpu_reducer.hpp" - -#include "jit_transpose_src_utils.hpp" -#include "jit_avx512_common_conv_kernel.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), - jit_avx512_common_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, wei_type, dst_type, dst_type, - data_type::undef) - && !has_zero_dim_memory(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_common_conv_fwd_kernel::init_conf( - jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, - *attr(), mkldnn_get_max_threads()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad, - jcp_); - - return status; - } - - jit_conv_conf_t jcp_; - }; - - jit_avx512_common_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - { - kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_, - *pd()->attr()); - } - ~jit_avx512_common_convolution_fwd_t() { delete kernel_; } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if (pd()->ndims() == 3) - execute_forward_1d(ctx); - else if (pd()->ndims() == 4) - execute_forward_2d(ctx); - else if (pd()->ndims() == 5) - execute_forward_3d(ctx); - else - assert(false); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); - - return status::success; - } - -private: - void prepare_padded_bias(const dst_data_t *&bias, - const memory_tracking::grantor_t &scratchpad) const; - void execute_forward_1d(const exec_ctx_t &ctx) const; - void execute_forward_2d(const exec_ctx_t &ctx) const; - void execute_forward_3d(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_common_conv_fwd_kernel *kernel_; -}; - -template -struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), - jit_avx512_common_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(diff_src_type, wei_type, - data_type::undef, diff_dst_type, data_type::undef) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = - jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_, - *desc(), *diff_src_md(), *weights_md(), *diff_dst_md()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( - scratchpad, jcp_); - - return status::success; - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); - auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), - OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i, - OIdhw16o16i, gOIdhw16o16i); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - jit_avx512_common_convolution_bwd_data_t(const pd_t *apd) - : cpu_primitive_t(apd) - { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); } - ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; }; - - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_src_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if (pd()->ndims() == 3) - execute_backward_data_1d(ctx); - else if (pd()->ndims() == 4) - execute_backward_data_2d(ctx); - else if (pd()->ndims() == 5) - execute_backward_data_3d(ctx); - else - assert(false); - return status::success; - } - -private: - void execute_backward_data_1d(const exec_ctx_t &ctx) const; - void execute_backward_data_2d(const exec_ctx_t &ctx) const; - void execute_backward_data_3d(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_; -}; - -template -struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), - jit_avx512_common_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, diff_weights_type, - diff_weights_type, diff_dst_type, data_type::undef) - && !has_zero_dim_memory(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32:: - init_conf(jcp_, *desc(), src_md_, diff_weights_md_, - diff_bias_md_, diff_dst_md_); - if (status != status::success) return status; - - init_balancers(); - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( - scratchpad, jcp_); - - auto reducer_bia_scratchpad = memory_tracking::registrar_t( - scratchpad, memory_tracking::names::prefix_reducer_bia); - reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); - - return status; - } - - jit_conv_conf_t jcp_; - typename cpu_reducer_t::conf_t reducer_bia_conf_; - - private: - void init_balancers() { - const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; - if (with_bias()) { - reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, - jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, - max_buffer_size)); - } - } - }; - - jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd); - ~jit_avx512_common_convolution_bwd_weights_t() { - delete kernel_; - if (trans_kernel_) - delete trans_kernel_; - if (acc_ker_) - delete acc_ker_; - delete reducer_bias_; - } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type diff_weights_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - void prepare_scratchpad_data(const exec_ctx_t &ctx) const; - struct thread_info_t; - void compute_diff_weights(const thread_info_t *) const; - void compute_diff_weights_3d(const thread_info_t *) const; - void reduce_diff_weights(const thread_info_t *) const; - void reduce_diff_weights_3d(const thread_info_t *) const; - void compute_diff_bias(const thread_info_t *) const; - void compute_diff_bias_3d(const thread_info_t *) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_; - - jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_; - jit_trans_src_t *trans_kernel_; - cpu_accumulator_1d_t *acc_ker_; - cpu_reducer_t *reducer_bias_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp deleted file mode 100644 index 62247c026..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp +++ /dev/null @@ -1,1215 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifdef __INTEL_COMPILER -#include -#endif - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_common_convolution_winograd.hpp" - -#ifndef _MSC_VER -#define pragma_unroll _Pragma("unroll") -#else -#define pragma_unroll -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace memory_tracking::names; - -namespace { - -unsigned int LLC_cache_size = get_cache_size(3, false); - -void inline load_ps(float *dest, const float *src_mem) { -#ifdef __INTEL_COMPILER - __m512 *Iv512 = (__m512 *)dest; - Iv512[0] = _mm512_load_ps(src_mem); -#else - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v]; -#endif -} - -void inline store_output(float *dest, const float *data, bool streamout) { -#ifdef __INTEL_COMPILER - if (streamout) - _mm512_stream_ps(dest, *((__m512 *)data)); - else - _mm512_store_ps(dest, *((__m512 *)data)); -#else - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) - dest[v] = data[v]; -#endif -} - -void inline accum_output( - float *dest, float *data, bool streamout, bool with_relu_postsum) { -#ifdef __INTEL_COMPILER - __m512 _data = _mm512_loadu_ps(data); - __m512 _dest = _mm512_loadu_ps(dest); - _data = _mm512_add_ps(_data, _dest); - if (with_relu_postsum) - _data = _mm512_max_ps(_data, _mm512_setzero_ps()); - if (streamout) - _mm512_stream_ps(dest, _data); - else - _mm512_store_ps(dest, _data); -#else - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) - data[v] += dest[v]; - - if (with_relu_postsum) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) - if (data[v] < 0.f) - data[v] = 0.f; - } - - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) - dest[v] = data[v]; -#endif -} -} - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; - -void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) { - float Fw[6][16]; - float T[6][3][16]; - float t0[16]; - float t1[16]; - float t2[16]; - - for (int j = 0; j < 16; j++) { -#pragma unroll - for (int i = 0; i < 3; i++) { - PRAGMA_OMP_SIMD() - for (int k = 0; k < 16; k++) { - t0[k] = 0.26890756302521f * F[2][i][j][k]; - t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k]; - t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k]; - - T[0][i][k] = 1.13777777777778f * F[0][i][j][k]; - T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k]; - T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k]; - T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k]; - T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k]; - T[5][i][k] = F[2][i][j][k]; - } - } -#pragma unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int k = 0; k < 16; k++) { - t0[k] = 0.26890756302521f * T[i][2][k]; - t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k]; - t2[k] = t0[k] + 0.119514472455649f * T[i][0][k]; - - Fw[0][k] = 1.13777777777778f * T[i][0][k]; - Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k]; - Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k]; - Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k]; - Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k]; - Fw[5][k] = T[i][2][k]; -#pragma unroll - for (int l = 0; l < 6; l++) { - Fw_[i][l][j][k] = Fw[l][k]; - } - } - } - } -} - -void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) { - float T[4][6][16]; - float t0[16]; - float t1[16]; - float t2[16]; - float t3[16]; - -#pragma unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = Mw[1][i][v] + Mw[2][i][v]; - t1[v] = Mw[3][i][v] + Mw[4][i][v]; - t2[v] = Mw[1][i][v] - Mw[2][i][v]; - t3[v] = Mw[3][i][v] - Mw[4][i][v]; - - T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v]; - T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f; - T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f; - T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v]; - } - } -#pragma unroll - for (int i = 0; i < 4; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = T[i][1][v] + T[i][2][v]; - t1[v] = T[i][3][v] + T[i][4][v]; - t2[v] = T[i][1][v] - T[i][2][v]; - t3[v] = T[i][3][v] - T[i][4][v]; - - O[i][0][v] = t0[v] + t1[v] + T[i][0][v]; - O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f; - O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f; - O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v]; - } - } -} - - -void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]) -{ - const float rcp3 = 1.0f / 3.0f; - const float rcp4 = 1.0f / 4.0f; - const float rcp6 = 1.0f / 6.0f; - const float rcp12 = 1.0f / 12.0f; - const float rcp24 = 1.0f / 24.0f; - float t0[16]; - float t1[16]; - float t2[16]; - float t3[16]; - float t4[16]; - float T[6][4][16]; - -pragma_unroll - for (int i = 0; i < 4; i++) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < 16; j++) { - t0[j] = F[2][i][j] * rcp6; - t1[j] = F[0][i][j] * -rcp6 - t0[j]; - t2[j] = F[0][i][j] * rcp24 + t0[j]; - t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6; - t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3; - - T[0][i][j] = F[0][i][j] * rcp4; - T[1][i][j] = t1[j] - t3[j]; - T[2][i][j] = t1[j] + t3[j]; - T[3][i][j] = t2[j] + t4[j]; - T[4][i][j] = t2[j] - t4[j]; - T[5][i][j] = F[3][i][j]; - } - } -pragma_unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < 16; j++) { - t0[j] = T[i][2][j] * rcp6; - t1[j] = T[i][0][j] * -rcp6 - t0[j]; - t2[j] = T[i][0][j] * rcp24 + t0[j]; - t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6; - t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3; - - Fw[i][0][j] = T[i][0][j] * rcp4; - Fw[i][1][j] = t1[j] - t3[j]; - Fw[i][2][j] = t1[j] + t3[j]; - Fw[i][3][j] = t2[j] + t4[j]; - Fw[i][4][j] = t2[j] - t4[j]; - Fw[i][5][j] = T[i][3][j]; - } - } -} - -void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]) -{ - float T[4][6][16]; - float M_[3][16]; - float t0[16]; - float t1[16]; - float t2[16]; - - for (int j = 0; j < 16; j++) { -pragma_unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int l = 0; l < 16; l++) { - t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l]; - t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l]; - t2[l] = t1[l] * 4.0f + Mw[5][i][j][l]; - - T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l]; - T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) + - 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]); - T[2][i][l] = t0[l] + t2[l]; - } - } -pragma_unroll - for (int i = 0; i < 3; i++) { - PRAGMA_OMP_SIMD() - for (int l = 0; l < 16; l++) { - t0[l] = T[i][1][l] + T[i][2][l]; - t1[l] = T[i][3][l] + T[i][4][l]; - t2[l] = t1[l] * 4.0f + T[i][5][l]; - - M_[0][l] = T[i][0][l] + t0[l] + t1[l]; - M_[1][l] = (T[i][1][l] - T[i][2][l]) + - 2.0f * (T[i][3][l] - T[i][4][l]); - M_[2][l] = t0[l] + t2[l]; - - for (int k = 0; k < 3; k++) { - M[i][k][j][l] = M_[k][l]; - } - } - } - } -} - -void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]) -{ - float T[6][6][16]; - float t0[16]; - float t1[16]; - float t2[16]; - float t3[16]; - float t4[16]; - float t5[16]; - -pragma_unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = I[2][i][v] * -2.25f + I[4][i][v]; - t1[v] = I[1][i][v] * -2.25f + I[3][i][v]; - t2[v] = I[2][i][v] * -0.390625f + I[4][i][v]; - t3[v] = I[1][i][v] * -0.390625f + I[3][i][v]; - t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v]; - t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v]; - - T[0][i][v] = I[2][i][v] * -2.640625f + t4[v]; - T[1][i][v] = t1[v] * 0.625f + t0[v]; - T[2][i][v] = t1[v] * -0.625f + t0[v]; - T[3][i][v] = t3[v] * 1.5f + t2[v]; - T[4][i][v] = t3[v] * -1.5f + t2[v]; - T[5][i][v] = I[3][i][v] * -2.640625f + t5[v]; - } - } - -pragma_unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = T[i][2][v] * -2.25f + T[i][4][v]; - t1[v] = T[i][1][v] * -2.25f + T[i][3][v]; - t2[v] = T[i][2][v] * -0.390625f + T[i][4][v]; - t3[v] = T[i][1][v] * -0.390625f + T[i][3][v]; - t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v]; - t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v]; - - Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v]; - Iw[i][1][v] = t1[v] * 0.625f + t0[v]; - Iw[i][2][v] = t1[v] * -0.625f + t0[v]; - Iw[i][3][v] = t3[v] * 1.5f + t2[v]; - Iw[i][4][v] = t3[v] * -1.5f + t2[v]; - Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v]; - } - } -} - -void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]) -{ - float T[6][4][16]; - float t0[16]; - float t1[16]; - float t2[16]; - float t3[16]; - float t4[16]; - -pragma_unroll - for (int i = 0; i < 4; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = F[2][i][v] * 0.26890756302521f; - t1[v] = F[0][i][v] * -0.688403361344538f - t0[v]; - t2[v] = F[0][i][v] * 0.119514472455649f + t0[v]; - t3[v] = F[1][i][v] * 0.430252100840336f + - F[3][i][v] * 0.168067226890756f; - t4[v] = F[1][i][v] * 0.179271708683473f + - F[3][i][v] * 0.403361344537815f; - - T[0][i][v] = F[0][i][v] * 1.13777777777778f; - T[1][i][v] = t1[v] - t3[v]; - T[2][i][v] = t1[v] + t3[v]; - T[3][i][v] = t2[v] + t4[v]; - T[4][i][v] = t2[v] - t4[v]; - T[5][i][v] = F[3][i][v]; - } - } -pragma_unroll - for (int i = 0; i < 6; i++) { - for (int v = 0; v < 16; v++) { - t0[v] = T[i][2][v] * 0.26890756302521f; - t1[v] = T[i][0][v] * -0.688403361344538f - t0[v]; - t2[v] = T[i][0][v] * 0.119514472455649f + t0[v]; - t3[v] = T[i][1][v] * 0.430252100840336f + - T[i][3][v] * 0.168067226890756f; - t4[v] = T[i][1][v] * 0.179271708683473f + - T[i][3][v] * 0.403361344537815f; - - Fw[i][0][v] = T[i][0][v] * 1.13777777777778f; - Fw[i][1][v] = t1[v] - t3[v]; - Fw[i][2][v] = t1[v] + t3[v]; - Fw[i][3][v] = t2[v] + t4[v]; - Fw[i][4][v] = t2[v] - t4[v]; - Fw[i][5][v] = T[i][3][v]; - } - } -} - -void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]) -{ - float T[3][6][16]; - float t0[16]; - float t1[16]; - float t2[16]; - float M_[3][16]; - - for (int j = 0; j < 16; j++) { -pragma_unroll - for (int i = 0; i < 6; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v]; - t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v]; - t2[v] = t1[v] * 2.25f + Mw[5][i][j][v]; - - T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v]; - T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) + - 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]); - T[2][i][v] = t0[v] * 0.390625f + t2[v]; - } - } -pragma_unroll - for (int i = 0; i < 3; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - t0[v] = T[i][1][v] + T[i][2][v]; - t1[v] = T[i][3][v] + T[i][4][v]; - t2[v] = t1[v] * 2.25f + T[i][5][v]; - - M_[0][v] = T[i][0][v] + t0[v] + t1[v]; - M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) + - 1.5f * (T[i][3][v] - T[i][4][v]); - M_[2][v] = t0[v] * 0.390625f + t2[v]; - } - -pragma_unroll - for (int k = 0; k < 3; k++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < 16; v++) { - M[i][k][j][v] = M_[k][v]; - } - } - } - } -} - -template -void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, - float *inp, float *tinp, bool streamout = true) -{ - const int inpw = is_fwd ? jcp.iw : jcp.ow; - const int inph = is_fwd ? jcp.ih : jcp.oh; - const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; - const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; - const int wp_max = inpw + l_pad; - const int hp_max = inph + t_pad; - float Iw[alpha][alpha][simd_w]; - float I[alpha][alpha][simd_w]; - - array_offset_calculator input(inp, - jcp.mb, jcp.dimK/simd_w, inph, inpw, - simd_w); - array_offset_calculator output(tinp, - jcp.dimN_nb_block, alpha, alpha, - jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, - jcp.dimN_reg_block, jcp.dimK_reg_block); - - int tile_base_index = image * jcp.itiles * jcp.jtiles; - int tile_block_ur = tile_base_index % jcp.tile_block_ur; - int nb_tile_block_ur = - (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; - int tile_block = - (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; - - for (int tj = 0; tj < jcp.jtiles; tj++) { - for (int ti = 0; ti < jcp.itiles; ti++) { - for (int j = 0; j < alpha; j++) { - int ydim = tj * tile_size + j; - if ((t_pad <= ydim) && (ydim < hp_max)) { - float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ; - for (int i = 0; i < alpha; i++) { - int xdim = ti * tile_size + i; - if ((l_pad <= xdim) && (xdim < wp_max)) { - float *pinp_i = pinp_j + (xdim - l_pad) * 16; - load_ps(I[j][i], pinp_i); - } else { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = 0.0f; - } - } - } - } else { - for (int i = 0; i < alpha; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = 0.0f; - } - } - } - } - - trans_I_4x4_3x3(Iw, I); - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - store_output(&(output(tile_block, j, i, - nb_tile_block_ur, 0, 0, - tile_block_ur, 0)), - Iw[j][i], streamout); - } - } - tile_block_ur++; - if (tile_block_ur >= jcp.tile_block_ur) { - tile_block_ur = 0; - nb_tile_block_ur++; - } - if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { - nb_tile_block_ur = 0; - tile_block++; - } - } - } -} - -template -void weight_transform_data(const jit_conv_winograd_conf_t &jcp, - float *wp, float *twp) -{ - const int kh = 3; - const int kw = 3; - array_offset_calculator input(wp, - jcp.oc/jcp.oc_simd_block, - jcp.ic/jcp.ic_simd_block, - jcp.kh, jcp.kw, - simd_w, simd_w); - array_offset_calculator output(twp, - jcp.dimM_nb_block, - alpha, alpha, - jcp.dimK_nb_block, - jcp.dimM_block, jcp.dimK_block, - simd_w, simd_w); - float Fw[alpha][alpha][simd_w][simd_w]; - float F[kh][kw][simd_w][simd_w]; - - for (int j = 0; j < kh; j++) { - for (int i = 0; i < kw; i++) { - for (int v1 = 0; v1 < simd_w; v1++) { - float *base_inp = is_fwd - ? &(input(0, 0, j, i, v1, 0)) - : &(input(0, 0, 2 - j, 2 - i, v1, 0)); - PRAGMA_OMP_SIMD() - for (int v2 = 0; v2 < simd_w; v2++) { - if (is_fwd) - F[j][i][v1][v2] = *(base_inp + v2); - else - F[j][i][v2][v1] = *(base_inp + v2); - } - } - } - } - - trans_W_4x4_3x3(Fw, F); - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - for (int v1 = 0; v1 < simd_w; v1++) { - PRAGMA_OMP_SIMD() - for (int v2 = 0; v2 < simd_w; v2++) { - output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2]; - } - } - } - } -} - -template -void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp, - const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias, - bool streamout = true) { - float Ow[alpha][alpha][simd_w]; - float O[tile_size][tile_size][simd_w]; - int outw = is_fwd ? jcp.ow : jcp.iw; - int outh = is_fwd ? jcp.oh : jcp.ih; - - /* Prepare for PostOps */ - bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1; - - array_offset_calculator input(toutp, - jcp.dimN_nb_block, jcp.dimM_nb_block, - alpha, alpha, - jcp.dimN_block, jcp.dimM_block, - jcp.dimN_reg_block, jcp.dimM_simd_block); - - int tile_base_index = image * jcp.itiles * jcp.jtiles; - int tile_block_ur = tile_base_index % jcp.tile_block_ur; - int nb_tile_block_ur = - (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; - int tile_block = - (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; - - for (int tj = 0; tj < jcp.jtiles; tj++) { - for (int ti = 0; ti < jcp.itiles; ti++) { - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - Ow[j][i][v] = input(tile_block, 0, - j, i, - nb_tile_block_ur, 0, - tile_block_ur, v); - } - } - } - - trans_O_4x4_3x3(Ow, O); - - for (int j = 0; j < tile_size; j++) { - int ydim = tj * tile_size + j; - if (ydim < outh) { - float *pout_j = pout_b + ydim * outw * simd_w; - for (int i = 0; i < tile_size; i++) { - int xdim = ti * tile_size + i; - if (xdim < outw) { - float *pout_i = pout_j + xdim * simd_w; - if (is_fwd) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - O[j][i][v] += with_bias ? bias[v] : 0.f; - O[j][i][v] = true - && with_relu_presum && O[j][i][v] < 0.f - ? O[j][i][v] - * jcp.eltwise.alpha - : O[j][i][v]; - } - } - if (with_sum) - accum_output(pout_i, O[j][i], streamout, - with_relu_postsum); - else - store_output(pout_i, O[j][i], streamout); - } - } - } - } - tile_block_ur++; - if (tile_block_ur >= jcp.tile_block_ur) { - tile_block_ur = 0; - nb_tile_block_ur++; - } - if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { - nb_tile_block_ur = 0; - tile_block++; - } - } - } -} - -template -void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, - float *inp, float *tinp, float *Iw_temp, - void (*transpose_4fma_ker)(float *, float *)) -{ - - const int ifwp = conv.iw + conv.l_pad; - const int ifhp = conv.ih + conv.t_pad; - float I[alpha][alpha][simd_w]; - float Iw[alpha][alpha][simd_w]; - - array_offset_calculator Iw_trans_temp(Iw_temp, - alpha, alpha, conv.tile_4fma, simd_w); - array_offset_calculator input(inp, - conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w); - array_offset_calculator output(tinp, - conv.nb_ic, alpha, alpha, - conv.tile_block, conv.ic_block, - conv.nb_tile_block_ur, conv.tile_block_ur, - conv.ic_simd_block * conv.tile_4fma); - - int tile_base_index = - image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding); - int tile_4fma = 0; - int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur; - int nb_tile_block_ur = - (tile_base_index / conv.tile_4fma / conv.tile_block_ur) - % conv.nb_tile_block_ur; - int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur) - / conv.nb_tile_block_ur; - - for (int tj = 0; tj < conv.jtiles; tj++) { - for (int ti = 0; ti < conv.itiles; ti++) { - for (int j = 0; j < alpha; j++) { - int ydim = tj * tile_size + j; - if ((conv.t_pad <= ydim) && ydim < ifhp) { - for (int i = 0; i < alpha; i++) { - int xdim = ti * tile_size + i; - if ((conv.l_pad <= xdim) && xdim < ifwp) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = input(0, 0, - ydim - conv.t_pad, - xdim - conv.l_pad, v); - } - } else { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = 0.0f; - } - } - } - } else { - for (int i = 0; i < alpha; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = 0.0f; - } - } - } - } - trans_I_4x4_3x3(Iw, I); - - if (ver_4fma) { - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - float *Iw_temp_base = &(Iw_trans_temp(j, i, - tile_4fma, 0)); - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - Iw_temp_base[v] = Iw[j][i][v]; - } - } - } - tile_4fma++; - if (tile_4fma == conv.tile_4fma) { - float *outp = &(output(0, 0, 0, - tile_block, 0, - nb_tile_block_ur, tile_block_ur, 0)); - transpose_4fma_ker(outp, (float *)Iw_temp); - tile_4fma = 0; - tile_block_ur++; - } - } else { - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - store_output(&(output(0, j, i, - tile_block, 0, - nb_tile_block_ur, tile_block_ur, 0)), - Iw[j][i], true); - } - } - tile_block_ur++; - } - - if (tile_block_ur == conv.tile_block_ur) { - tile_block_ur = 0; - ++nb_tile_block_ur; - } - if (nb_tile_block_ur == conv.nb_tile_block_ur) { - nb_tile_block_ur = 0; - tile_block++; - } - } - } - - if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) { - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) { - float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0)); - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - Iw_temp_base[v] = 0; - } - } - } - } - float *outp = &(output(0, 0, 0, - tile_block, 0, - nb_tile_block_ur, tile_block_ur, 0)); - transpose_4fma_ker(outp, (float *)Iw_temp); - } -} - -template -void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, - float *inp, float *tinp, float *dbias) -{ - - const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding; - float I[alpha][alpha][simd_w]; - float Iw[alpha][alpha][simd_w]; - - array_offset_calculator input(inp, - conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block); - array_offset_calculator output(tinp, - conv.nb_oc, alpha, alpha, - conv.tile_block, conv.oc_block, - conv.nb_tile_block_ur, - conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block); - - int tile_base_index = image * total_tiles; - int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma); - int nb_tile_block_ur = - (tile_base_index / conv.tile_block_ur / conv.tile_4fma) - % conv.nb_tile_block_ur; - int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma) - / conv.nb_tile_block_ur; - - for (int tj = 0; tj < conv.jtiles; tj++) { - for (int ti = 0; ti < conv.itiles; ti++) { - for (int j = 0; j < alpha; j++) { - int ydim = tj * tile_size + j; - if (ydim < conv.oh) { - for (int i = 0; i < alpha; i++) { - int xdim = ti * tile_size + i; - if (xdim < conv.ow) { - float *input_base = &(input(0, 0, ydim, xdim, 0)); - - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = input_base[v]; - } - if (with_bias && j < tile_size && i < tile_size) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - dbias[v] += input_base[v]; - } - } - } else { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = 0.0f; - } - } - } - } else { - for (int i = 0; i < alpha; i++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - I[j][i][v] = 0.0f; - } - } - } - } - - trans_W_3x3_4x4_wu(Iw, I); - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - store_output(&(output(0, j, i, - tile_block, 0, - nb_tile_block_ur, - tile_block_ur, 0)), - Iw[j][i], true); - } - } - tile_block_ur++; - if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) { - tile_block_ur = 0; - nb_tile_block_ur++; - } - if (nb_tile_block_ur >= conv.nb_tile_block_ur) { - nb_tile_block_ur = 0; - tile_block++; - } - } - } -} - -void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv, - float *wp, float *twp) -{ - const int kh = 3; - const int kw = 3; - float Fw[alpha][alpha][simd_w][simd_w]; - float F[kh][kw][simd_w][simd_w]; - - array_offset_calculator input(twp, - conv.nb_ic, conv.nb_oc, - alpha, alpha, - conv.oc_block, conv.ic_block, - conv.ic_simd_block, conv.oc_simd_block); - array_offset_calculator output(wp, - conv.oc/simd_w, conv.ic/simd_w, - conv.kh, conv.kw, - conv.ic_simd_block, conv.oc_simd_block); - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - for (int v = 0; v < conv.ic_simd_block; v++) { - PRAGMA_OMP_SIMD() - for (int k = 0; k < conv.oc_simd_block; k++) { - Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k); - } - } - } - } - - trans_O_3x3_4x4_wu(Fw, F); - - for (int j = 0; j < kh; j++) { - for (int i = 0; i < kw; i++) { - for (int v = 0; v < conv.ic_simd_block; v++) { - store_output(&(output(0, 0, j, i, v, 0)), - F[j][i][v], true); - } - } - } -} - -template -void _jit_avx512_common_convolution_winograd_t::_execute_data_W_S_G_D( - float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const auto &p_ops = attr_->post_ops_; - - const int inph = is_fwd ? jcp.ih : jcp.oh; - const int inpw = is_fwd ? jcp.iw : jcp.ow; - const int outh = is_fwd ? jcp.oh : jcp.ih; - const int outw = is_fwd ? jcp.ow : jcp.iw; - - /* Note that jcp.with_eltwise is true for both fused conv+relu primitive - * and conv primitive with PostOps with relu before sum - * (PostOps relu after sum is handled later) */ - auto output_transform = jcp.with_bias - ? (jcp.with_eltwise - ? (jcp.with_sum - ? output_transform_data - : output_transform_data) - : (jcp.with_sum - ? output_transform_data - : output_transform_data)) - : (jcp.with_eltwise - ? (jcp.with_sum - ? output_transform_data - : output_transform_data) - : (jcp.with_sum - ? output_transform_data - : output_transform_data)); - - /* Notation: - FWD: dimM:oc, dimN:ntiles, dimK:ic, - BWD: dimM:ic, dimN:ntiles, dimK:oc, - FWD/BWD: V: src/diff_dst transform, U:weight transform, - M:dst/diff_src transform */ - array_offset_calculator input(inp_ptr, - jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, - jcp.dimK_reg_block); - array_offset_calculator output(out_ptr, - jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, - jcp.dimM_simd_block); - array_offset_calculator weights(wei_ptr, - jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, - jcp.ic_simd_block, jcp.oc_simd_block); - array_offset_calculator bias(bias_ptr, - jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); - - array_offset_calculator M(is_fwd - ? scratchpad.template get(key_wino_M) - : scratchpad.template get(key_wino_V), - jcp.dimN_nb_block, jcp.dimM_nb_block, - alpha, alpha, - jcp.dimN_block, jcp.dimM_block, - jcp.dimN_reg_block, jcp.dimM_simd_block); - array_offset_calculator U( - scratchpad.template get(key_wino_U), - jcp.dimM_nb_block, - alpha, alpha, - jcp.dimK_nb_block, - jcp.dimM_block, jcp.dimK_block, - jcp.dimK_reg_block, jcp.dimM_simd_block); - array_offset_calculator V(is_fwd - ? scratchpad.template get(key_wino_V) - : scratchpad.template get(key_wino_M), - jcp.dimN_nb_block, alpha, alpha, - jcp.dimN_block, jcp.dimK_nb_block, - jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); - - bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) - > 2 * LLC_cache_size ? true : false; - - const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0; - - const bool wants_padded_bias = jcp.with_bias - && jcp.oc_without_padding != jcp.oc; - float last_slice_bias[simd_w] = {0}; - if (wants_padded_bias) { - for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) - last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); - } - - { - parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, - [&](int img, int K_blk1, int K_blk2) { - input_transform_data(img, jcp, - &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), - &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout); - }); - - parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block, - [&](int ofm1, int ifm1, int ofm2, int ifm2) { - float *U_base_ptr = is_fwd - ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) - : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); - weight_transform_data(jcp, - &(weights(ofm1 * jcp.oc_block + ofm2, - ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr); - }); - - parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block, - [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) { - - kernel_->gemm_loop_ker_first_iter( - (float *)&(M(N_blk1, M_blk1, oj, oi, - N_blk2, 0, 0, 0)), - (const float *)&(U(M_blk1, oj, oi, - 0, 0, 0, 0, 0)), - (const float *)&(V(N_blk1, oj, oi, - N_blk2, 0, 0, 0, 0))); - for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) { - kernel_->gemm_loop_ker( - (float *)&(M(N_blk1, M_blk1, oj, oi, - N_blk2, 0, 0, 0)), - (const float *)&(U(M_blk1, oj, oi, - K_blk1, 0, 0, 0, 0)), - (const float *)&(V(N_blk1, oj, oi, - N_blk2, K_blk1, - 0, 0, 0))); - } - - }); - - parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block, - [&](int img, int M_blk1, int M_blk2) { - - const int M_blk = M_blk1 * jcp.dimM_block + M_blk2; - - float *bias_ptr = wants_padded_bias - && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 - ? last_slice_bias : &bias(M_blk, 0); - - output_transform(img, jcp, p_ops, - &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), - &(output(img, M_blk, 0, 0, 0)), - bias_ptr, output_is_aligned); - - }); - - } -} - -template struct _jit_avx512_common_convolution_winograd_t; -template struct _jit_avx512_common_convolution_winograd_t; - -void jit_avx512_common_convolution_winograd_bwd_weights_t:: -_maybe_execute_diff_bias_copy(float *diff_bias, - const memory_tracking::grantor_t &scratchpad) const { - if (pd()->wants_padded_bias()) { - auto padded_bias = scratchpad.get(key_conv_padded_bias); - for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) - diff_bias[oc] = padded_bias[oc]; - } -} - -void jit_avx512_common_convolution_winograd_bwd_weights_t:: -_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, - const memory_tracking::grantor_t &scratchpad) const { - auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); - auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); - auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); - auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); - - const auto &jcp = kernel_->jcp; - const int nthreads = jcp.nthr; - - auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ? - diff_src_transform_bwd_weights : - diff_src_transform_bwd_weights; - auto diff_dst_transform_bwd_weights_ver = jcp.with_bias - ? diff_dst_transform_bwd_weights - : diff_dst_transform_bwd_weights; - - array_offset_calculator src((float *)ptr_src, - jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w); - array_offset_calculator diff_dst((float *)ptr_diff_dst, - jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w); - array_offset_calculator diff_weights(ptr_diff_weights, - jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w); - array_offset_calculator diff_bias(pd()->wants_padded_bias() - ? scratchpad.get(key_conv_padded_bias) : ptr_diff_bias, - jcp.oc/simd_w, simd_w); - - array_offset_calculator U( - scratchpad.get(key_wino_U), - jcp.nb_ic, jcp.nb_oc, - alpha, alpha, - jcp.oc_block, jcp.ic_block, - jcp.ic_simd_block, jcp.oc_simd_block); - - array_offset_calculator M( - scratchpad.get(key_wino_M), - jcp.nb_oc, alpha, alpha, - jcp.tile_block, jcp.oc_block, - jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma, - jcp.oc_simd_block); - array_offset_calculator V( - scratchpad.get(key_wino_V), - jcp.nb_ic, alpha, alpha, - jcp.tile_block, jcp.ic_block, - jcp.nb_tile_block_ur, jcp.tile_block_ur, - jcp.ic_simd_block * jcp.tile_4fma); - - const int trans_buffer_size = alpha * alpha * jcp.tile_4fma - * jcp.ic_simd_block; - array_offset_calculator trans_buffer( - scratchpad.get(key_conv_tr_src), - nthreads, - trans_buffer_size); - - array_offset_calculator diff_bias_prv( - scratchpad.get(key_conv_bia_reduction), - nthreads, - jcp.oc); - -PRAGMA_OMP(parallel num_threads(nthreads)) - { - if (jcp.with_bias) { - parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { - diff_bias_prv(ithr, ofm) = 0.0f; - }); - -PRAGMA_OMP(for nowait) - for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) { - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) - diff_bias(bofm, v) = 0.0f; - } - } - - const int ithread = mkldnn_get_thread_num(); - - parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block, - [&](int img, int ifm1, int ifm2) { - float *transb = jcp.ver == ver_4fma - ? &(trans_buffer(ithread, 0)) - : NULL; - diff_src_transform_bwd_weights_ver(img, jcp, - &(src(img, ifm1 * jcp.ic_block + ifm2, - 0, 0, 0)), - &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)), - transb, - kernel_->transpose_4fma_ker); - }); - - parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block, - [&](int img, int ofm1, int ofm2) { - float *dbias = jcp.with_bias - ? &(diff_bias_prv(ithread, - simd_w * (ofm1 * jcp.oc_block + ofm2))) - : NULL; - diff_dst_transform_bwd_weights_ver(img, jcp, - &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, - 0, 0, 0)), - &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), - dbias); - }); - -PRAGMA_OMP(barrier) - - for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) { - parallel_nd_in_omp(alpha, alpha, jcp.nb_oc, - [&](int oj, int oi, int ofm1) { - kernel_->gemm_loop_ker_first_iter( - (float *)&(U(ifm1, ofm1, oj, oi, - 0, 0, 0, 0)), - (const float *)&(M(ofm1, oj, oi, - 0, 0, 0, 0, 0)), - (const float *)&(V(ifm1, oj, oi, - 0, 0, 0, 0, 0))); - for (int tile_block = 1; tile_block < jcp.tile_block; - tile_block++) { - kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1, - oj, oi, - 0, 0, 0, 0)), - (const float *)&(M(ofm1, oj, oi, tile_block, - 0, 0, 0, 0)), - (const float *)&(V(ifm1, oj, oi, tile_block, - 0, 0, 0, 0))); - } - }); - } - -PRAGMA_OMP(barrier) - - parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, - [&](int ifm1, int ofm1, int ofm2, int ifm2) { - diff_weights_transform_bwd_weights(jcp, - &(diff_weights(ofm1 * jcp.oc_block + ofm2, - ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), - &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0))); - }); - - if (jcp.with_bias) { -PRAGMA_OMP(for) - for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) { - for (int ithr = 0; ithr < nthreads; ithr++) { - float* base_bias_ptr = &(diff_bias(ofm1, 0)); - float* base_bias_prv_ptr = &(diff_bias_prv( - ithr * jcp.oc + ofm1 * simd_w)); - PRAGMA_OMP_SIMD() - for (int ofm2 = 0; ofm2 < simd_w; ofm2++) { - base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2]; - } - } - } - } - } - - _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad); -} - -} -} -} -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp deleted file mode 100644 index 6c76f37c7..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP -#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace winograd_avx512_common { -inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_winograd_conf_t &jcp) { - using namespace memory_tracking::names; - - size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; - size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic - * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); - size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc - * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); - - scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); - scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); - scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); - - if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { - const int nthr = mkldnn_get_max_threads(); - - size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr - * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; - scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); - - size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; - scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); - - size_t padded_bias_sz = - jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; - scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); - } -} -} - -template -struct _jit_avx512_common_convolution_winograd_t { - _jit_avx512_common_convolution_winograd_t( - const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) - : kernel_(nullptr), attr_(attr) { - kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); - } - - ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } - - protected: - void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, - float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const; - _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; - const primitive_attr_t *attr_; -}; - -struct jit_avx512_common_convolution_winograd_fwd_t - : _jit_avx512_common_convolution_winograd_t - , public cpu_primitive_t - { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), - jit_avx512_common_convolution_winograd_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: - init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), - *attr()); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - auto scratchpad = scratchpad_registry().registrar(); - winograd_avx512_common::init_scratchpad(scratchpad, jcp_); - - return status; - } - - jit_conv_winograd_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; - return set_default_formats_common(nChw16c, wei_tag, nChw16c); - } - }; - - jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) - : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) - , cpu_primitive_t(apd, true) {} - - ~jit_avx512_common_convolution_winograd_fwd_t(){}; - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override - { - auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); - this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, - (float *)bias, this->scratchpad(ctx)); - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct jit_avx512_common_convolution_winograd_bwd_data_t - : _jit_avx512_common_convolution_winograd_t, - public cpu_primitive_t { - struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), - jit_avx512_common_convolution_winograd_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && expect_data_types(data_type::f32, data_type::f32, - data_type::undef, data_type::f32, data_type::f32) - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && !has_zero_dim_memory() - && set_default_formats() - && mkldnn_thr_syncable(); - if (!ok) return status::unimplemented; - - status_t status = - jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( - jcp_, *desc(), *diff_src_md(), *weights_md(), - *diff_dst_md()); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - auto scratchpad = scratchpad_registry().registrar(); - winograd_avx512_common::init_scratchpad(scratchpad, jcp_); - - return status; - } - - jit_conv_winograd_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; - return set_default_formats_common(nChw16c, wei_tag, nChw16c); - } - }; - - jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) - : _jit_avx512_common_convolution_winograd_t(apd->jcp_, apd->attr()) - , cpu_primitive_t(apd, true) {} - - ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); - this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, - (float *)weights, nullptr, this->scratchpad(ctx)); - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct jit_avx512_common_convolution_winograd_bwd_weights_t - : public cpu_primitive_t { - struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, - hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), - jit_avx512_common_convolution_winograd_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats() - && mkldnn_thr_syncable(); - if (!ok) return status::unimplemented; - - status_t status = - jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: - init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), - *diff_weights_md()); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - auto scratchpad = scratchpad_registry().registrar(); - winograd_avx512_common::init_scratchpad(scratchpad, jcp_); - - return status; - } - - jit_conv_winograd_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; - return set_default_formats_common(nChw16c, wei_tag, nChw16c); - } - }; - - jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd, true), kernel_(nullptr) - { - kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( - pd()->jcp_); - } - - ~jit_avx512_common_convolution_winograd_bwd_weights_t() - { delete kernel_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override - { - _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); - return status::success; - } - -private: - void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, - const memory_tracking::grantor_t &scratchpad) const; - void _maybe_execute_diff_bias_copy(float *diff_bias, - const memory_tracking::grantor_t &scratchpad) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; -}; - -void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); -void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); -void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); -void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); -void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); -void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); -void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp deleted file mode 100644 index d4a451c02..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp +++ /dev/null @@ -1,853 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_common_lrn.hpp" - -#include "jit_generator.hpp" - -#define FWD_RBC 4 -#define BWD_RBC 3 - -#define XMM_SIZE (4*sizeof(float)) -#define ZMM_SIZE (vlen) -#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE) -#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE) -#define SRC_PREV_OFFSET (vlen - XMM_SIZE) - -#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \ - statement;\ -} - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -enum params { vsize = 16, vlen = 64}; - -typedef struct { - const float *src; - float *dst, *ws0, *ws1; -} jit_args_fwd_t; - -typedef struct { - const float *src, *diff_dst, *ws0, *ws1; - float *diff_src; -} jit_args_bwd_t; - -struct nChw16c_across { -/* version: - * -1: channels 0..15, - * 1: channels C-16 .. C-1, - * 0: other channels - * 3: channels only for this kernel(without prev and next) - */ - int H, W, version; - nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {} -}; - -struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32: - public jit_generator { - int HW, W; - bool is_first; - bool is_last; - bool is_single; - - Reg64 src = rax; - Reg64 dst = r8; - Reg64 scratch0 = rdx; - Reg64 scratch1 = rsi; - Reg64 imm_addr64 = rbx; - - Zmm zalpha = zmm0; - Xmm xalpha = xmm0; - Zmm zk = zmm1; - Xmm xk = xmm1; - - Reg64 param = abi_param1; - Reg64 t = rsp; - Reg64 hw = r9; - - int xsrc_prev = 2; - int zsrc = 7; - int xsrc_next = 3; - int zc = 7; - - int za = 2; - int zb = 3; - int zd = 5; - int ze = 6; - int zsum = 4; - int zdst = 2; - int zbase = 3; - int zsum2 = 5; - - prop_kind_t pk; - int use_h_parallelism; - - float alpha, k; - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) - - void (*ker)(jit_args_fwd_t *); - void operator()(jit_args_fwd_t *arg) { ker(arg); } - - enum { - prf0_offt = 1*FWD_RBC, - prf2_offt = 8*FWD_RBC - }; - - inline void compute_loop(int loop_size_param) - { - // loop_size - param for IRB_LOOP macro - int loop_size = FWD_RBC; - - auto xreg = [=](int irb, int i) { - return Xmm(irb*3 + i); - }; - - auto zreg = [=](int irb, int i) { - return Zmm(irb*7 + i); - }; - - if (!is_first && !is_single) { - IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen])); - IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen])); - } - IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen))); - IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen))); - if (!is_last && !is_single) { - IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen])); - IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen])); - } - if (pk != prop_kind::forward_inference) { - IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0, - (irb + prf0_offt)*vlen))); - IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0, - (irb + prf2_offt)*vlen))); - } - IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen))); - IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen))); - if (pk != prop_kind::forward_inference) { - IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1, - (irb + prf0_offt) * vlen))); - IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1, - (irb + prf2_offt) * vlen))); - } - - loop_size = loop_size_param; - if (loop_size == 0) - return; - if (!is_first && !is_single) { - IRB_LOOP(vmovups(xreg(irb, xsrc_prev), - ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET])); - } - IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen))); - if (!is_last && !is_single) { - IRB_LOOP(vmovups(xreg(irb, xsrc_next), - ptr[src + (irb + HW) * vlen])); - } - - if (!is_first && !is_single) { - IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], - xreg(irb, xsrc_prev))); - } - IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), - zreg(irb, zsrc))); - if (!is_last && !is_single) { - IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], - xreg(irb, xsrc_next))); - } - - IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE - 2*sizeof(float)))); - IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE - sizeof(float)))); - IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE + sizeof(float)))); - IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE + 2*sizeof(float)))); - - assert(zc == zsrc); - IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc))); - - IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za))); - IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb))); - IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd))); - IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze))); - - IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha)); - - IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum))); - - IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum))); - IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2))); - - IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); - IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); - - if (pk != prop_kind::forward_inference) { - IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen), - zreg(irb, zsum))); - } - IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum))); - IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst))); - if (pk != prop_kind::forward_inference) { - /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ - IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase))); - IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen), - zreg(irb, zsum))); - } - } - - jit_avx512_common_lrn_kernel_f32( - const struct nChw16c_across &J, - prop_kind_t prop_kind, - int use_h_parallel, - float A, - float K, - void *code_ptr = nullptr, - size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE) - : jit_generator(code_ptr, code_size) - , pk(prop_kind) - , use_h_parallelism(use_h_parallel) - , alpha(A) - , k(K) - { - this->preamble(); - - mov(src, ptr[param + 0]); - mov(dst, ptr[param + 8]); - if (pk != prop_kind::forward_inference) - { - mov(scratch0, ptr[param + 16]); - mov(scratch1, ptr[param + 24]); - } - is_first = J.version == -1 || J.version == -2; - is_last = J.version == +1 || J.version == -2; - is_single = J.version == 3; - - W = J.W; - HW = J.W*J.H; - int LSB = use_h_parallelism ? W : HW; - - sub(t, FWD_RBC*BUFFER_BLOCK); - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - vbroadcastss(zalpha, xalpha); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - vbroadcastss(zk, xk); - - if (is_first || is_single) { - vxorps(xmm2, xmm2, xmm2); - for(int irb = 0; irb < FWD_RBC; irb++) { - vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2); - } - } - if (is_last || is_single) { - vxorps(xmm2, xmm2, xmm2); - for(int irb = 0; irb < FWD_RBC; irb++) { - vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], - xmm2); - } - } - - int LSREST = LSB % FWD_RBC; - int LS = LSB - LSREST; - - Label lrn_loop; - - if (LS > 0) { - mov(hw, LS); - - L(lrn_loop); - { - compute_loop(FWD_RBC); - - add(src, FWD_RBC*vlen); - add(dst, FWD_RBC*vlen); - if (pk != prop_kind::forward_inference) - { - add(scratch0, FWD_RBC*vlen); - add(scratch1, FWD_RBC*vlen); - } - - for(int irb = 0; irb < FWD_RBC; irb++) - dec(hw); - cmp(hw, 0); - jne(lrn_loop, T_NEAR); - } - } - - compute_loop(LSREST); - - add(t, FWD_RBC*BUFFER_BLOCK); - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); - } -}; - -status_t jit_avx512_common_lrn_fwd_t::pd_t::init() { - using namespace prop_kind; - using namespace alg_kind; - - const memory_desc_wrapper data_d(src_md()); - bool ok = true - && mayiuse(avx512_common) - && is_fwd() - && !has_zero_dim_memory() - && everyone_is(data_type::f32, data_d.data_type()) - && data_d.ndims() == 4 - && data_d.dims()[1] % vsize == 0 - && attr()->has_default_values(); - if (!ok) return unimplemented; - - if (desc()->prop_kind == forward_training) { - dims_t ws_dims = { MB(), C(), H(), 2*W() }; - mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, - format_tag::nChw16c); - } - - bool args_ok_across = true - && desc()->alg_kind == lrn_across_channels - && desc()->local_size == 5 - && desc()->lrn_beta == 0.75 - && data_d.matches_tag(format_tag::nChw16c); - - return args_ok_across ? success : unimplemented; -} - -jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) - , ker_last_(nullptr) { - using namespace alg_kind; - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - const int ls = pd()->desc()->local_size; - const float alpha = pd()->desc()->lrn_alpha / ls; - const float k = pd()->desc()->lrn_k; - - auto pk = pd()->desc()->prop_kind; - - use_h_parallelism = H > 28 ? 1 : 0; - - if (C / vsize == 1) { - ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk, - use_h_parallelism, alpha, k); - } else { - ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk, - use_h_parallelism, alpha, k); - ker_first_ = new jit_avx512_common_lrn_kernel_f32( - nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k); - ker_last_ = new jit_avx512_common_lrn_kernel_f32( - nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k); - } -} - -jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t() -{ delete ker_; delete ker_first_; delete ker_last_; } - -void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const -{ - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); - - const int N = pd()->MB(); - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - - parallel(0, [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - const int C16 = C / vsize; - const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; - - balance211(work_amount, nthr, ithr, start, end); - if (use_h_parallelism) { - int n{0}, c16{0}, h{0}; - nd_iterator_init(start, n, N, c16, C16, h, H); - for (size_t iwork = start; iwork < end; ++iwork) { - auto offset = n*C*H*W + c16*H*W*vsize - + h*W*vsize; - auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize - + h*2*W*vsize; - auto ws_offset1 = ws_offset0 + W*vsize; - - jit_args_fwd_t args; - args.src = &src[offset]; - args.dst = &dst[offset]; - args.ws0 = &ws[ws_offset0]; - args.ws1 = &ws[ws_offset1]; - - if (C16 == 1) - (*ker_)(&args); - else if (c16 == 0) - (*ker_first_)(&args); - else if (c16 == C16 - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - nd_iterator_step(n, N, c16, C16, h, H); - } - } else { - int n{0}, c16{0}; - nd_iterator_init(start, n, N, c16, C16); - for (size_t iwork = start; iwork < end; ++iwork) { - auto offset = n*C*H*W + c16*H*W*vsize; - auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; - auto ws_offset1 = ws_offset0 + H*W*vsize; - - jit_args_fwd_t args; - args.src = &src[offset]; - args.dst = &dst[offset]; - args.ws0 = &ws[ws_offset0]; - args.ws1 = &ws[ws_offset1]; - - if (C16 == 1) - (*ker_)(&args); - else if (c16 == 0) - (*ker_first_)(&args); - else if (c16 == C16 - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - - nd_iterator_step(n, N, c16, C16); - } - } - }); -} - -struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32: - public jit_generator { - int HW, W; - bool is_first; - bool is_last; - bool is_single; - - Reg64 src = rax; - Reg64 diffsrc = r8; - Reg64 diffdst = r9; - Reg64 workspace0 = rdx; - Reg64 workspace1 = rsi; - Reg64 imm_addr64 = rbx; - - Zmm znalphabeta = zmm0; - Xmm xnalphabeta = xmm0; - - Reg64 param = abi_param1; - Reg64 t = rsp; - Reg64 hw = r10; - - int xws1_prev = 1; - int xdiffdst_prev = 2; - int zws1 = 1; - - int zsrc = 1; - int zdiffdst = 5; - int zdiffsrc = 6; - - int xws1_next = 1; - int xdiffdst_next = 3; - - int za = 1; - int zb = 2; - int zd = 3; - int ze = 4; - int zws0 = 2; - - float nalphabeta; - - int use_h_parallelism; - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) - - void (*ker)(jit_args_bwd_t *); - void operator()(jit_args_bwd_t *arg) { ker(arg); } - - enum { - prf0_offt = 1*BWD_RBC, - prf2_offt = 8*BWD_RBC - }; - - inline void compute_loop(int loop_size_param, int prefetchL1, - int prefetchL2) - { - // loop_size - param for IRB_LOOP macro - int loop_size = loop_size_param; - - auto xreg = [=](int irb, int i) { - return Xmm(irb*6 + i); - }; - - auto zreg = [=](int irb, int i) { - return Zmm(irb*6 + i); - }; - -// ---- prefetching ------------------------------------------- - if (!is_first && !is_single) { - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt - - 2 * HW) * vlen])); - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt - - HW) * vlen])); - } - - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen])); - if (prefetchL2) - IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen])); - - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen])); - - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen])); - - if (!is_last && !is_single) { - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt - + 2 * HW) * vlen])); - if (prefetchL2) - IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt - + 2 * HW) * vlen])); - - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt - + HW) * vlen])); - if (prefetchL2) - IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt - + HW) * vlen])); - } - if (prefetchL1) - IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen])); - if (prefetchL2) - IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen])); -// ----------------------------------------------------------- - - if (loop_size_param == 0) - return; - - if (!is_first && !is_single) { - IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb - - 2 * HW) * vlen + SRC_PREV_OFFSET])); - IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb - - HW) * vlen + SRC_PREV_OFFSET])); - IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev), - xreg(irb, xws1_prev))); - } - - IRB_LOOP(vmovups(zreg(irb, zws1), - EVEX_compress_addr(workspace1, irb*vlen))); - IRB_LOOP(vmovups(zreg(irb, zdiffdst), - EVEX_compress_addr(diffdst, irb*vlen))); - IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst), - zreg(irb, zws1))); - - if (!is_last && !is_single) { - IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb - + 2 * HW) * vlen])); - IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb - + HW) * vlen])); - IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next), - xreg(irb, xws1_next))); - } - - if (!is_first && !is_single) { - IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], - xreg(irb, xdiffdst_prev))); - } - IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), - zreg(irb, zdiffsrc))); - if (!is_last && !is_single) { - IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], - xreg(irb, xdiffdst_next))); - } - - IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE - 2*sizeof(float)))); - IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE - 1*sizeof(float)))); - IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE + 1*sizeof(float)))); - IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK - + XMM_SIZE + 2*sizeof(float)))); - IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), - zreg(irb, za))); - assert(zsrc == za); - IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen))); - IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), - zreg(irb, zb))); - IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), - zreg(irb, zd))); - IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), - zreg(irb, ze))); - IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta)); - - IRB_LOOP(vmovups(zreg(irb, zws0), - EVEX_compress_addr(workspace0, irb*vlen))); - IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst), - zreg(irb, zws0))); - IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc), - zreg(irb, zdiffdst))); - - Label unaligned_store, end_store; - test(diffsrc, vlen - 1); - jnz(unaligned_store, T_NEAR); - IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen), - zreg(irb, zdiffsrc))); - jmp(end_store, T_NEAR); - L(unaligned_store); { - IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen), - zreg(irb, zdiffsrc))); - } - L(end_store); - } - - jit_avx512_common_lrn_kernel_f32( - const struct nChw16c_across &J, - float A, - float B, - int use_h_parallel, - void *code_ptr = nullptr, - size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE) - : jit_generator(code_ptr, code_size) - , nalphabeta(-2*A*B) - , use_h_parallelism(use_h_parallel) - { - this->preamble(); - - mov(src, ptr[param + 0]); - mov(diffdst, ptr[param + 8]); - mov(workspace0, ptr[param + 16]); - mov(workspace1, ptr[param + 24]); - mov(diffsrc, ptr[param + 32]); - - W = J.W; - HW = J.H*J.W; - int LSB = this->use_h_parallelism ? W : HW; - - sub(t, BWD_RBC*BUFFER_BLOCK); - mov(imm_addr64, float2int(this->nalphabeta)); - movq(xnalphabeta, imm_addr64); - vbroadcastss(znalphabeta, xnalphabeta); - - is_first = J.version == -1 || J.version == -2; - is_last = J.version == +1 || J.version == +2; - is_single = J.version == 3; - - if (is_first || is_single) { - vxorps(xmm1, xmm1, xmm1); - for(int irb = 0; irb < BWD_RBC; irb++) { - vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1); - } - } - if (is_last || is_single) { - vxorps(xmm1, xmm1, xmm1); - for(int irb = 0; irb < BWD_RBC; irb++) { - vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1); - } - } - - int LSREST = LSB % BWD_RBC; - int LS = LSB - LSREST; - - Label lrn_loop; - - if (LS > 0) { - mov(hw, LS); - - L(lrn_loop); - { - compute_loop(BWD_RBC, 1, 1); - - add(src, BWD_RBC*vlen); - add(diffsrc, BWD_RBC*vlen); - add(diffdst, BWD_RBC*vlen); - add(workspace0, BWD_RBC*vlen); - add(workspace1, BWD_RBC*vlen); - - for(int irb = 0; irb < BWD_RBC; irb++) - dec(hw); - cmp(hw, 0); - jne(lrn_loop, T_NEAR); - } - } - - compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1); - - add(t, BWD_RBC*BUFFER_BLOCK); - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); - } - -}; - -status_t jit_avx512_common_lrn_bwd_t::pd_t::init() { - using namespace alg_kind; - - const memory_desc_wrapper data_d(src_md()); - bool ok = true - && mayiuse(avx512_common) - && !is_fwd() - && utils::everyone_is(data_type::f32, data_d.data_type()) - && !has_zero_dim_memory() - && data_d.ndims() == 4 - && data_d.dims()[1] % vsize == 0 - && attr()->has_default_values(); - if (!ok) return unimplemented; - - dims_t ws_dims = { MB(), C(), H(), 2*W() }; - mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, - format_tag::nChw16c); - - if (!compare_ws(hint_fwd_pd_)) return unimplemented; - - bool args_ok_across = true - && desc()->alg_kind == lrn_across_channels - && desc()->local_size == 5 - && desc()->lrn_beta == 0.75 - && data_d.matches_tag(format_tag::nChw16c); - - return args_ok_across ? success : unimplemented; -} - -jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) - , ker_last_(nullptr) { - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - const int ls = pd()->desc()->local_size; - const float alpha = pd()->desc()->lrn_alpha / ls; - const float beta = pd()->desc()->lrn_beta; - - use_h_parallelism = H > 28 ? 1 : 0; - - if (C / vsize == 1) { - ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), - alpha, beta, use_h_parallelism); - } else { - ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), - alpha, beta, use_h_parallelism); - ker_first_ = new jit_avx512_common_lrn_kernel_f32( - nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism); - ker_last_ = new jit_avx512_common_lrn_kernel_f32( - nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism); - } -} - -jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t() -{ delete ker_; delete ker_first_; delete ker_last_; } - -void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const -{ - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const int N = pd()->MB(); - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - - parallel(0, [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - const int C16 = C / vsize; - const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; - - balance211(work_amount, nthr, ithr, start, end); - if (use_h_parallelism) { - int n{0}, c16{0}, h{0}; - nd_iterator_init(start, n, N, h, H, c16, C16); - for (size_t iwork = start; iwork < end; ++iwork) { - auto offset = n*C*H*W + c16*H*W*vsize - + h*W*vsize; - auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize - + h*2*W*vsize; - auto ws_offset1 = ws_offset0 + W*vsize; - - jit_args_bwd_t args; - args.src = &src[offset]; - args.diff_dst = &diff_dst[offset]; - args.ws0 = &ws[ws_offset0]; - args.ws1 = &ws[ws_offset1]; - args.diff_src = &diff_src[offset]; - - if (C16 == 1) - (*ker_)(&args); - else if (c16 == 0) - (*ker_first_)(&args); - else if (c16 == C16 - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - nd_iterator_step(n, N, h, H, c16, C16); - } - } else { - int n{0}, c16{0}; - nd_iterator_init(start, n, N, c16, C16); - for (size_t iwork = start; iwork < end; ++iwork) { - auto offset = n*C*H*W + c16*H*W*vsize; - auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; - auto ws_offset1 = ws_offset0 + H*W*vsize; - - jit_args_bwd_t args; - args.src = &src[offset]; - args.diff_dst = &diff_dst[offset]; - args.ws0 = &ws[ws_offset0]; - args.ws1 = &ws[ws_offset1]; - args.diff_src = &diff_src[offset]; - - if (C16 == 1) - (*ker_)(&args); - else if (c16 == 0) - (*ker_first_)(&args); - else if (c16 == C16 - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - - nd_iterator_step(n, N, c16, C16); - } - } - }); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp deleted file mode 100644 index 37fbb9b3e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp +++ /dev/null @@ -1,96 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP -#define CPU_JIT_AVX512_COMMON_LRN_HPP - -#include "c_types_map.hpp" - -#include "cpu_isa_traits.hpp" -#include "cpu_lrn_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_lrn_fwd_pd_t { - using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), - jit_avx512_common_lrn_fwd_t); - - status_t init(); - }; - - jit_avx512_common_lrn_fwd_t(const pd_t *apd); - ~jit_avx512_common_lrn_fwd_t(); - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - int use_h_parallelism; - struct jit_avx512_common_lrn_kernel_f32; - jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; -}; - -struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_lrn_bwd_pd_t { - using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), - jit_avx512_common_lrn_bwd_t); - - status_t init(); - }; - - jit_avx512_common_lrn_bwd_t(const pd_t *apd); - ~jit_avx512_common_lrn_bwd_t(); - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - int use_h_parallelism; - struct jit_avx512_common_lrn_kernel_f32; - jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp deleted file mode 100644 index c58d3fa0a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp +++ /dev/null @@ -1,1103 +0,0 @@ -/******************************************************************************* - * Copyright 2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_core_fp32_wino_conv_2x3.hpp" -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_kind; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// -struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS( - jit_avx512_core_fp32_wino_conv_2x3_src_trans_t) - - jit_conv_conf_2x3_wino_t jcp; - - struct call_params_t { - const void *src; - const void *wino_src; - const void *v_y_masks; - const void *v_x_masks; - }; - void (*ker_)(const call_params_t *); - - jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( - jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) - : jcp(ajcp) { - generate(); - ker_ = - reinterpret_cast(const_cast(getCode())); - } - - void generate(); - - Zmm vreg_inp(int i) { - assert(i < jcp.alpha * jcp.alpha); - return Zmm(31 - i); - } - - Zmm vreg_tmp(int i) { - assert(i < jcp.alpha * jcp.alpha); - return Zmm(15 - i); - } - - Zmm vreg_out(int i) { - assert(i < jcp.alpha * jcp.alpha); - return Zmm(31 - i); - } - - Opmask y_mask = Opmask(1); - Opmask r_mask = Opmask(2); - Opmask x_mask(int id) { - assert (id < 4); - return Opmask(3 + id); - } - - Reg64 reg_ptr_v_y_masks = r12; - Reg64 reg_ptr_v_x_masks = r11; - - Reg64 reg_aux_ptr_src = r10; - Reg64 reg_aux_ptr_dst = r9; - - Reg64 reg_ic_block = r8; - -}; - -void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() { - Label ic_block_label; - - const int load_block = 16; - int out_offset = 0, inp_offset = 0; - preamble(); - -#define READ_PARAM(reg, field) \ - mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) - READ_PARAM(reg_aux_ptr_src, src); - READ_PARAM(reg_aux_ptr_dst, wino_src); - READ_PARAM(reg_ptr_v_y_masks, v_y_masks); - READ_PARAM(reg_ptr_v_x_masks, v_x_masks); -#undef READ_PARAM - - for (int i = 0; i < jcp.alpha; i++) { - kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); - } - mov(reg_ic_block, jcp.ic / load_block); - L(ic_block_label); - { - for (int y = 0; y < jcp.alpha; y++) { - kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); - for (int x = 0; x < jcp.alpha; x++) { - Zmm zmm = vreg_inp(y * jcp.alpha + x); - - vxorps(zmm, zmm, zmm); - kandw(r_mask, y_mask, x_mask(x)); - inp_offset = sizeof(float) - * ((-jcp.t_pad + y) * jcp.iw * load_block - + (-jcp.l_pad + x) * load_block); - vmovups(zmm | r_mask, - EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); - } - } - for (int y = 0; y < jcp.alpha; y++) { - vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0), - vreg_inp(y * jcp.alpha + 2)); - vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1), - vreg_inp(y * jcp.alpha + 2)); - vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2), - vreg_inp(y * jcp.alpha + 1)); - vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1), - vreg_inp(y * jcp.alpha + 3)); - } - for (int x = 0; x < jcp.alpha; x++) { - vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0), - vreg_tmp(x + jcp.alpha * 2)); - vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), - vreg_tmp(x + jcp.alpha * 2)); - vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2), - vreg_tmp(x + jcp.alpha * 1)); - vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), - vreg_tmp(x + jcp.alpha * 3)); - } - - for (int i = 0; i < 16; i++) { - out_offset = sizeof(float) * (jcp.inp_stride * i); - vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), - vreg_out(i)); - } - - add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block); - add(reg_aux_ptr_dst, sizeof(float) * load_block); - } - dec(reg_ic_block); - cmp(reg_ic_block, 0); - jg(ic_block_label, T_NEAR); - postamble(); -} - -/// DST TRANSFORMS ///////////////////////////////////////////////////////////// -struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS( - jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t) - - jit_conv_conf_2x3_wino_t jcp; - const primitive_attr_t &attr_; - - struct call_params_t { - const void *wino_dst; - const void *dst; - const void *v_y_masks; - const void *v_x_masks; - - const void *bias; - const void *scales; - }; - void (*ker_)(const call_params_t *); - - jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( - jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr) { - generate(); - ker_ = reinterpret_cast( - const_cast(getCode())); - } - - void generate(); - bool maybe_relu(int position); - - Zmm vreg_inp(int i) { // 16 - assert(i < jcp.alpha * jcp.alpha); - return Zmm(31 - i); - } - - Zmm vreg_stg(int id) { // 8 - const int id_reg_stg = jcp.alpha * jcp.alpha + id; - assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); - return Zmm(31 - id_reg_stg); - } - - Zmm vreg_out(int id) { // 4 - const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; - assert(id_reg_out < jcp.alpha * jcp.alpha + 12); - return Zmm(31 - id_reg_out); - } - - Zmm vreg_tmp(int id) { // 2 - const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; - assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); - return Zmm(31 - id_reg_tmp); - } - - Zmm vreg_zero = Zmm(0); - Zmm vreg_prev_dst = Zmm(0); - Zmm vreg_bias = Zmm(2); - - Opmask y_mask = Opmask(1); - Opmask r_mask = Opmask(2); - Opmask x_mask(int id) { - assert (id < 4); - return Opmask(3 + id); - } - - Reg64 reg_ptr_v_y_masks = r12; - Reg64 reg_ptr_v_x_masks = r11; - - Reg64 reg_aux_ptr_src = r10; - Reg64 reg_aux_ptr_dst = r9; - - Reg64 reg_oc_block = r8; - - Reg64 reg_ptr_bias = rbx; - Reg64 reg_ptr_scales = abi_not_param1; - Reg64 reg_ptr_sum_scale = rdx; -}; - -bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) { - using namespace primitive_kind; - const auto &p = attr_.post_ops_; - - if (position == 0) { - /* relu before sum */ - return false - || p.contain(eltwise, 0); - } else if (position == 1) { - /* relu after sum */ - const int sum_idx = p.contain(sum, 0) - ? 0 : (p.contain(sum, 1) ? 1 : -1); - if (sum_idx == -1) - return false; - - return false - || p.contain(eltwise, sum_idx + 1); - } - - return false; -} - -void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() { - Label oc_block_label; - - const int load_block = 16; - - auto loop_body = [=]() { - const auto &p = attr_.post_ops_; - const int sum_idx = p.find(primitive_kind::sum); - const float *p_sum_scale = (sum_idx != -1) - ? &p.entry_[sum_idx].sum.scale - : nullptr; - if (p_sum_scale && *p_sum_scale != 1.f) - mov(reg_ptr_sum_scale, (size_t)p_sum_scale); - - for (int i = 0; i < 16; i++) { - int internal_offset = sizeof(float) * jcp.out_stride * i; - vmovups(vreg_inp(i), - EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); - } - for (int y = 0; y < jcp.alpha; y++) { - vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1)); - vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2)); - - vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2)); - vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3)); - } - for (int x = 0; x < jcp.m; x++) { - vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1)); - vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2)); - - vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2)); - vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3)); - } - - - if (jcp.with_bias) { - auto bias_addr = ptr [ reg_ptr_bias ]; - vmovups(vreg_bias, bias_addr); - } - for (int y = 0; y < jcp.m; y++) { - kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]); - for (int x = 0; x < jcp.m; x++) { - kandw(r_mask, y_mask, x_mask(x)); - - int i = y * jcp.m + x; - int offset = sizeof(float) * - (y * jcp.ow * jcp.oc_block + x * jcp.oc_block); - Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); - - Zmm zmm = vreg_out(i); - if (jcp.with_bias) - vaddps(zmm, zmm, vreg_bias); - vmulps(zmm, zmm, ptr [reg_ptr_scales]); - - if (maybe_relu(0)) { - vxorps(vreg_zero, vreg_zero, vreg_zero); - vmaxps(zmm, vreg_zero, zmm); - } - if (p_sum_scale) { // post_op: sum - vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); - vmovups(vreg_prev_dst | r_mask, addr); - if (*p_sum_scale == 1.f) - vaddps(zmm, vreg_prev_dst); - else - vfmadd231ps(zmm, vreg_prev_dst, - zword_b[reg_ptr_sum_scale]); - } - if (maybe_relu(1)) { - vxorps(vreg_zero, vreg_zero, vreg_zero); - vmaxps(zmm, vreg_zero, zmm); - } - - vmovups(addr, zmm | r_mask); - } - } - }; - - preamble(); - -#define READ_PARAM(reg, field) \ - mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) - READ_PARAM(reg_aux_ptr_src, wino_dst); - READ_PARAM(reg_aux_ptr_dst, dst); - READ_PARAM(reg_ptr_v_y_masks, v_y_masks); - READ_PARAM(reg_ptr_v_x_masks, v_x_masks); - READ_PARAM(reg_ptr_bias, bias); - READ_PARAM(reg_ptr_scales, scales); -#undef READ_PARAM - - for (int i = 0; i < jcp.alpha * jcp.alpha; i++) - vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i)); - - for (int i = 0; i < jcp.alpha; i++) - kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); - - int oc_blocks = 1; - oc_blocks = jcp.oc / load_block; - mov(reg_oc_block, oc_blocks); - L(oc_block_label); - { - loop_body(); - add(reg_aux_ptr_src, sizeof(float) * load_block); - add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block); - - add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); - add(reg_ptr_bias, jcp.typesize_bia * load_block); - } - dec(reg_oc_block); - cmp(reg_oc_block, 0); - jg(oc_block_label, T_NEAR); - - sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); - sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block); - - postamble(); - -} - -/// GEMM kernel //////////////////////////////////////////////////////////////// -struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t) - jit_conv_conf_2x3_wino_t jcp; - - struct call_params_t { - const void *src; - const void *dst; - const void *wei; - const void *dst_b; - }; - void (*ker_)(const call_params_t *); - - void generate(); - static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, - const primitive_attr_t &attr); - - jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( - jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) - : jcp(ajcp) { - generate(); - ker_ = reinterpret_cast( - const_cast(getCode())); - } - - static status_t init_conf( - jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, - memory_desc_t &src_md, memory_desc_t &weights_md, - memory_desc_t &dst_md, memory_desc_t &bias_md, - const primitive_attr_t &attr, - memory_desc_t& expect_wei_md); - - Zmm vreg_out(int n, int m) { - const int id_reg_out = n * jcp.m_block + m; - assert(id_reg_out < jcp.n2_block * jcp.m_block); - return Zmm(31 - id_reg_out); - } - Zmm vreg_wei(int i) { - assert (31 - jcp.n2_block * jcp.m_block - i > 1); - return Zmm(31 - jcp.n2_block * jcp.m_block - i); - } - - Zmm vreg_src = Zmm(0); - Zmm vreg_one = Zmm(1); - Zmm vreg_tmp = Zmm(2); - - Reg64 reg_ptr_src = r15; - - Reg64 reg_aux_dst = r12; - Reg64 reg_aux_dst2 = r11; - Reg64 reg_aux_wei = r10; - Reg64 reg_aux_wei2 = r9; - Reg64 reg_aux_src = r8; - Reg64 reg_aux_src2 = rax; - - Reg64 reg_mb = rbx; - Reg64 reg_nnb = rdx; - Reg64 reg_K = rsi; - -}; - -bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok( - jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { - using namespace primitive_kind; - const auto &p = attr.post_ops_; - - auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; - - switch (p.len_) { - case 0: return true; - case 1: return is_relu(0) || p.contain(sum, 0); - case 2: return (p.contain(sum, 0) && is_relu(1)) || - (p.contain(sum, 1) && is_relu(0)); - case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); - default: return false; - } - - return false; -} - -void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() { - Label nnb_loop_label, K_loop_label, mb_loop_label; - - preamble(); -#define READ_PARAM(reg, field) \ - mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) - READ_PARAM(reg_ptr_src, src); - READ_PARAM(reg_aux_dst, dst); - READ_PARAM(reg_aux_wei, wei); -#undef READ_PARAM - - if (!jcp.small_mb) { - mov(reg_nnb, jcp.n_chunks); - L(nnb_loop_label); - } - mov(reg_aux_dst2, reg_aux_dst); - mov(reg_aux_src, reg_ptr_src); - mov(reg_mb, jcp.M / jcp.m_block); - L(mb_loop_label); - { - int nb2 = 0; - for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { - for (int m = 0; m < jcp.m_block; m++) { - vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m)); - } - } - mov(reg_aux_src2, reg_aux_src); - mov(reg_aux_wei2, reg_aux_wei); - - mov(reg_K, jcp.k_chunks); - L(K_loop_label); { - int wei_offset = 0; - for (int _i = 0; _i < jcp.k2_block; _i++) { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - if (jcp.small_mb) { - int wei_offset = sizeof(float) - * ((nb2 * jcp.nb_ic * jcp.ic_block - * jcp.oc_block) - + _i * jcp.oc_block); - vmovups(vreg_wei(nb2), - EVEX_compress_addr(reg_aux_wei2, wei_offset)); - } else { - vmovups(vreg_wei(nb2), - EVEX_compress_addr(reg_aux_wei2, - sizeof(float) * wei_offset)); - wei_offset += jcp.oc_block; - } - } - for (int m = 0; m < jcp.m_block; m++) { - int inp_offset = sizeof(float) * (m * jcp.K + _i); - if (jcp.n2_block > 1) { - vbroadcastss(vreg_src, - EVEX_compress_addr(reg_aux_src2, inp_offset)); - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) - vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2), - vreg_src); - } else { - vfmadd231ps(vreg_out(0, m), vreg_wei(0), - EVEX_compress_addr(reg_aux_src2, inp_offset, true)); - } - } - } - add(reg_aux_src2, sizeof(float) * jcp.ic_block); - if (jcp.small_mb) - add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block); - else - add(reg_aux_wei2, - sizeof(float) * jcp.k2_block * jcp.n2_block - * jcp.oc_block); - } - dec(reg_K); - cmp(reg_K, 0); - jg(K_loop_label, T_NEAR); - - for (int m = 0; m < jcp.m_block; m++) { - int nb2 = 0; - for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { - int offset = sizeof(float) * - (m * jcp.N + nb2 * jcp.oc_block); - vmovups(EVEX_compress_addr(reg_aux_dst2,offset), - vreg_out(nb2, m)); - } - } - add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K); - add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N); - } - dec(reg_mb); - cmp(reg_mb, 0); - jg(mb_loop_label, T_NEAR); - - if (!jcp.small_mb) { - add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block); - add(reg_aux_wei, - sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block - * jcp.oc_block); - - dec(reg_nnb); - cmp(reg_nnb, 0); - jg(nnb_loop_label, T_NEAR); - } - postamble(); -} - -namespace { -bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { - return jcp.mb >= 4; -} -} - -status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf( - jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, - memory_desc_t &src_md, memory_desc_t &wei_md, - memory_desc_t &dst_md, memory_desc_t &bias_md, - const primitive_attr_t &attr, memory_desc_t &expect_wei_md) { - const memory_desc_wrapper src_d(&src_md); - const memory_desc_wrapper wei_d(&wei_md); - const memory_desc_wrapper dst_d(&dst_md); - const memory_desc_wrapper bias_d(&bias_md); - - const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; - - jcp.nthr = mkldnn_get_max_threads(); - - jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = dst_d.dims()[2]; - jcp.ow = dst_d.dims()[3]; - jcp.kh = wei_d.dims()[with_groups + 2]; - jcp.kw = wei_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.b_pad = cd.padding[1][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.r_pad = cd.padding[1][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - jcp.m = 2; - jcp.r = 3; - jcp.alpha = jcp.m + jcp.r - 1; - int simdw = 16; - - format_tag_t dat_tag = format_tag::nChw16c; - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - if (jcp.src_tag != dat_tag) return status::unimplemented; - if (jcp.dst_tag != dat_tag) return status::unimplemented; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - bool ok_to_pad_channels = jcp.ngroups == 1; - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simdw); - jcp.ic = rnd_up(jcp.ic, simdw); - } - - jcp.ver = ver_avx512_core; - if (!(mayiuse(avx512_core))) - return status::unimplemented; - - if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, - is_winograd_faster_than_direct(jcp))) - return status::unimplemented; - - if (src_d.data_type() != data_type::f32) - return status::unimplemented; - if (wei_d.data_type() != data_type::f32) - return status::unimplemented; - if (dst_d.data_type() != data_type::f32) - return status::unimplemented; - - jcp.ic_block = simdw; - jcp.oc_block = simdw; - - bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 - && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 - && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0 - && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad - && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0 - && jcp.l_pad < 2 && jcp.l_pad >= 0; - if (!ok) - return status::unimplemented; - - const int L2_cap = get_cache_size(2, true) / sizeof(float); - const int L3_capacity = get_cache_size(3, false) / sizeof(float); - int a = jcp.alpha; - int aa = a * a; - int mb = jcp.mb; - int ic = jcp.ic; - int oc = jcp.oc; - int ih = jcp.ih; - int iw = jcp.iw; - auto wei_sz = (float)aa * ic * oc; - auto inp_sz = (float)mb * ih * iw * ic; - auto sp_sz = (float)mb * ih * iw; - - /* Heuristics here. Numbers '28','196' is an observation from data. */ - if (wei_sz / inp_sz > 5) - jcp.small_mb = true; - else - jcp.small_mb = false; - - if (mb > nstl::min(jcp.nthr, 28) - || (!jcp.small_mb - && (wei_sz >= 0.9f * L2_cap - || inp_sz > L2_cap * jcp.nthr + L3_capacity)) - || (jcp.small_mb && sp_sz > 196)) - return status::unimplemented; - - jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; - jcp.dst_dt = cd.dst_desc.data_type; - - jcp.typesize_bia - = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; - - jcp.nb_oc = jcp.oc / jcp.oc_block; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - const int skx_free_regs = 30; - - auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block, - int &n2_block, float ®_eff) { - M = (xb * yb) / jcp.alpha; - int max_m_block = m_block = nstl::min(M, skx_free_regs); - int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs); - reg_eff = 0; - for (int im = max_m_block; im > 0; im--) { - for (int in2 = max_n2_block; in2 > 0; in2--) { - int used_regs = in2 * im + in2; - float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f; - if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs - || cur_reg_eff <= reg_eff) - continue; - reg_eff = cur_reg_eff; - m_block = im; - n2_block = in2; - } - } - }; - - int oh = jcp.oh; - int ow = jcp.ow; - int nb_oc = jcp.nb_oc; - int Z = ic + oc; - int Y = ic * oc; - const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float); - - /* Selecting xb and yb blocking */ - int min_yb = jcp.alpha; - int min_xb = jcp.alpha; - int max_yb = nstl::max(min_yb, rnd_up(ih, 2)); - int max_xb = nstl::max(min_xb, rnd_up(iw, 2)); - float best_eff = 0.f; - for (int ix = max_xb; ix >= min_xb; ix -= 2) { - if (rnd_up(ow, ix) < iw - 2) - continue; - for (int iy = max_yb; iy >= min_yb; iy -= 2) { - if (rnd_up(oh, iy) < ih - 2) - continue; - int ex_y = rnd_up(oh, iy); - int ex_x = rnd_up(ow, ix); - float work_eff = (float)(ih * iw) / (ex_y * ex_x); - - int M, m_block, n2_b; - float reg_eff, thr_eff, par_eff, mem_eff, req_mem; - - find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff); - - /* outer parallelization */ - int nblocks = mb * div_up(ih, iy) * div_up(iw, ix); - thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); - - mem_eff = 1.f; - req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y; - if (req_mem > L2_cap / 2) { - if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7) - mem_eff /= (n2_b + 1) / 2.f; - else - mem_eff /= (n2_b + 1) / 3.f; - } - - float outer_eff = thr_eff + work_eff + reg_eff + mem_eff; - - /* inner parallelization */ - int bsz = iy * ix / a; - int gemmw = aa * (nb_oc / n2_b); - int bsz_r = rnd_up(bsz, jcp.nthr); - int gemmw_r = rnd_up(gemmw, jcp.nthr); - thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y); - - req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic; - mem_eff = nstl::min(1.f, L2_cap / req_mem); - int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr)); - int oc_per_thr = - nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr)); - req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z; - if (req_mem > L2_cap) - mem_eff = 0.1f; - par_eff = 1 / (2.f * nblocks); - - float inner_eff = thr_eff + work_eff + mem_eff + par_eff; - - float eff = jcp.small_mb ? inner_eff : outer_eff; - if (eff > best_eff) { - best_eff = eff; - jcp.yb = iy; - jcp.xb = ix; - jcp.M = M; - jcp.m_block = m_block; - jcp.n2_block = n2_b; - } - } - } - - assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); - - jcp.inp_stride = jcp.M * jcp.ic; - jcp.out_stride = jcp.M * jcp.oc; - jcp.wei_stride = jcp.ic * jcp.oc; - jcp.bia_stride = jcp.oc; - - jcp.N = jcp.oc; - jcp.K = jcp.ic; - - jcp.n_block = jcp.oc_block; - jcp.k_block = jcp.ic_block; - - assert(jcp.M % jcp.m_block == 0); - assert(jcp.nb_oc % jcp.n2_block == 0); - - jcp.n_chunks = jcp.nb_oc / jcp.n2_block; - jcp.k2_block = jcp.ic_block; - jcp.k_chunks = jcp.K / jcp.k2_block; - - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); - - /* re-create weights primitive descriptor - and set weights wino_blocking */ - expect_wei_md.format_kind = format_kind::wino; - expect_wei_md.data_type = data_type::f32; - mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; - wd.wino_format - = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo; - wd.r = jcp.r; - wd.alpha = jcp.alpha; - wd.ic = jcp.ic; - wd.oc = jcp.oc; - wd.ic_block = jcp.ic_block; - wd.oc_block = jcp.oc_block; - wd.oc2_block = jcp.n2_block; - wd.ic2_block = 1; - wd.adj_scale = 1.f; - size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; - wd.size = max_size; - - return status::success; -} -//////////////////////////////////////////////////////////////////////////////// - -status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t - ::pd_t::jit_conf(memory_desc_t& expect_wei_md) { - return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf( - jcp_, *this->desc(), this->src_md_, this->weights_md_, - this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md); -} - -jit_avx512_core_fp32_wino_conv_2x3_fwd_t:: - jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) -{ - kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( - pd()->jcp_, *pd()->attr()); - src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( - pd()->jcp_, *pd()->attr()); - dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( - pd()->jcp_, *pd()->attr()); -} - -jit_avx512_core_fp32_wino_conv_2x3_fwd_t - ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() { - delete kernel_; - delete src_trans_; - delete dst_trans_; -} - -void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN( - const float *src, const float *wei, const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const -{ - const auto &jcp = kernel_->jcp; - const auto &oscales = pd()->attr()->output_scales_; - - const size_t wino_size_offset = - (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb); - const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16; - const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16; - - if (pd()->wants_padded_bias()) { - auto padded_bias = scratchpad.get(key_conv_padded_bias); - utils::array_copy(padded_bias, bia, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bia = padded_bias; - } - - auto ptr_V = scratchpad.get(key_wino_V); - auto ptr_M = scratchpad.get(key_wino_M); - - parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb), - [&](int mb, int tile_y_b, int tile_x_b) { - int tile_y = tile_y_b * jcp.yb; - int tile_x = tile_x_b * jcp.xb; - - int ithr = mkldnn_get_thread_num(); - auto wino_src = ptr_V + size_wino_src * ithr; - auto wino_dst = ptr_M + size_wino_dst * ithr; - - auto src_trans_p = - jit_avx512_core_fp32_wino_conv_2x3_src_trans_t - ::call_params_t(); - auto dst_trans_p = - jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t - ::call_params_t(); - auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: - call_params_t(); - - /* transformation of input tensor to winograd domain */ - for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { - for (int x_in_block = 0; x_in_block < jcp.xb; - x_in_block += 2) { - - unsigned short v_y_masks[4], v_x_masks[4]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) - + (x_in_block / 2); - - int v_ys = nstl::max(0, jcp.t_pad - y); - int v_ye = nstl::min(jcp.alpha, - nstl::max(0, jcp.ih + jcp.t_pad - y)); - - int v_xs = nstl::max(0, jcp.l_pad - x); - int v_xe = nstl::min(jcp.alpha, - nstl::max(0, jcp.iw + jcp.l_pad - x)); - -#pragma unroll(4) - for (int i = 0; i < jcp.alpha; i++) { - v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; - v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; - } - auto local_s = src - + mb * jcp.nb_ic * jcp.ih * jcp.iw - * jcp.ic_block - + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; - auto local_w = wino_src + m * jcp.ic; - - src_trans_p.src = local_s; - src_trans_p.wino_src = local_w; - src_trans_p.v_y_masks = v_y_masks; - src_trans_p.v_x_masks = v_x_masks; - - src_trans_->ker_(&src_trans_p); - } - } - /* gemms */ - for (int tile_ij = 0; tile_ij < 16; tile_ij++) { - int offset = (tile_ij + ithr) % 16; - gemm_p.src = wino_src + jcp.inp_stride * offset; - gemm_p.dst = wino_dst + jcp.out_stride * offset; - gemm_p.wei = wei + jcp.wei_stride * offset; - - kernel_->ker_(&gemm_p); - } - - /* transformation from winograd domain to output tensor */ - for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { - for (int x_in_block = 0; x_in_block < jcp.xb; - x_in_block += 2) { - unsigned short v_y_masks[2], v_x_masks[2]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) - + (x_in_block / 2); - -#pragma unroll(2) - for (int i = 0; i < jcp.m; i++) { - v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; - v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; - } - auto local_d = dst - + mb * jcp.nb_oc * jcp.oh * jcp.ow - * jcp.oc_block - + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; - auto local_w = wino_dst + m * jcp.oc; - - auto scales = oscales.scales_; - dst_trans_p.dst = local_d; - dst_trans_p.wino_dst = local_w; - dst_trans_p.v_y_masks = v_y_masks; - dst_trans_p.v_x_masks = v_x_masks; - - dst_trans_p.scales = scales; - dst_trans_p.bias = bia; - - dst_trans_->ker_(&dst_trans_p); - } - } - }); -} - -void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb( - const float *src, const float *wei, const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const -{ - const auto &jcp = kernel_->jcp; - const auto &oscales = pd()->attr()->output_scales_; - - if (pd()->wants_padded_bias()) { - auto padded_bias = scratchpad.get(key_conv_padded_bias); - utils::array_copy(padded_bias, bia, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bia = padded_bias; - } - - auto ptr_V = scratchpad.get(key_wino_V); - auto ptr_M = scratchpad.get(key_wino_M); - - for (int mb = 0; mb < jcp.mb; mb++) { - for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { - for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { - /* transformation of input tensor to winograd domain */ - parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), - [&](int y_in_block_b, int x_in_block_b) { - int y_in_block = y_in_block_b * 2; - int x_in_block = x_in_block_b * 2; - - auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t :: - call_params_t(); - - unsigned short v_y_masks[4], v_x_masks[4]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - - int v_ys = nstl::max(0, jcp.t_pad - y); - int v_ye = nstl::min( - jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); - - int v_xs = nstl::max(0, jcp.l_pad - x); - int v_xe = nstl::min( - jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); - -#pragma unroll(4) - for (int i = 0; i < jcp.alpha; i++) { - v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; - v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; - } - auto local_s = src - + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block - + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; - auto local_w = ptr_V + m * jcp.ic; - - src_trans_p.src = local_s; - src_trans_p.wino_src = local_w; - src_trans_p.v_y_masks = v_y_masks; - src_trans_p.v_x_masks = v_x_masks; - - src_trans_->ker_(&src_trans_p); - }); - - /* gemms */ - parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { - auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: - call_params_t(); - - gemm_p.src = ptr_V + jcp.inp_stride * tile_ij; - gemm_p.dst = ptr_M + jcp.out_stride * tile_ij - + nnb * jcp.n2_block * jcp.n_block; - gemm_p.wei = wei + jcp.wei_stride * tile_ij - + nnb * jcp.n2_block * jcp.n_block * jcp.K; - - kernel_->ker_(&gemm_p); - }); - - /* transformation from winograd domain to output tensor */ - - parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), - [&](int y_in_block_b, int x_in_block_b) { - int y_in_block = y_in_block_b * 2; - int x_in_block = x_in_block_b * 2; - - auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t :: - call_params_t(); - - unsigned short v_y_masks[2], v_x_masks[2]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - -#pragma unroll(2) - for (int i = 0; i < jcp.m; i++) { - v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; - v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; - } - auto local_d = dst - + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block - + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; - auto local_w = ptr_M + m * jcp.oc; - - auto scales = oscales.scales_; - dst_trans_p.dst = local_d; - dst_trans_p.wino_dst = local_w; - dst_trans_p.v_y_masks = v_y_masks; - dst_trans_p.v_x_masks = v_x_masks; - - dst_trans_p.scales = scales; - dst_trans_p.bias = bia; - - dst_trans_->ker_(&dst_trans_p); - }); - }}} -} - -} // namespace cpu -} // namespace impl -} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp deleted file mode 100644 index 7e38b07f5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp +++ /dev/null @@ -1,144 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP -#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP - -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_primitive_conf.hpp" -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t; -struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t; -struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t; - -struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""), - jit_avx512_core_fp32_wino_conv_2x3_fwd_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::forward_inference - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && set_default_formats(); - if (!ok) return status::unimplemented; - - memory_desc_t expect_wei_md = *weights_md(); - status_t jit_conf_result = jit_conf(expect_wei_md); - if (jit_conf_result != status::success) return jit_conf_result; - set_default_alg_kind(alg_kind::convolution_winograd); - - if (weights_md_.format_kind == format_kind::any) - weights_md_ = expect_wei_md; - if (weights_md_ != expect_wei_md) - return status::unimplemented; - - init_scratchpad(); - - return status::success; - } - - jit_conv_conf_2x3_wino_t jcp_; - - protected: - status_t jit_conf(memory_desc_t& expect_wei_md); - - void init_scratchpad() { - using namespace memory_tracking::names; - - auto scratchpad = scratchpad_registry().registrar(); - - int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb; - - size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr; - scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K); - - size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr; - scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K); - - if (wants_padded_bias()) { - assert(jcp_.ngroups == 1); - scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc); - } - } - - bool set_default_formats() { - using namespace format_tag; - return set_default_formats_common(nChw16c, any, nChw16c); - } - }; - - jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd); - ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t(); - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); - auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); - auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); - - if (pd()->jcp_.small_mb) - execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx)); - else - execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx)); - - return status::success; - } - -private: - void execute_forward_small_mb(const float *src, const float *wei, - const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const; - void execute_forward_mbN(const float *src, const float *wei, - const float *bia, float *dst, - const memory_tracking::grantor_t &scratchpad) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_; - jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_; - jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp deleted file mode 100644 index 96325e3ad..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp +++ /dev/null @@ -1,1020 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifdef __INTEL_COMPILER -#include -#endif - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_core_fp32_wino_conv_4x3.hpp" - -#ifndef _MSC_VER -#define pragma_unroll _Pragma("unroll") -#else -#define pragma_unroll -#endif - - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t -::weight_transform_data(const jit_conv_winograd_conf_t &jcp, - float *wp, float *twp) const -{ - float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f, - 1.13777777777778f, 0.430252100840336f, 0.179271708683473f}; - const int kh = 3; - const int kw = 3; - float Fw[alpha][alpha][simd_w][simd_w]; - float F[kh][kw][simd_w][simd_w]; - float T[alpha][3][simd_w]; - auto p = jit_wino_transform_call_s(); - - p.src = wp; - p.dst = twp; - p.G = G; - p.M = F; - p.Mw = Fw; - p.T = T; - - kernel_->weights_transform_data_ker(&p); -} - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t::output_transform_data -(int image, const jit_conv_winograd_conf_t &jcp, - const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const { - - float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; - float Ow[alpha][alpha][simd_w]; - float O[tile_size][tile_size][simd_w]; - float T[tile_size][alpha][simd_w]; - - auto p = jit_wino_transform_call_s(); - p.src = toutp; - p.dst = pout_b; - p.G = G; - p.M = O; - p.Mw = Ow; - p.T = T; - p.bias = bias; - - int tile_base_index = image * jcp.itiles * jcp.jtiles; - int tile_block_ur = tile_base_index % jcp.tile_block_ur; - int nb_tile_block_ur = - (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; - int tile_block = - (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; - - for (int tj = 0; tj < jcp.jtiles; tj++) { - for (int ti = 0; ti < jcp.itiles; ti++) { - - p.tile_block_ur = tile_block_ur; - p.nb_tile_block_ur = nb_tile_block_ur; - p.tile_block = tile_block; - p.tj = tj; - p.ti = ti; - - kernel_->output_transform_data_ker(&p); - - tile_block_ur++; - if (tile_block_ur >= jcp.tile_block_ur) { - tile_block_ur = 0; - nb_tile_block_ur++; - } - if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { - nb_tile_block_ur = 0; - tile_block++; - } - } - } -} - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t -::output_transform_tileblock_data(int tile_block, - const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, - float *toutp, float *outp, float *bias) const { - - float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; - float Ow[alpha][alpha][simd_w]; - float O[tile_size][tile_size][simd_w]; - float T[tile_size][alpha][simd_w]; - - auto p = jit_wino_transform_call_s(); - p.src = toutp; - p.dst = outp; - p.G = G; - p.M = O; - p.Mw = Ow; - p.T = T; - p.bias = bias; - - int outw = is_fwd ? jcp.ow : jcp.iw; - int outh = is_fwd ? jcp.oh : jcp.ih; - - int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; - - for (int nb_tile_block_ur = 0; - nb_tile_block_ur < jcp.nb_tile_block_ur; - nb_tile_block_ur++) { - - for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; - tile_block_ur++) { - int img = tile_index / (jcp.jtiles * jcp.itiles); - int ti = tile_index % jcp.itiles; - int tj = (tile_index / jcp.itiles) % jcp.jtiles; - - p.tile_block_ur = tile_block_ur; - p.nb_tile_block_ur = nb_tile_block_ur; - p.tile_block = tile_block; - p.tj = tj; - p.ti = ti; - p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block) - * outh * outw * jcp.dimM_simd_block; - - kernel_->output_transform_data_ker(&p); - - tile_index++; - } - } -} - - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t - ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, - float *inp, float *tinp) const -{ - float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, - 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; - - float Iw[alpha][alpha][simd_w]; - float I[alpha][alpha][simd_w]; - float T[alpha][alpha][simd_w]; - - auto p = jit_wino_transform_call_s(); - - p.src = inp; - p.dst = tinp; - p.G = G; - p.M = I; - p.Mw = Iw; - p.T = T; - - int tile_base_index = image * jcp.itiles * jcp.jtiles; - int tile_block_ur = tile_base_index % jcp.tile_block_ur; - int nb_tile_block_ur = - (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; - int tile_block = - (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; - - for (int tj = 0; tj < jcp.jtiles; tj++) { - for (int ti = 0; ti < jcp.itiles; ti++) { - - p.tile_block_ur = tile_block_ur; - p.nb_tile_block_ur = nb_tile_block_ur; - p.tile_block = tile_block; - p.tj = tj; - p.ti = ti; - - kernel_->input_transform_data_ker(&p); - - tile_block_ur++; - if (tile_block_ur >= jcp.tile_block_ur) { - tile_block_ur = 0; - nb_tile_block_ur++; - } - if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { - nb_tile_block_ur = 0; - tile_block++; - } - } - } -} - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t - ::input_transform_tileblock_data(int tile_block, - const jit_conv_winograd_conf_t &jcp, - float *inp, float *tinp) const -{ - float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, - 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; - float Iw[alpha][alpha][simd_w]; - float I[alpha][alpha][simd_w]; - float T[alpha][alpha][simd_w]; - - const int inph = is_fwd ? jcp.ih : jcp.oh; - const int inpw = is_fwd ? jcp.iw : jcp.ow; - - array_offset_calculator input(inp, - jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); - array_offset_calculator output(tinp, - alpha, alpha, - jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, - jcp.dimN_reg_block, jcp.dimK_reg_block); - - auto p = jit_wino_transform_call_s(); - - p.dst = tinp; - p.G = G; - p.M = I; - p.Mw = Iw; - p.T = T; - - - int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; - - for (int nb_tile_block_ur = 0; - nb_tile_block_ur < jcp.nb_tile_block_ur; - nb_tile_block_ur++) { - - for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; - tile_block_ur++) { - - int img = tile_index / (jcp.jtiles * jcp.itiles); - int ti = tile_index % jcp.itiles; - int tj = (tile_index / jcp.itiles) % jcp.jtiles; - float *pinp_b = &(input(img, 0, 0, 0, 0)); - - p.src = pinp_b; - p.tile_block_ur = tile_block_ur; - p.nb_tile_block_ur = nb_tile_block_ur; - p.tj = tj; - p.ti = ti; - - kernel_->input_transform_data_ker(&p); - - tile_index++; - } - } -} - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_S_G_D( - float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const auto &p_ops = attr_->post_ops_; - - const int inph = is_fwd ? jcp.ih : jcp.oh; - const int inpw = is_fwd ? jcp.iw : jcp.ow; - const int outh = is_fwd ? jcp.oh : jcp.ih; - const int outw = is_fwd ? jcp.ow : jcp.iw; - - /* Notation: - FWD: dimM:oc, dimN:ntiles, dimK:ic, - BWD: dimM:ic, dimN:ntiles, dimK:oc, - FWD/BWD: V: src/diff_dst transform, U:weight transform, - M:dst/diff_src transform */ - array_offset_calculator input(inp_ptr, - jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, - jcp.dimK_reg_block); - array_offset_calculator output(out_ptr, - jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, - jcp.dimM_simd_block); - array_offset_calculator weights(wei_ptr, - jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, - jcp.ic_simd_block, jcp.oc_simd_block); - array_offset_calculator bias(bias_ptr, - jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); - - array_offset_calculator M(is_fwd - ? scratchpad.template get(key_wino_M) - : scratchpad.template get(key_wino_V), - jcp.dimN_nb_block, jcp.dimM_nb_block, - alpha, alpha, - jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, - jcp.dimN_reg_block, jcp.dimM_simd_block); - - auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) - ? wei_ptr - : scratchpad.template get(key_wino_U); - - array_offset_calculator U(wino_wei, - jcp.dimM_nb_block, - alpha, alpha, - jcp.dimK_nb_block, - jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, - jcp.dimK_reg_block, jcp.dimM_simd_block); - array_offset_calculator V(is_fwd - ? scratchpad.template get(key_wino_V) - : scratchpad.template get(key_wino_M), - jcp.dimN_nb_block, alpha, alpha, - jcp.dimN_block, jcp.dimK_nb_block, - jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); - - const bool wants_padded_bias = jcp.with_bias - && jcp.oc_without_padding != jcp.oc; - float last_slice_bias[simd_w] = {0}; - if (wants_padded_bias) { - for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) - last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); - } - - { - - parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, - [&](int img, int K_blk1, int K_blk2) { - input_transform_data(img, jcp, - &(input(img, K_blk1 * jcp.dimK_block + K_blk2, - 0, 0, 0)), - &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0))); - }); - - if (jcp.prop_kind != prop_kind::forward_inference) { - parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), - (jcp.ic_block * jcp.ic_reg_block), - [&](int ofm1, int ifm1, int ofm2, int ifm2) { - float *U_base_ptr = is_fwd - ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) - : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); - weight_transform_data(jcp, - &(weights( - ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, - ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, - 0, 0, 0, 0)), - U_base_ptr); - }); - } - - parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, - [&](int N_blk1, int oj, int oi, int M_blk1) { - for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; - K_blk1++) - for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++) - kernel_->gemm_loop_ker( - (float *)&(M(N_blk1, M_blk1, oj, oi, - N_blk2, 0, 0, 0)), - (const float *)&(U(M_blk1, oj, oi, - K_blk1, 0, 0, 0, 0)), - (const float *)&(V(N_blk1, oj, oi, - N_blk2, K_blk1, 0, 0, 0)), K_blk1); - }); - - parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block), - [&](int img, int M_blk1, int M_blk2) { - const int M_blk = - M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; - - float *bias_ptr = wants_padded_bias - && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 - ? last_slice_bias : &bias(M_blk, 0); - output_transform_data(img, jcp, p_ops, - &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), - &(output(img, M_blk, 0, 0, 0)), bias_ptr); - }); - - } -} - -template -void _jit_avx512_core_fp32_wino_conv_4x3_t::_execute_data_W_SGD( - float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const auto &p_ops = attr_->post_ops_; - - const int inph = is_fwd ? jcp.ih : jcp.oh; - const int inpw = is_fwd ? jcp.iw : jcp.ow; - const int outh = is_fwd ? jcp.oh : jcp.ih; - const int outw = is_fwd ? jcp.ow : jcp.iw; - - array_offset_calculator input(inp_ptr, - jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); - array_offset_calculator output(out_ptr, - jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); - array_offset_calculator weights(wei_ptr, - jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, - jcp.ic_simd_block, jcp.oc_simd_block); - array_offset_calculator bias(bias_ptr, - jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block); - - auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) - ? wei_ptr - : scratchpad.template get(key_wino_U); - - array_offset_calculator U(wino_wei, - jcp.dimM_nb_block, - alpha, alpha, - jcp.dimK_nb_block, - jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, - jcp.dimK_reg_block, jcp.dimM_simd_block); - - array_offset_calculator M(is_fwd - ? scratchpad.template get(key_wino_M) - : scratchpad.template get(key_wino_V), - 0, jcp.dimM_nb_block, alpha, alpha, - jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, - jcp.dimN_reg_block, jcp.dimM_simd_block); - array_offset_calculator V(is_fwd - ? scratchpad.template get(key_wino_V) - : scratchpad.template get(key_wino_M), - 0, alpha, alpha, jcp.dimN_block, - jcp.dimK_nb_block, jcp.dimK_block, - jcp.dimN_reg_block, jcp.dimK_reg_block); - - const bool wants_padded_bias = jcp.with_bias - && jcp.oc_without_padding != jcp.oc; - float last_slice_bias[simd_w] = {0}; - if (wants_padded_bias) { - for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) - last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); - } - - if (jcp.prop_kind != prop_kind::forward_inference) { - - parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block), - [&](int ofm1, int ifm1, int ofm2, int ifm2) { - float *U_base_ptr = is_fwd - ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) - : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); - weight_transform_data(jcp, - &(weights( - ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, - ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, - 0, 0, 0, 0)), - U_base_ptr); - }); - } - - parallel_nd(jcp.tile_block, [&](int tile_block) { - int ithr = mkldnn_get_thread_num(); - - for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) { - for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) { - - input_transform_tileblock_data( - tile_block, jcp, - &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), - &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); - } - } - - for (int oj = 0; oj < alpha; oj++) { - for (int oi = 0; oi < alpha; oi++) { - for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) - for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) - for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) - kernel_->gemm_loop_ker( - (float *)&(M(ithr, M_blk1, oj, oi, - N_blk, 0, 0, 0)), - (const float *)&(U(M_blk1, oj, oi, K_blk1, - 0, 0, 0, 0)), - (const float *)&(V(ithr, oj, oi, - N_blk, K_blk1, 0, 0, 0)), K_blk1); - } - } - - for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) { - for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; - M_blk2++) { - const int M_blk = - M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; - - float *bias_ptr = wants_padded_bias - && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 - ? last_slice_bias : &bias(M_blk, 0); - - output_transform_tileblock_data(tile_block, jcp, p_ops, - &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)), - &(output(0, M_blk, 0, 0, 0)), bias_ptr); - } - } - }); -} - -template struct _jit_avx512_core_fp32_wino_conv_4x3_t; -template struct _jit_avx512_core_fp32_wino_conv_4x3_t; - -namespace { - -void subarray_sum(size_t num_arrs, float *output, size_t nelems, - float *input_ptrs[], size_t input_starts[], size_t input_ends[]) { - using namespace nstl; - const size_t block_size = 16 * 1024 / sizeof(float); - const size_t blocks_number = nelems / block_size; - const size_t tail = nelems % block_size; - -PRAGMA_OMP(parallel) - { - const int ithr = mkldnn_get_thread_num(); - const int nthr = mkldnn_get_num_threads(); - size_t start{ 0 }, end{ 0 }; - balance211(blocks_number, nthr, ithr, start, end); - - for (size_t nb = start; nb < end; ++nb) { - size_t start_e = nb * block_size; - size_t end_e = start_e + block_size; - size_t input_start = max(start_e, min(input_starts[0], end_e)); - size_t input_end = max(start_e, min(input_ends[0], end_e)); - - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < input_start; e++) { - output[e] = 0.f; - } - - PRAGMA_OMP_SIMD() - for (size_t e = input_start; e < input_end; e++) { - output[e] = input_ptrs[0][e]; - } - - PRAGMA_OMP_SIMD() - for (size_t e = input_end; e < end_e; e++) { - output[e] = 0.f; - } - - for (size_t a = 1; a < num_arrs; a++) { - input_start = max(start_e, input_starts[a]); - input_end = min(input_ends[a], end_e); - - PRAGMA_OMP_SIMD() - for (size_t e = input_start; e < input_end; e++) { - output[e] += input_ptrs[a][e]; - } - } - } - - if (tail != 0 && ithr == nthr - 1) { - size_t start_e = nelems - tail; - size_t end_e = nelems; - size_t input_start = max(start_e, min(input_starts[0], end_e)); - size_t input_end = max(start_e, min(input_ends[0], end_e)); - - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < input_start; e++) { - output[e] = 0.f; - } - - PRAGMA_OMP_SIMD() - for (size_t e = input_start; e < input_end; e++) { - output[e] = input_ptrs[0][e]; - } - - PRAGMA_OMP_SIMD() - for (size_t e = input_end; e < end_e; e++) { - output[e] = 0.f; - } - - for (size_t a = 1; a < num_arrs; a++) { - input_start = max(start_e, input_starts[a]); - input_end = min(input_ends[a], end_e); - - PRAGMA_OMP_SIMD() - for (size_t e = input_start; e < input_end; e++) { - output[e] += input_ptrs[a][e]; - } - } - } - } -} - -const int max_threads_number = 1024; - -// Sum to the first buffer array -void array_sum(size_t num_arrs, float *output, - size_t nelems, float *input_ptrs[], bool reduce_to_first = true) { - const size_t block_size = 16 * 1024 / sizeof(float); - const size_t blocks_number = nelems / block_size; - const size_t tail = nelems % block_size; - -PRAGMA_OMP(parallel) - { - const size_t ithr = mkldnn_get_thread_num(); - const size_t nthr = mkldnn_get_num_threads(); - size_t start{ 0 }, end{ 0 }; - balance211(blocks_number, nthr, ithr, start, end); - - for (size_t nb = start; nb < end; ++nb) { - size_t start_e = nb * block_size; - size_t end_e = start_e + block_size; - if (!reduce_to_first) { - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] = input_ptrs[0][e]; - } - } - for (size_t a = 1; a < num_arrs; a++) { - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] += input_ptrs[a][e]; - } - } - } - - if (tail != 0 && ithr == nthr - 1) { - size_t start_e = nelems - tail; - size_t end_e = nelems; - if (!reduce_to_first) { - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] = input_ptrs[0][e]; - } - } - for (size_t a = 1; a < num_arrs; a++) { - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] += input_ptrs[a][e]; - } - } - } - } -} -} //bwdw namespace - -void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: -_execute_backward_weights_SDGtWo(const float *ptr_src, - const float *ptr_diff_dst, float *ptr_diff_weights, - float *ptr_diff_bias, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const int nthreads = jcp.nthr; - - array_offset_calculator src((float *)ptr_src, - jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); - array_offset_calculator diff_dst((float *)ptr_diff_dst, - jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); - array_offset_calculator diff_weights(ptr_diff_weights, - jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); - - array_offset_calculator Us(scratchpad.get(key_wino_U), - 0, alpha, alpha, - jcp.oc_block, jcp.ic_block, - jcp.ic_simd_block, - jcp.oc_reg_block, - jcp.oc_simd_block); - - const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc - * jcp.ic / jcp.nb_ic; - array_offset_calculatordiff_weights_prv( - scratchpad.get(key_wino_U) + U_sz, - 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); - - array_offset_calculator M(scratchpad.get(key_wino_M), - 0, alpha, alpha, - jcp.oc_block, - jcp.nb_tile_block_ur, - jcp.tile_block_ur, - jcp.oc_reg_block, - jcp.oc_simd_block); - - array_offset_calculator V(scratchpad.get(key_wino_V), - 0, alpha, alpha, - jcp.ic_block, - jcp.nb_tile_block_ur, - jcp.tile_block_ur, - jcp.ic_simd_block); - - array_offset_calculator diff_bias_prv( - scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); - - auto trans_ker_p = jit_wino_transform_call_s(); - float I[alpha][alpha][simd_w]; - float T[alpha][alpha][simd_w]; - float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, - 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; - float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f, - 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f, - 1.13777777777778f}; - float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; - -PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)) -{ - if (jcp.with_bias) { - parallel_nd_in_omp(nthreads, jcp.oc / simd_w, - [&](int ithr, int ofm){ - float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w)); - PRAGMA_OMP_SIMD() - for (int v = 0; v < simd_w; v++) { - pdbias[v] = 0.0f; - } - }); - } - - int ithr = mkldnn_get_thread_num(); - for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) { - int first_tblk = 0; -PRAGMA_OMP(for) - for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) { - int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur; - int img = tile_index / (jcp.itiles * jcp.jtiles); - trans_ker_p.ti = tile_index % jcp.itiles; - trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles; - trans_ker_p.M = I; - trans_ker_p.T = T; - trans_ker_p.G = G_I_3x3_4x4; - for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { - int ifm = ifm1 * jcp.ic_block + ifm2; - trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); - trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0)); - kernel_->src_transform(&trans_ker_p); - } - - for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) { - trans_ker_p.G = G_W_3x3_4x4; - for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { - int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; - trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); - trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0)); - if (jcp.with_bias && ifm1 == 0) { - trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); - kernel_->diff_dst_transform_wbias(&trans_ker_p); - } else { - kernel_->diff_dst_transform(&trans_ker_p); - } - } - - for (int oj = 0; oj < alpha; ++oj) { - for (int oi = 0; oi < alpha; ++oi) { - kernel_->gemm_loop_ker_first_iter( - &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)), - &(M(ithr, oj, oi, 0, 0, 0, 0, 0)), - &(V(ithr, oj, oi, 0, 0, 0, 0))); - } - } - trans_ker_p.G = G_O_3x3_4x4; - for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { - for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) { - int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block - + ofm3; - for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { - int ifm = ifm1 * jcp.ic_block + ifm2; - trans_ker_p.src = (float *)&(Us(ithr, 0, 0, - ofm2, ifm2, 0, ofm3, 0)); - trans_ker_p.dst = (float *)&(diff_weights_prv(ithr, - ofm, ifm, 0, 0, 0, 0)); - if (first_tblk == 0) { - kernel_->diff_weights_transform(&trans_ker_p); - } else { - kernel_->diff_weights_transform_accum(&trans_ker_p); - } - } - } - } - } - ++first_tblk; - } - } -} - - // Reduce diff-weights - { - float *output = ptr_diff_weights; - float *input_base = scratchpad.get(key_wino_U) + U_sz; - int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw; - float *input_ptrs[max_threads_number]; - for (int i = 0; i < nthreads; ++i) { - input_ptrs[i] = input_base + nelems * i; - } - array_sum(nthreads, output, nelems, input_ptrs, false); - - if (jcp.with_bias) { - output = ptr_diff_bias; - input_base = scratchpad.get(key_conv_bia_reduction); - for (int i = 0; i < nthreads; ++i) { - input_ptrs[i] = input_base + jcp.oc * i; - } - array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs, - false); - } - } -} - -void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: -_execute_backward_weights_S_D_Giot_W(const float *ptr_src, - const float *ptr_diff_dst, float *ptr_diff_weights, - float *ptr_diff_bias, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const int nthreads = jcp.nthr; - - array_offset_calculator src((float *)ptr_src, - jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); - array_offset_calculator diff_dst((float *)ptr_diff_dst, - jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); - array_offset_calculator diff_weights((float *)ptr_diff_weights, - jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); - array_offset_calculator diff_bias((float *)ptr_diff_bias, jcp.oc); - - array_offset_calculator U(scratchpad.get(key_wino_U), - jcp.nb_ic, jcp.nb_oc, - alpha, alpha, - jcp.oc_block, jcp.ic_block, - jcp.ic_simd_block, - jcp.oc_reg_block, - jcp.oc_simd_block); - - const int U_size = jcp.oc * jcp.ic * alpha * alpha; - array_offset_calculator Us( - scratchpad.get(key_wino_U) + U_size, - 0, jcp.nb_ic, jcp.nb_oc, - alpha, alpha, - jcp.oc_block, jcp.ic_block, - jcp.ic_simd_block, - jcp.oc_reg_block, - jcp.oc_simd_block); - - array_offset_calculator M(scratchpad.get(key_wino_M), - jcp.nb_oc, - jcp.tile_block, - alpha, alpha, - jcp.oc_block, - jcp.nb_tile_block_ur, - jcp.tile_block_ur , - jcp.oc_reg_block, - jcp.oc_simd_block); - - array_offset_calculator V(scratchpad.get(key_wino_V), - jcp.nb_ic, - jcp.tile_block, - alpha, alpha, - jcp.ic_block, - jcp.nb_tile_block_ur, jcp.tile_block_ur, - jcp.ic_simd_block); - - array_offset_calculator diff_bias_prv( - scratchpad.get(key_conv_bia_reduction), nthreads, jcp.oc); - - size_t input_starts[max_threads_number] = {0}; - size_t input_ends[max_threads_number] = {0}; - size_t first_tblk = 0; - - auto trans_ker_p = jit_wino_transform_call_s(); - float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, - 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; - float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, - 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, - 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; - float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; - float I[alpha][alpha][simd_w]; - float T[alpha][alpha][simd_w]; - -PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T)) -{ - if (jcp.with_bias) { - parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { - diff_bias_prv(ithr, ofm) = 0.0f; - }); - } - - trans_ker_p.G = G_I_3x3_4x4; - trans_ker_p.M = I; - trans_ker_p.T = T; - - parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb, - [&](int ifm1, int ifm2, int img){ - size_t ifm = ifm1 * jcp.ic_block + ifm2; - size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); - size_t tblk3 = tile_base_index % jcp.tile_block_ur; - size_t tblk2 = (tile_base_index / jcp.tile_block_ur) - % jcp.nb_tile_block_ur; - size_t tblk1 = (tile_base_index / jcp.tile_block_ur) - / jcp.nb_tile_block_ur; - trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; - trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); - trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0)); - kernel_->src_transform(&trans_ker_p); - }); - - int ithr = mkldnn_get_thread_num(); - trans_ker_p.G = G_W_3x3_4x4; - parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb, - [&](int ofm1, int ofm2, int img){ - int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; - size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); - size_t tblk3 = tile_base_index % jcp.tile_block_ur; - size_t tblk2 = (tile_base_index / jcp.tile_block_ur) - % jcp.nb_tile_block_ur; - size_t tblk1 = (tile_base_index / jcp.tile_block_ur) - / jcp.nb_tile_block_ur; - trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; - trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); - trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0)); - if (jcp.with_bias) { - trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); - kernel_->diff_dst_transform_wbias(&trans_ker_p); - } else { - kernel_->diff_dst_transform(&trans_ker_p); - } - }); - - PRAGMA_OMP(barrier) - - parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block, - [&](int ifm1, int ofm1, int oj, int oi, int tblk1){ - if (first_tblk == 0) { - input_starts[ithr] = - (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, - 0, 0)) - - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0, - 0, 0, 0)); - input_ends[ithr] = input_starts[ithr] - + jcp.oc_block * jcp.ic_block - * jcp.ic_simd_block * jcp.oc_reg_block - * jcp.oc_simd_block; - } - else if (tblk1 == 0) { - input_ends[ithr] += jcp.oc_block * jcp.ic_block - * jcp.ic_simd_block * jcp.oc_reg_block - * jcp.oc_simd_block; - } - - if (first_tblk == 0 || tblk1 == 0) { - kernel_->gemm_loop_ker_first_iter( - &(Us(ithr, ifm1, ofm1, oj, oi, - 0, 0, 0, 0, 0)), - &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), - &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); - } else { - kernel_->gemm_loop_ker( - &(Us(ithr, ifm1, ofm1, oj, oi, - 0, 0, 0, 0, 0)), - &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), - &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); - } - ++first_tblk; - }); -} - - // Reduce diff-weights - { - float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0)); - size_t nelems = jcp.ic * jcp.oc * alpha * alpha; - float *input_ptrs[max_threads_number]; - for (int i = 0; i < nthreads; ++i) - input_ptrs[i] = output + nelems * (i + 1); - subarray_sum(nthreads, output, nelems, input_ptrs, - input_starts, input_ends); - } - - trans_ker_p.G = G_O_3x3_4x4; -PRAGMA_OMP(parallel firstprivate(trans_ker_p)) - { - parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block, - [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){ - int ofm = (ofm1 * jcp.oc_block + ofm2) - * jcp.oc_reg_block + ofm3; - int ifm = ifm1 * jcp.ic_block + ifm2; - trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0, - ofm2, ifm2, 0, ofm3, 0)); - trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm, - 0, 0, 0, 0)); - kernel_->diff_weights_transform(&trans_ker_p); - }); - } - - if (jcp.with_bias) { - parallel_nd(jcp.oc / simd_w, [&](int ofm1) { - float* pbias = &(diff_bias(ofm1 * simd_w)); - float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w)); - - const int blk_sz = ofm1 == jcp.oc / simd_w - 1 - ? jcp.oc_without_padding - ofm1 * simd_w : simd_w; - - PRAGMA_OMP_SIMD() - for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { - pbias[ofm2] = pbias_prv[ofm2]; - } - - for (int ithr = 1; ithr < nthreads; ++ithr) { - pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w)); - PRAGMA_OMP_SIMD() - for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { - pbias[ofm2] += pbias_prv[ofm2]; - } - } - }); - } -} - -} -} -} -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp deleted file mode 100644 index f1a56aac7..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp +++ /dev/null @@ -1,386 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP -#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace winograd_avx512_core { -inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_winograd_conf_t &jcp) { - using namespace utils; - using namespace memory_tracking::names; - - size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; - size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles - * jcp.jtiles; - size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles - * jcp.jtiles; - - switch (jcp.sched_policy) { - case WSCHED_DATA_W_SGD: - V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur - * jcp.tile_block_ur * jcp.ic; - M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur - * jcp.tile_block_ur * jcp.oc; - break; - case WSCHED_WEI_SDGtWo: - U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc - * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw); - M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) - * (jcp.oc / jcp.nb_oc); - V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) - * (jcp.ic / jcp.nb_ic); - break; - case WSCHED_WEI_S_D_Giot_W: - U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc; - M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles; - V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles; - break; - default: break; - } - - scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); - scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); - scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); - - if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) { - size_t br_sz = (size_t)jcp.nthr * jcp.oc; - scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); - } -} -} - -template -struct _jit_avx512_core_fp32_wino_conv_4x3_t { - - _jit_avx512_core_fp32_wino_conv_4x3_t( - const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) - : kernel_(nullptr), attr_(attr) { - kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp); - } - - ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; } - - protected: - void weight_transform_data(const jit_conv_winograd_conf_t &jcp, - float *wp, float *twp) const; - void input_transform_data(int image, - const jit_conv_winograd_conf_t &jcp, - float *inp, float *tinp) const; - void input_transform_tileblock_data(int tile_block, - const jit_conv_winograd_conf_t &jcp, - float *inp, float *tinp) const; - void output_transform_data(int image, - const jit_conv_winograd_conf_t &jcp, - const post_ops_t &p_ops, float *toutp, float *pout_b, - float *bias) const; - void output_transform_tileblock_data(int tile_block, - const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, - float *toutp, float *outp, float *bias) const; - void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, - float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const; - void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, - float *wei_ptr, float *bias_ptr, - const memory_tracking::grantor_t &scratchpad) const; - _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_; - const primitive_attr_t *attr_; -}; - -struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t - : _jit_avx512_core_fp32_wino_conv_4x3_t - , public cpu_primitive_t - { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), - jit_avx512_core_fp32_wino_conv_4x3_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel:: - init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, - *attr()); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - auto scratchpad = scratchpad_registry().registrar(); - winograd_avx512_core::init_scratchpad(scratchpad, jcp_); - - return status; - } - - jit_conv_winograd_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - auto wei_fmt = desc()->prop_kind == prop_kind::forward_training - ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any; - return set_default_formats_common(nChw16c, wei_fmt, nChw16c); - } - }; - - jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd) - : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) - , cpu_primitive_t(apd, true) - {} - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); - - auto scratchpad = this->scratchpad(ctx); - - switch ((pd()->jcp_).sched_policy) { - case WSCHED_DATA_W_S_G_D: - this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, - (float *)bias, scratchpad); - break; - case WSCHED_DATA_W_SGD: - this->_execute_data_W_SGD((float *)src, dst, (float *)weights, - (float *)bias, scratchpad); - break; - default: - break; - } - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t - : _jit_avx512_core_fp32_wino_conv_4x3_t, - public cpu_primitive_t { - struct pd_t : public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), - jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t); - - status_t init() { - bool ok = true - && mkldnn_thr_syncable() - && desc()->prop_kind == prop_kind::backward_data - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::f32, data_type::f32, - data_type::undef, data_type::f32, data_type::f32) - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel - ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), - *diff_dst_md()); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - auto scratchpad = scratchpad_registry().registrar(); - winograd_avx512_core::init_scratchpad(scratchpad, jcp_); - - return status; - } - - jit_conv_winograd_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; - return set_default_formats_common(nChw16c, wei_fmt, nChw16c); - } - }; - - jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd) - : _jit_avx512_core_fp32_wino_conv_4x3_t(apd->jcp_, apd->attr()) - , cpu_primitive_t(apd, true) - {} - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); - - auto scratchpad = this->scratchpad(ctx); - - switch ((pd()->jcp_).sched_policy) { - case WSCHED_DATA_W_S_G_D: - this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, - (float *)weights, NULL, scratchpad); - break; - - case WSCHED_DATA_W_SGD: - this->_execute_data_W_SGD((float *)diff_dst, diff_src, - (float *)weights, NULL, scratchpad); - break; - - default: - break; - } - - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t - : public cpu_primitive_t { - struct pd_t : public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), - jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t); - - status_t init() { - bool ok = true - && mkldnn_thr_syncable() - && desc()->prop_kind == prop_kind::backward_weights - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && set_default_formats(); - if (!ok) - return status::unimplemented; - - status_t status = - jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: - init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), - *diff_weights_md()); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - auto scratchpad = scratchpad_registry().registrar(); - winograd_avx512_core::init_scratchpad(scratchpad, jcp_); - - return status; - } - - jit_conv_winograd_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; - return set_default_formats_common(nChw16c, wei_fmt, nChw16c); - } - }; - - jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd, true) - , kernel_(nullptr) - { - kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( - pd()->jcp_); - } - - ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t() - { - delete kernel_; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); - - switch (kernel_->jcp.sched_policy) { - case WSCHED_WEI_SDGtWo: - _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights, - diff_bias, scratchpad(ctx)); - break; - case WSCHED_WEI_S_D_Giot_W: - _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights, - diff_bias, scratchpad(ctx)); - break; - default: - assert(kernel_->jcp.sched_policy != WSCHED_INVALID); - break; - } - return status::success; - } - -private: - void _execute_backward_weights_SDGtWo(const float *src, - const float *diff_dst, float *diff_weights, float *diff_bias, - const memory_tracking::grantor_t &scratchpad) const; - void _execute_backward_weights_S_D_Giot_W(const float *src, - const float *diff_dst, float *diff_weights, float *diff_bias, - const memory_tracking::grantor_t &scratchpad) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp deleted file mode 100644 index 0d64a2d13..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp +++ /dev/null @@ -1,2596 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include - -#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" - -#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace { - -using namespace mkldnn::impl::utils; - -unsigned int L1_cache_size = get_cache_size(1, true); -unsigned int L2_cache_size = get_cache_size(2, true); -unsigned int LLC_data_size = get_cache_size(3, false); - -// the test funtion takes jcp, the candidate and the current best. -// it returns true if the new candidate is better -int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, - int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) -{ - int best_divisor = default_best; - auto test_num - = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { - if (test(jcp, num, best_divisor)) { - best_divisor = num; - } - }; - - for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { - if (number % divisor == 0) { - test_num(jcp, divisor); - test_num(jcp, number / divisor); - } - } - - return best_divisor; -} - -namespace { -bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { - /* Determines if current winograd implementation is faster than direct. - Following conditions are empirical and based on performance data */ - unsigned int ncores_per_socket = - cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); - unsigned int nthreads = mkldnn_get_max_threads(); - - if (jcp.prop_kind == prop_kind::forward_inference) { - return jcp.mb >= 4; - } else if (nthreads > ncores_per_socket) { - double src_dst_transforms_per_core = alpha * alpha - * (jcp.ic + jcp.oc) - * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size) - * ((jcp.ow + tile_size - 1) / tile_size) - * sizeof(float) / 1024. / 1024. / nthreads; - double wei_transform = alpha * alpha - * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.; - - if (jcp.prop_kind == prop_kind::backward_weights) { - if (src_dst_transforms_per_core < 0.3 - || (src_dst_transforms_per_core <= 28 && wei_transform < 4)) - return false; - else - return true; - } else { - if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02) - return false; - } - } - - return jcp.mb > 8; -} -} - -/* assumes 512 bits registers */ -/* TODO: add support for strides */ -/* TODO: handle the prefetch distance automatically */ -typedef enum cache_t_ { L1, L2, L3 } cache_t; - -template -struct prefetcher_t { - prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, - cache_t cache_type, size_t block_size, /* in number of elements*/ - int nb_instructions_in_block, int fma_ipc) - : cg_(generator) - , reg_base_addr_(reg_base_addr) - , cache_type_(cache_type) - , cache_block_size_(block_size) - { - nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); - prefetch_spread_ - = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); - prefetch_blk_ - = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); - - /* assumption: when fetch in Li, data is already in L(i+1) */ - int cache_latency; - switch (cache_type_) { - case L1: cache_latency = 14; break; - case L2: cache_latency = 250; break; - case L3: cache_latency = 250; break; - } - - prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); - } - - void prefetch(int instruction_number) - { - if (instruction_number % prefetch_spread_ == 0) { - for (int i = 0; (i < prefetch_blk_) - && (prefetches_issued_ < nb_cache_lines_to_prefetch_); - i++, prefetches_issued_++) { - prefetch_inst_(cg_->EVEX_compress_addr( - reg_base_addr_, (cache_block_size_ * prefetch_distance_) - * sizeof(data_t) - + (prefetches_issued_ * 64))); - } - } - } - -private: - void prefetch_inst_(const Xbyak::Address &addr) - { - switch (cache_type_) { - case L1: cg_->prefetcht0(addr); break; - case L2: cg_->prefetcht1(addr); break; - case L3: cg_->prefetcht2(addr); break; - default: - break; // TODO: raise an exception or put an assert - } - } - - jit_generator *cg_; - Xbyak::Reg64 reg_base_addr_; - cache_t cache_type_; - int cache_block_size_ = 0; - int nb_cache_lines_to_prefetch_ = 0; - int prefetches_issued_ = 0; - int prefetch_spread_ = 0; - int prefetch_blk_ = 0; - int prefetch_distance_ = 0; -}; - -// utilities to support kernel parameter selection -bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, - int dimN_block, float C2_min, float C2_max) { - float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic) - * dimN_block * jcp.dimN_reg_block - + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float); - float L2_lb = C2_min * L2_cache_size; - float L2_ub = C2_max * L2_cache_size; - return (block_size > L2_lb && block_size < L2_ub); -} - -bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block, - int dimM_block, float C1_min, float C1_max) { - float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block - * jcp.dimK_reg_block * jcp.dimM_reg_block - + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block - + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block) - * (float)sizeof(float); - float L1_lb = C1_min * L1_cache_size; - float L1_ub = C1_max * L1_cache_size; - return (gemm_block_size > L1_lb && gemm_block_size < L1_ub); -} -bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, - int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) -{ - float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block - + dimM_block * dimK_block * dimK_reg_block - * dimM_simd_block * dimM_reg_block - + dimK_block * dimN_reg_block * dimK_reg_block) - * (float)sizeof(float); - float rhs = C * L1_cache_size; - return (lhs < rhs); -} -bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, - int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) -{ - float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block - * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block) - * (float)sizeof(float); - float rhs = C * L1_cache_size; - return (lhs < rhs); -} -bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, - int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block, - int dimM_simd_block, float C) -{ - float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block - * dimM_simd_block * dimM_reg_block - + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block - * dimM_simd_block * dimM_reg_block - + nb_dimN_reg_block * dimK_nb_block * dimK_block - * dimN_reg_block * dimK_reg_block) - * (float)sizeof(float); - float rhs = C * L2_cache_size; - return (lhs < rhs); -} - -bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block, - int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) -{ - float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK - * (float)sizeof(float); - float B_size = dimN_block * dimN_reg_block * dimK - * (float)sizeof(float); - return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size); -} -} - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate() -{ - // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) - // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block; - // dimM_reg_block++) // unrolled - // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) - // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; - // dimK_reg_block++) // unrolled - // for (int tile =0; tile < jcp.dimN_reg_block; tile++) - // C[dimM_block][dimM_reg_block][tile] += - // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block] - // * broadcast(B[dimK_block][tile][dimK_reg_block]); - // Notes: - // jcp.kernel_kind defines embedded or explicit broadcast - // dimM_reg_block=1 for embedded bcast kernel - - auto zmm_srcA = [=]() { - return Xbyak::Zmm(0); - }; - auto zmm_srcB = [=](int tile) { - int idx = 1 + tile; - assert(idx < 1 + jcp.dimN_reg_block); - return Xbyak::Zmm(idx); - }; - auto zmm_dstC = [=](int dimM_reg_block, int tile) { - int idx{0}; - if (jcp.kernel_kind == embd_bcast) - idx = 1 + tile; - else - idx = 1 + jcp.dimN_reg_block - + dimM_reg_block * jcp.dimN_reg_block + tile; - assert(idx < 32); - return Xbyak::Zmm(idx); - }; - - auto prepare_output = [=]() { - for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; - dimM_reg_block++) { - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - Zmm zmm = zmm_dstC(dimM_reg_block, tile); - vpxord(zmm, zmm, zmm); - } - } - }; - auto store_output = [=](bool output_is_aligned) { - Label save; - cmp(reg_is_beta_zero, 0); - je(save, T_NEAR); - - for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; - dimM_reg_block++) { - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - Zmm zmm = zmm_dstC(dimM_reg_block,tile); - int output_offset - = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; - vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset)); - } - } - - L(save); - for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; - dimM_reg_block++) { - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - Zmm zmm = zmm_dstC(dimM_reg_block,tile); - int output_offset - = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; - - // In W_SGD, output will be reused. - if (output_is_aligned - && jcp.dimK_nb_block == 1 - && jcp.sched_policy == WSCHED_DATA_W_S_G_D - && (jcp.dimN * jcp.dimM * alpha * alpha - * sizeof(float) > 2 * LLC_data_size)) - vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm); - else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm); - } - } - }; - - auto inner_loops = [=]() { - Label dimM_block_loop, dimK_block_loop; - - if (jcp.dimM_block > 1) { - mov(reg_dimM_block_loop_cnt, jcp.dimM_block); - L(dimM_block_loop); - } - - prepare_output(); - - if (jcp.dimK_block > 1) { - mov(reg_dimK_block_loop_cnt, jcp.dimK_block); - L(dimK_block_loop); - } - - for (int dimK_reg_block = 0; - dimK_reg_block < jcp.dimK_reg_block; - dimK_reg_block ++) { - - if (jcp.kernel_kind == expl_bcast) { - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - vbroadcastss(zmm_srcB(tile), - ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]); - } - } - - /* Performing the fmas */ - - for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; - dimM_reg_block++) { - - vmovups(zmm_srcA(), - zword[reg_srcA - + jcp.dimK_reg_block * jcp.dimK_block * 64 - * dimM_reg_block - + dimK_reg_block * 64] - ); - - for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { - if (jcp.kernel_kind == expl_bcast) - vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), - zmm_srcB(tile)); - else - vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), - EVEX_compress_addr(reg_srcB, - 64 * tile + dimK_reg_block * 4, true)); - } - } - } - add(reg_srcA, jcp.dimK_reg_block * 64); - add(reg_srcB, jcp.dimN_reg_block * 64); - if (jcp.dimK_block > 1) { - sub(reg_dimK_block_loop_cnt, 1); - jnz(dimK_block_loop); - } - - Label unaligned_store, end_store; - test(reg_dstC, cpu_isa_traits::vlen - 1); - jnz(unaligned_store, T_NEAR); - store_output(true); - jmp(end_store, T_NEAR); - L(unaligned_store); { - store_output(false); - } - L(end_store); - - if (jcp.dimM_block > 1) { - sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); - add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64); - if (jcp.kernel_kind == expl_bcast) { - add(reg_srcA, - (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64 - * jcp.dimK_block); - } - sub(reg_dimM_block_loop_cnt, 1); - jnz(dimM_block_loop); - } - }; - - /* Preamble */ - preamble(); - - /* kernel */ - inner_loops(); - - /* Postamble */ - postamble(); - ret(); -} - -void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel - ::weights_transform_data_ker_generate() -{ - bool is_fwd = one_of(jcp.prop_kind, - mkldnn_forward_training, mkldnn_forward_inference); - int kh = jcp.kh; - int kw = jcp.kw; - - auto zmm_temp = Xbyak::Zmm(31); - auto zmm_zero = Xbyak::Zmm(30); - - auto zmm_M = [=](int i) { - return Xbyak::Zmm(i); - }; - auto zmm_MT = [=](int i) { - return Xbyak::Zmm(i + simd_w); - }; - - auto zmm_G = [=](int i) { - return Xbyak::Zmm(i); - }; - auto zmm_F = [=](int i) { - return Xbyak::Zmm(alpha + i); - }; - auto zmm_T = [=](int i) { - return Xbyak::Zmm(alpha + 3 + i); - }; - auto zmm_t = [=](int i) { - return Xbyak::Zmm(2 * alpha + 3 + i); - }; - - auto zmm_load = [=](int i) { - return Xbyak::Zmm(i); - }; - - auto init_G = [=]() { - mov(wreg_temp, ptr[param1 + GET_OFF(G)]); - for (int i = 0; i < alpha; i++) { - vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]); - } - vpxord(zmm_zero, zmm_zero, zmm_zero); - }; - - auto trans16x16 = [=]() { - for (int i = 0; i < simd_w; i+=2 ) { - vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]); - vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]); - vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1)); - vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1)); - } - for (int i = 0; i < simd_w; i+=4 ) { - vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2)); - vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2)); - vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3)); - vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3)); - } - for (int i = 0; i < simd_w; i += 8) { - vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88); - vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88); - vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88); - vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88); - vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd); - vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd); - vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd); - vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd); - } - { - int i = 0; - int mask = 0x88; - vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask); - vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0)); - vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask); - vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1)); - vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask); - vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2)); - vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask); - vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3)); - vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask); - vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4)); - vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask); - vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5)); - vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask); - vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6)); - vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask); - vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7)); - mask = 0xdd; - vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask); - vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8)); - vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask); - vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9)); - vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask); - vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10)); - vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask); - vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11)); - vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask); - vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12)); - vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask); - vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13)); - vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask); - vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14)); - vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask); - vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15)); - } - }; - - auto load_src = [=]() { - mov(wreg_src, ptr[param1 + GET_OFF(src)]); - mov(wreg_F, ptr[param1 + GET_OFF(M)]); - for (int j = 0; j < kh; j++) { - for (int i = 0; i < kw; i++) { - if (is_fwd) { - for (int v1 = 0; v1 < simd_w; v1++) { - int offset_src = (j * kw * simd_w * simd_w - + i * simd_w * simd_w + v1 * simd_w) * typesize; - int offset_F = (j * kw * simd_w * simd_w - + i * simd_w * simd_w + v1 * simd_w) * typesize; - vmovups(zmm_temp, ptr[wreg_src + offset_src]); - vmovups(ptr[wreg_F + offset_F], zmm_temp); - } - } else { - int offset_src = ((2 - j) * kw * simd_w * simd_w - + (2 - i) * simd_w * simd_w) * typesize; - int offset_F = (j * kw * simd_w * simd_w - + i * simd_w * simd_w) * typesize; - lea(wreg_M, ptr[wreg_src + offset_src]); - lea(wreg_MT, ptr[wreg_F + offset_F]); - trans16x16(); - } - } - } - }; - - auto store_dst = [=]() { - mov(wreg_dst, ptr[param1 + GET_OFF(dst)]); - mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); - - Label Loop_j; - mov(wreg_cnt_j, 0); - mov(wreg_dst_aux, wreg_dst); - mov(wreg_Fw_aux, wreg_Fw); - - int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block) - * jcp.dimK_block * simd_w * simd_w; - - L(Loop_j); - { - for (int i = 0; i < alpha; i++) { - // touch pages - vmovups(zmm_load(0), ptr[wreg_Fw_aux - + (i * simd_w * simd_w) * typesize]); - mov(wreg_dst_idx, i * dim5 * typesize); - vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0)); - } - for (int i = 0; i < alpha; i++) { - for (int v1 = 1; v1 < simd_w; v1++) { - int offset_Fw = (i * simd_w * simd_w + v1 * simd_w) - * typesize; - vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]); - } - mov(wreg_dst_idx, i * dim5 * typesize); - for (int v1 = 1; v1 < simd_w; v1++) { - int offset_dst = v1 * simd_w * typesize; - vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst], - zmm_load(v1)); - } - } - add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize); - add(wreg_dst_aux, alpha * dim5 * typesize); - add(wreg_cnt_j, 1); - cmp(wreg_cnt_j, alpha); - jl(Loop_j, T_NEAR); - } - }; - - auto trans_W_4x4_3x3 = [=]() { - auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { - vmovups(dst, a); - vfmadd231ps(dst, b, c); - }; - auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { - vmulps(zmm_temp, b, c); - vsubps(dst, a, zmm_temp); - }; - auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { - vsubps(dst, zmm_zero, a); - vfnmadd231ps(dst, b, c); - }; - - mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); - mov(wreg_F, ptr[param1 + GET_OFF(M)]); - mov(wreg_T, ptr[param1 + GET_OFF(T)]); - - Label Loop_j; - mov(wreg_cnt_j, 0); - L(Loop_j); - mov(wreg_F_aux, wreg_F); - mov(wreg_Fw_aux, wreg_Fw); - mov(wreg_temp, wreg_cnt_j); - shl(wreg_temp, 4 + 2); - lea(wreg_F_aux, ptr[wreg_F + wreg_temp]); - lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]); - - for (int i = 0; i < 3; i++) { - for (int idx = 0; idx < 3; idx ++) { - vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w - * simd_w + i * simd_w * simd_w) * typesize]); - } - vmulps(zmm_t(0), zmm_G(0), zmm_F(2)); - fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0)); - fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0)); - - vmulps(zmm_T(0), zmm_G(3), zmm_F(0)); - fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1)); - fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1)); - fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1)); - fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1)); - vmovaps(zmm_T(5), zmm_F(2)); - - for (int idx = 0; idx < 6; idx ++) { - vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w) - * typesize], zmm_T(idx)); - } - } - for (int i = 0; i < 6; i++) { - - for (int idx = 0; idx < 3; idx ++) { - vmovups(zmm_T(idx), ptr[wreg_T - + (i * 3 * simd_w + idx * simd_w) * typesize]); - } - vmulps(zmm_t(0), zmm_G(0), zmm_T(2)); - fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0)); - fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0)); - - vmulps(zmm_F(0), zmm_G(3), zmm_T(0)); - fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1)); - fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1)); - fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1)); - fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1)); - vmovaps(zmm_F(5), zmm_T(2)); - - for (int l = 0; l < 6; l++) { - vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w - + l * simd_w * simd_w) * typesize], zmm_F(l)); - } - } - add(wreg_cnt_j, 1); - cmp(wreg_cnt_j, 16); - jl(Loop_j, T_NEAR); - }; - - auto inner_loops = [=]() { - load_src(); - init_G(); - trans_W_4x4_3x3(); - store_dst(); - }; - - preamble(); - inner_loops(); - postamble(); -} - -void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel - ::output_transform_data_ker_generate() -{ - bool is_fwd = one_of(jcp.prop_kind, - mkldnn_forward_training, mkldnn_forward_inference); - int outw = is_fwd ? jcp.ow : jcp.iw; - int outh = is_fwd ? jcp.oh : jcp.ih; - bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; - bool with_bias = jcp.with_bias; - bool with_relu = jcp.with_eltwise; - bool with_relu_postsum = jcp.with_relu_postsum; - bool with_sum = jcp.with_sum; - - auto zmm_zero = Xbyak::Zmm(0); - auto zmm_temp = Xbyak::Zmm(31); - auto zmm_G = [=](int i) { - return Xbyak::Zmm(1 + i); - }; - auto zmm_O = [=](int i) { - return Xbyak::Zmm(1 + alpha + i); - }; - auto zmm_T = [=](int i) { - return Xbyak::Zmm(1 + 2 * alpha + i); - }; - auto zmm_t = [=](int i) { - return Xbyak::Zmm(1 + 3 * alpha + i); - }; - - auto init_G = [=]() { - mov(oreg_temp, ptr[param1 + GET_OFF(G)]); - for (int i = 0; i < 6; i++) { - vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]); - } - }; - - auto load_src = [=]() { - mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); - mov(oreg_src, ptr[param1 + GET_OFF(src)]); - - mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); - imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur, - (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block - * jcp.dimM_simd_block * typesize); - add(oreg_src, oreg_nb_tile_block_ur); - - mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); - imul(oreg_tile_block_ur, oreg_tile_block_ur, - jcp.dimM_simd_block * typesize); - add(oreg_src, oreg_tile_block_ur); - - if (not_tiled) { - mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]); - imul(oreg_tile_block, oreg_tile_block, - jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block - * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block - * jcp.dimM_simd_block * typesize); - add(oreg_src, oreg_tile_block); - } - - int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block) - * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize; - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - int j_base_offset = j * alpha * last4dim; - int i_base_offset = i * last4dim; - vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]); - vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w) - * typesize], zmm_temp); - } - } - }; - - auto store_dst = [=]() { - vpxord(zmm_zero, zmm_zero, zmm_zero); - mov(oreg_dst, ptr[param1 + GET_OFF(dst)]); - mov(oreg_O, ptr[param1 + GET_OFF(M)]); - mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]); - shl(oreg_ydim, 2); // tj * tile_size (==4) - mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]); - shl(oreg_xdim, 2); // ti * tilesize (==4) - - if (with_bias) - mov(oreg_bias, ptr[param1 + GET_OFF(bias)]); - - auto store_one = [=](int j, int i, bool is_aligned) { - auto zmm_O = Xbyak::Zmm(31); - auto zmm_relu_ns = Xbyak::Zmm(30); - auto xmm_relu_ns = Xbyak::Xmm(30); - int offset = (j * tile_size * simd_w + i * simd_w) * typesize; - - vmovups(zmm_O, ptr[oreg_O + offset]); - if (is_fwd) { - if (with_bias) { - vaddps(zmm_O, zmm_O, ptr[oreg_bias]); - } - if (with_relu) { - if (jcp.eltwise.alpha == 0) { - vmaxps(zmm_O, zmm_O, zmm_zero); - } else { - Opmask kmask = Opmask(7); - mov(imm_addr64, float2int(jcp.eltwise.alpha)); - vmovq(xmm_relu_ns, imm_addr64); - vbroadcastss(zmm_relu_ns, xmm_relu_ns); - vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os); - vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns); - } - } - } - if (with_sum) { - vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]); - if (with_relu_postsum) // orig: with_relu_postsum - vmaxps(zmm_O, zmm_O, zmm_zero); - } - if (is_aligned) - vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O); - else - vmovups(ptr[oreg_out_j + oreg_temp], zmm_O); - }; - - auto i_loop = [=](int j, bool is_aligned) { - for (int i = 0; i < tile_size; i++) { - Label next; - mov(oreg_temp, oreg_xdim); - add(oreg_temp, i); - cmp(oreg_temp, outw); - jge(next, T_NEAR); - shl(oreg_temp, 4 + 2); // * 16 * 4 - - store_one(j, i, is_aligned); - - L(next); - } - }; - - - for (int j = 0; j < tile_size; j++) { - Label next, unaligned; - mov(oreg_temp, oreg_ydim); - add(oreg_temp, j); - cmp(oreg_temp, outh); - jge(next, T_NEAR); - - mov(oreg_out_j, oreg_dst); - imul(oreg_temp, oreg_temp, outw * simd_w * typesize); - add(oreg_out_j, oreg_temp); - - test(oreg_dst, 63); - jnz(unaligned, T_NEAR); - - i_loop(j, true); - jmp(next, T_NEAR); - - L(unaligned); - i_loop(j, false); - - L(next); - } - }; - - auto trans_O_4x4_3x3 = [=]() { - auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){ - vmulps(dst, v1, u1); - vfmadd231ps(dst, v2, u2); - }; - mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); - mov(oreg_T, ptr[param1 + GET_OFF(T)]); - mov(oreg_O, ptr[param1 + GET_OFF(M)]); - - for (int i = 0; i < alpha; i++) { - for (int j = 0; j < alpha; j++) { - vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w - + i * simd_w) * typesize]); - } - - vaddps(zmm_t(0), zmm_O(1), zmm_O(2)); - vaddps(zmm_t(1), zmm_O(3), zmm_O(4)); - vsubps(zmm_t(2), zmm_O(1), zmm_O(2)); - vsubps(zmm_t(3), zmm_O(3), zmm_O(4)); - - vaddps(zmm_T(0), zmm_t(0), zmm_t(1)); - vaddps(zmm_T(0), zmm_T(0), zmm_O(0)); - fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); - fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); - fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); - vaddps(zmm_T(3), zmm_T(3), zmm_O(5)); - - for (int j = 0; j < tile_size; j++) { - vmovups(ptr[oreg_T + (j * alpha * simd_w - + i * simd_w) * typesize], zmm_T(j)); - } - } - for (int j = 0; j < tile_size; j++) { - for (int i = 0; i < alpha; i++) { - vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w - + i * simd_w) * typesize]); - } - vaddps(zmm_t(0), zmm_T(1), zmm_T(2)); - vaddps(zmm_t(1), zmm_T(3), zmm_T(4)); - vsubps(zmm_t(2), zmm_T(1), zmm_T(2)); - vsubps(zmm_t(3), zmm_T(3), zmm_T(4)); - - vaddps(zmm_O(0), zmm_t(0), zmm_t(1)); - vaddps(zmm_O(0), zmm_O(0), zmm_T(0)); - fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); - fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); - fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); - vaddps(zmm_O(3), zmm_O(3), zmm_T(5)); - - for (int i = 0; i < tile_size; i++) { - vmovups(ptr[oreg_O + (j * tile_size * simd_w - + i * simd_w) * typesize], zmm_O(i)); - } - } - }; - - auto inner_loops = [=]() { - init_G(); - load_src(); - trans_O_4x4_3x3(); - store_dst(); - }; - - preamble(); - inner_loops(); - postamble(); -} - -void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel - ::input_transform_data_ker_generate() -{ - bool is_fwd = one_of(jcp.prop_kind, - mkldnn_forward_training, mkldnn_forward_inference); - int inpw = is_fwd ? jcp.iw : jcp.ow; - int inph = is_fwd ? jcp.ih : jcp.oh; - int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; - int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; - int wp_max = inpw + l_pad; - int hp_max = inph + t_pad; - bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; - int G_size = 9; - - auto zmm_zero = Xbyak::Zmm(0); - auto zmm_temp = Xbyak::Zmm(31); - auto zmm_G = [=](int i) { - return Xbyak::Zmm(1 + i); - }; - auto zmm_I = [=](int i) { - return Xbyak::Zmm(1 + G_size + i); - }; - auto zmm_T = [=](int i) { - return Xbyak::Zmm(1 + G_size + alpha + i); - }; - auto zmm_t = [=](int i) { - return Xbyak::Zmm(1 + G_size + 2 * alpha + i); - }; - - auto init_G = [=]() { - mov(ireg_temp, ptr[param1 + GET_OFF(G)]); - for (int i = 0; i < G_size; i++) { - vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]); - } - }; - - auto load_src = [=]() { - mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp - mov(ireg_I, ptr[param1 + GET_OFF(M)]); - - xor_(ireg_zero, ireg_zero); - vpxord(zmm_zero, zmm_zero, zmm_zero); - - mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]); - shl(ireg_ydim, 2); // tj * tile_size (==4) - mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]); - shl(ireg_xdim, 2); // ti * tilesize (==4) - - for (int j = 0; j < alpha; j++) { - mov(ireg_temp, ireg_ydim); - add(ireg_temp, j); - - mov(ireg_mask_j, 0xffff); - cmp(ireg_temp, t_pad); - cmovl(ireg_mask_j, ireg_zero); - cmp(ireg_temp, hp_max); - cmovge(ireg_mask_j, ireg_zero); - - sub(ireg_temp, t_pad); - imul(ireg_temp, ireg_temp, inpw * simd_w * typesize); - mov(ireg_inp_j, ireg_src); - add(ireg_inp_j, ireg_temp); - - for (int i = 0; i < alpha; i++) { - - mov(ireg_temp, ireg_xdim); - add(ireg_temp, i); - - mov(ireg_mask, 0xffff); - cmp(ireg_temp, l_pad); - cmovl(ireg_mask, ireg_zero); - cmp(ireg_temp, wp_max); - cmovge(ireg_mask, ireg_zero); - and_(ireg_mask, ireg_mask_j); - - sub(ireg_temp, l_pad); - shl(ireg_temp, 4 + 2); - - vpxord(zmm_temp, zmm_temp, zmm_temp); - Opmask kmask = Opmask(7); - kmovw(kmask, ireg_mask_32); - vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]); - vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w) - * typesize], zmm_temp); - } - } - }; - - auto store_Iw = [=]() { - - mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); - mov(ireg_output, ptr[param1 + GET_OFF(dst)]); - - bool streamout - = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) - > 2 * LLC_data_size - ? true : false; - - if (not_tiled) { - mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]); - imul(ireg_tile_block, ireg_tile_block, - alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block - * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block - * typesize); - } - - mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); - imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur, - jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block - * jcp.dimK_reg_block * typesize); - - mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); - imul(ireg_tile_block_ur, ireg_tile_block_ur, - jcp.dimK_reg_block * typesize); - - add(ireg_output, ireg_nb_tile_block_ur); - add(ireg_output, ireg_tile_block_ur); - if (not_tiled) - add(ireg_output, ireg_tile_block); - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w - + i * simd_w) * typesize]); - - int j_base_offset = - j * alpha * jcp.dimN_block * jcp.dimK_nb_block - * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block - * typesize; - int i_base_offset = - i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block - * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; - - if (not_tiled && streamout) - vmovntps(ptr[ireg_output + j_base_offset + i_base_offset], - zmm_temp); - else - vmovups(ptr[ireg_output + j_base_offset + i_base_offset], - zmm_temp); - } - } - }; - - auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { - vmulps(zmm_temp, a, b); - vaddps(dst, zmm_temp, c); - }; - - auto trans_I_4x4_3x3 = [=]() { - mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); - mov(ireg_T, ptr[param1 + GET_OFF(T)]); - mov(ireg_I, ptr[param1 + GET_OFF(M)]); - - mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch - for (int i = 0; i < alpha; i++) { - for (int idx = 0; idx < alpha; idx++) { - vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w - + i * simd_w) * typesize]); - int j_base_offset = - i * alpha * jcp.dimN_block * jcp.dimK_nb_block - * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block - * typesize; - int idx_base_offset = - idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block - * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; - prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]); - } - - fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); - fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); - fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); - fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); - fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); - fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); - - fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); - fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); - fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); - fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); - fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); - fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); - - for (int idx = 0; idx < alpha; idx++) { - vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w) - * typesize],zmm_T(idx)); - } - } - for (int i = 0; i < alpha; i++) { - for (int idx = 0; idx < alpha; idx++) { - vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx - * simd_w) * typesize]); - } - - fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); - fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); - fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); - fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); - fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); - fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); - - fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); - fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); - fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); - fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); - fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); - fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); - - for (int idx = 0; idx < alpha; idx++) { - vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w) - * typesize],zmm_I(idx)); - } - } - }; - - auto inner_loops = [=]() { - init_G(); - load_src(); - trans_I_4x4_3x3(); - store_Iw(); - }; - - preamble(); - inner_loops(); - postamble(); -} - -status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d) -{ - if (!mayiuse(avx512_core)) { - return status::unimplemented; - } - - jcp.nthr = mkldnn_get_max_threads(); - - jcp.ver = ver_avx512_core; - jcp.prop_kind = cd.prop_kind; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = dst_d.dims()[2]; - jcp.ow = dst_d.dims()[3]; - jcp.kh = weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - jcp.r_pad = nstl::max( - 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); - jcp.b_pad = nstl::max( - 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - jcp.ohp = jcp.oh; - jcp.owp = jcp.ow; - - bool ok_to_pad_channels = jcp.ngroups == 1; - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - // Checking conditions not supported by these kernels - if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, - is_winograd_faster_than_direct(jcp))) - return status::unimplemented; - - if (jcp.ngroups != 1) - return status::unimplemented; - if ((jcp.kh != 3) || (jcp.kw != 3)) - return status::unimplemented; - if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) - return status::unimplemented; - if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) - return status::unimplemented; - if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) - return status::unimplemented; - - format_tag_t dat_tag = nChw16c; - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - if (jcp.src_tag != dat_tag) return status::unimplemented; - if (jcp.dst_tag != dat_tag) return status::unimplemented; - - if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) { - format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - if (jcp.wei_tag != wei_tag) - return status::unimplemented; - } - - bool layout_consistency = true - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= dst_d.padded_dims()[1] - && (one_of(weights_d.format_kind(), - format_kind::any, format_kind::wino) - || (jcp.ic <= weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= weights_d.padded_dims()[with_groups + 0])); - if (!layout_consistency) - return status::unimplemented; - - return status::success; -} - -void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) { - - /* ----------- dimM reg block ---------------------*/ - auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp, - int dimM_reg_block, int current_best) { - int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4; - return (dimM_reg_block >= 1) - && (dimM_reg_block <= max_dimM_reg_block ) - && (dimM_reg_block > current_best); - }; - jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp, - jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block); - - /* ----------- dimN reg block ---------------------*/ - - auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, - int dimN_reg_block, int current_best) { - return jcp.kernel_kind == embd_bcast - ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best - : dimN_reg_block >= 1 - && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block) - < jcp.nb_reg - && dimN_reg_block > current_best; - }; - jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp, - jcp.dimN, 1, test_cond_dimN_reg_block); -} - -status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { - if (jcp.ver != ver_avx512_core) - return status::unimplemented; - - jcp.kernel_kind = embd_bcast; - - set_kernel_dims_reg_block(jcp); - - /*-------------- L2 blocking for dimN block ---------*/ - - auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp, - int dimN_block, int current_best) { - return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) - && (dimN_block > current_best) - && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) - >= 1.5 * mkldnn_get_max_threads()); - }; - - jcp.dimN_block = get_divisor_satisfying_cond( - jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block); - jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; - - if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) - && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) { - - /* ------------------- L1 blocking for GEMM --------------*/ - /* -------------------- Choose dimK block ----------------*/ - - auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp, - int dimK_block, int current_best) { - return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5) - && (dimK_block > current_best); - }; - - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block); - - if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) { - jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; - - /* -------------- Choose dimM block -------------------*/ - auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp, - int dimM_block, int current_best) { - return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, - 0.2, 0.5) && (dimM_block > current_best); - }; - - jcp.dimM_block = get_divisor_satisfying_cond(jcp, - jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, - test_cond_dimM_block); - jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block - / jcp.dimM_simd_block; - - jcp.sched_policy = WSCHED_DATA_W_SGD; - return status::success; - } - - } - return status::unimplemented; -} - -void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) { - - set_kernel_dims_reg_block(jcp); - - //********************* Choosing dimK_block **********************// - auto test_cond1_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, - 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f) - && (dimK_block > current_best); - }; - - auto test_cond1_bis_dimK_block = []( - jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { - return check_cond1_bis(jcp.dimN_reg_block, dimK_block, - jcp.dimK_reg_block, 1, jcp.dimM_reg_block, - jcp.dimM_simd_block, .9f) - && (dimK_block > current_best); - }; - - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); - // If we are not able to use streams, we fall back to condition [1] - if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) - jcp.dimK_block = get_divisor_satisfying_cond( - jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); - jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; - - //********************* Choosing dimM_block **********************// - auto test_cond1_dimM_block = []( - jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { - return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, - jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, - jcp.dimM_simd_block, .5f) - && (dimM_block > current_best); - }; - - auto test_cond1_bis_dimM_block = []( - jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { - return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, - jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, - jcp.dimM_simd_block, .3f) - && (dimM_block > current_best); - }; - - if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) - jcp.dimM_block = get_divisor_satisfying_cond( - jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, - test_cond1_dimM_block); - else - jcp.dimM_block = get_divisor_satisfying_cond(jcp, - jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, - test_cond1_bis_dimM_block); - jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block - * jcp.dimM_reg_block); - - //******************* Choosing dimN_block *******************// - auto test_cond2_dimN_block = []( - jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { - return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, - jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, - jcp.dimM_reg_block, jcp.dimM_simd_block, .9f) - && (dimN_block > current_best); - }; - - jcp.dimN_block = get_divisor_satisfying_cond( - jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); - jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); -} - -status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) { - - jcp.kernel_kind = expl_bcast; - set_kernel_blocking_DATA_W_S_G_D(jcp); - if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block, - jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK, - .1f, .35f))) { - jcp.kernel_kind = embd_bcast; - set_kernel_blocking_DATA_W_S_G_D(jcp); - } - jcp.sched_policy = WSCHED_DATA_W_S_G_D; - return status::success; -} - -status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel( - jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) -{ - jcp.nb_reg = 32; - jcp.dimN = dimN; - jcp.dimK = dimK; - jcp.dimM = dimM; - jcp.sched_policy = WSCHED_INVALID; - - jcp.dimK_reg_block = 16; - jcp.dimM_simd_block = 16; - - if (jcp.kernel_kind == embd_bcast) { - jcp.dimM_reg_block = 1; - } - - if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success)) - set_wsched_DATA_W_S_G_D_avx512_core(jcp); - - assert(jcp.sched_policy != WSCHED_INVALID); - return status::success; -} - -bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_relu(0) || is_sum(0); // relu or sum - case 2: return (is_sum(0) && is_relu(1)) - || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum - case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu - default: return false; - } - - return false; -} - -status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_t &src_md, memory_desc_t &weights_md, - const memory_desc_t &dst_md, const primitive_attr_t &attr) { - - status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md); - - if (st != status::success) - return st; - - // Winograd specific initialization - jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; - jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; - jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; - jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1; - - status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); - - jcp.ic_simd_block = jcp.dimK_reg_block; - jcp.ic_block = jcp.dimK_block; - jcp.nb_ic = jcp.dimK_nb_block; - jcp.oc_simd_block = jcp.dimM_simd_block; - jcp.oc_block = jcp.dimM_block; - jcp.oc_reg_block = jcp.dimM_reg_block; - jcp.ic_reg_block = 1; - jcp.nb_oc = jcp.dimM_nb_block; - jcp.tile_block_ur = jcp.dimN_reg_block; - jcp.nb_tile_block_ur = jcp.dimN_block; - jcp.tile_block = jcp.dimN_nb_block; - - /* re-create weights primitive descriptor - and set weights wino_blocking */ - if (cd.prop_kind == mkldnn_forward_inference) { - memory_desc_t expect_wei_md = weights_md; - - expect_wei_md.format_kind = format_kind::wino; - expect_wei_md.data_type = data_type::f32; - mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; - wd.wino_format = mkldnn_wino_wei_OBaaIBOIio; - wd.r = 3; - wd.alpha = 6; - - wd.ic = jcp.ic; - wd.oc = jcp.oc; - wd.ic_block = jcp.dimK_reg_block; - wd.oc_block = jcp.dimM_simd_block; - wd.ic2_block = jcp.dimK_block; - wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block; - size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc; - wd.size = max_size; - wd.adj_scale = 1.f; - - if (weights_md.format_kind == format_kind::any) - weights_md = expect_wei_md; - if (weights_md != expect_wei_md) - return status::unimplemented; - } - - return res; -} - -status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d) -{ - status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); - - if (st != status::success) - return st; - - jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; - jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; - jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; - - status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); - - jcp.oc_simd_block = jcp.dimK_reg_block; - jcp.oc_block = jcp.dimK_block; - jcp.nb_oc = jcp.dimK_nb_block; - jcp.ic_simd_block = jcp.dimM_simd_block; - jcp.ic_block = jcp.dimM_block; - jcp.ic_reg_block = jcp.dimM_reg_block; - jcp.oc_reg_block = 1; - jcp.nb_ic = jcp.dimM_nb_block; - jcp.tile_block_ur = jcp.dimN_reg_block; - jcp.nb_tile_block_ur = jcp.dimN_block; - jcp.tile_block = jcp.dimN_nb_block; - - return res; -} - -void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: -src_transform_generate() { - constexpr int G_size = 9; - const size_t ifwp = jcp.iw + jcp.l_pad; - const size_t ifhp = jcp.ih + jcp.t_pad; - - auto zmm_G = [=](int i) { - return Xbyak::Zmm(i); - }; - auto zmm_I = [=](int i) { - return Xbyak::Zmm(G_size + i); - }; - auto zmm_T = [=](int i) { - return Xbyak::Zmm(G_size + alpha + i); - }; - auto zmm_t = [=](int i) { - return Xbyak::Zmm(G_size + 2 * alpha + i); - }; - - auto init_G = [=]() { - mov(reg_G, ptr[reg_transp + GET_OFF(G)]); - for (int i = 0; i < G_size; i++) { - vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); - } - }; - - auto load_src = [=]() { - mov(reg_I, ptr[reg_transp + GET_OFF(M)]); - xor_(reg_zero, reg_zero); - - mov(reg_ydim, reg_tj); - shl(reg_ydim, 2); //tj * tile_size(=4) - - for (int j = 0; j < alpha; j++) { - /* check if tile index is within physical spatial boundaries*/ - mov(reg_maskj, 0xffff); - cmp(reg_ydim, jcp.t_pad); - cmovl(reg_maskj, reg_zero); - cmp(reg_ydim, ifhp); - cmovge(reg_maskj, reg_zero); - - /*address offset for tile in src*/ - mov(reg_src_offset, reg_ydim); - sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad - imul(reg_src_offset, reg_src_offset, jcp.iw); - - mov(reg_xdim, reg_ti); - shl(reg_xdim, 2); // xdim = ti * tile_size - - add(reg_src_offset, reg_xdim); - sub(reg_src_offset, jcp.l_pad); - imul(reg_src_offset, reg_src_offset, simd_w * typesize); - for (int i = 0; i < alpha; i++) { - /* check if tile index is within physical spatial boundaries*/ - mov(reg_maski, 0xffff); - cmp(reg_xdim, jcp.l_pad); - cmovl(reg_maski, reg_zero); - cmp(reg_xdim, ifwp); - cmovge(reg_maski, reg_zero); - and_(reg_maski, reg_maskj); - - Opmask kmask_src = Xbyak::Opmask(7); - auto zmm_src = Xbyak::Zmm(31); - kmovw(kmask_src, reg_maski_32); - vpxord(zmm_src, zmm_src, zmm_src); - vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]); - vmovups(ptr[reg_I], zmm_src); - - add(reg_xdim, 1); //xdim = ti * tile_size + i - add(reg_src_offset, simd_w * typesize); - add(reg_I, simd_w * typesize); - } - add(reg_ydim, 1); - } - }; - - auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) { - vmovups(dst, c); - vfmadd231ps(dst, a, b); - }; - - auto trans_I_3x3_4x4 = [=]() { - //Use 24 registers - mov(reg_I, ptr[reg_transp + GET_OFF(M)]); - mov(reg_T, ptr[reg_transp + GET_OFF(T)]); - for (int i = 0; i < alpha; i++) { - for (int j = 0; j < alpha; j++) { - size_t I_off = (j * alpha + i) * simd_w * typesize; - vmovups(zmm_I(j), ptr[reg_I + I_off]); - } - - fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); - fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); - fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); - fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); - fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); - fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); - - fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); - fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); - fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); - fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); - fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); - fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); - - for (int j = 0; j < alpha; j++) { - vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize], - zmm_T(j)); - } - - } - - for (int j = 0; j < alpha; j++) { - for (int i = 0; i < alpha; i++) { - vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]); - } - - fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); - fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); - fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); - fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); - fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); - fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); - - fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); - fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); - fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); - fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); - fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); - fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); - - for (int i = 0; i < alpha; i++) { - size_t dst_off = (j * alpha * jcp.ic_block - * jcp.nb_tile_block_ur * jcp.tile_block_ur - + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur) - * simd_w * typesize; - vmovups(ptr[reg_dst + dst_off], zmm_I(i)); - } - } - }; - - auto compute_transform_SDGtWo = [=]() { - mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); - mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - xor_(reg_tile_count, reg_tile_count); - Label loop_mb, loop_jtiles, loop_itiles, done; - L(loop_mb); - { - L(loop_jtiles); - { - L(loop_itiles); - { - load_src(); - - trans_I_3x3_4x4(); - - add(reg_tile_count, 1); - cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); - jge(done); - - add(reg_dst, simd_w * typesize); - add(reg_ti, 1); - cmp(reg_ti, jcp.itiles); - jl(loop_itiles); - } - xor_(reg_ti, reg_ti); - add(reg_tj, 1); - cmp(reg_tj, jcp.jtiles); - jl(loop_jtiles); - } - xor_(reg_tj, reg_tj); - add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize); - jmp(loop_mb); - } - L(done); - }; - - auto compute_transform = [=]() { - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - xor_(reg_ti, reg_ti); - xor_(reg_tj, reg_tj); - - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); - imul(reg_temp, reg_tile_count, simd_w * typesize); - add(reg_dst, reg_temp); - - Label loop_jtiles, loop_itiles, next_tile_block, next_tile; - L(loop_jtiles); - - { - L(loop_itiles); - { - load_src(); - - trans_I_3x3_4x4(); - - add(reg_tile_count, 1); - cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); - jge(next_tile_block); - add(reg_dst, simd_w * typesize); - jmp(next_tile); - - L(next_tile_block); - sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) - * simd_w * typesize); - size_t tblk_off = alpha * alpha * jcp.ic_block - * jcp.nb_tile_block_ur * jcp.tile_block_ur - * simd_w * typesize; - add(reg_dst, tblk_off); - xor_(reg_tile_count, reg_tile_count); - - L(next_tile); - add(reg_ti, 1); - cmp(reg_ti, jcp.itiles); - jl(loop_itiles); - } - xor_(reg_ti, reg_ti); - add(reg_tj, 1); - cmp(reg_tj, jcp.jtiles); - jl(loop_jtiles); - } - }; - - preamble(); - init_G(); - if (jcp.sched_policy == WSCHED_WEI_SDGtWo) - compute_transform_SDGtWo(); - else - compute_transform(); - postamble(); -} - -void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: -diff_dst_transform_generate(bool with_bias) { - - constexpr int G_size = 8; - auto zmm_G = [](int i) { - return Xbyak::Zmm(31); - }; - - auto zmm_src = [=](int j, int i) { - return Xbyak::Zmm(G_size + j * 4 + i); - }; - - auto zmm_bias = Xbyak::Zmm(31); - - auto load_src = [=]() { - if (with_bias) vmovups(zmm_bias, ptr[reg_bias]); - mov(reg_ydim, reg_tj); - shl(reg_ydim, 2); //tj * tile_size(=4) - for (int j = 0; j < tile_size; j++) { - /* check if tile index is within physical spatial boundaries*/ - mov(reg_maskj, 0xffff); - cmp(reg_ydim, jcp.oh); - cmovge(reg_maskj, reg_zero); - - /*address offset for tile in src*/ - mov(reg_src_offset, reg_ydim); - imul(reg_src_offset, reg_src_offset, jcp.ow); - - mov(reg_xdim, reg_ti); - shl(reg_xdim, 2); // xdim = ti * tile_size - - add(reg_src_offset, reg_xdim); - imul(reg_src_offset, reg_src_offset, simd_w * typesize); - for (int i = 0; i < tile_size; i++) { - /* check if tile index is within physical spatial boundaries*/ - mov(reg_maski, 0xffff); - cmp(reg_xdim, jcp.ow); - cmovge(reg_maski, reg_zero); - and_(reg_maski, reg_maskj); - - Opmask kmask_src = Xbyak::Opmask(7); - kmovw(kmask_src, reg_maski_32); - vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i)); - vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]); - if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias, - ptr[reg_src + reg_src_offset]); - - add(reg_xdim, 1); //xdim = ti * tile_size + i - add(reg_src_offset, simd_w * typesize); - } - add(reg_ydim, 1); - } - if(with_bias) vmovups(ptr[reg_bias], zmm_bias); - }; - - auto zmm_t = [=](int i) { - return Xbyak::Zmm(G_size + 16 + i); - }; - - auto zmm_T = [=](int j, int i) { - return Xbyak::Zmm(j * 4 + i); - }; - - auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) { - if (jcp.sched_policy == WSCHED_WEI_SDGtWo) - vmovups(ptr[reg_dst + dst_off], a); - else - vmovntps(ptr[reg_dst + dst_off], a); - }; - - auto trans_W_3x3_4x4 = [=]() { - mov(reg_G, ptr[reg_transp + GET_OFF(G)]); - for (int i = 0; i < tile_size; i++) { - vbroadcastss(zmm_G(0), ptr[reg_G]); - vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0)); - - vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); - vmovups(zmm_t(1), zmm_t(0)); - vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1)); - - vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); - vmovups(zmm_t(2), zmm_t(0)); - vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2)); - - vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); - vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3)); - - vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); - vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4)); - - vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); - vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5)); - - vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); - vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6)); - - vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); - vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7)); - vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3)); - vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3)); - vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4)); - vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4)); - vmovups(zmm_T(5, i), zmm_src(3, i)); - } - - for (int j = 0; j < alpha; j++) { - vbroadcastss(zmm_G(0), ptr[reg_G]); - vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0)); - - vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); - vmovups(zmm_t(1), zmm_t(0)); - vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1)); - - vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); - vmovups(zmm_t(2), zmm_t(0)); - vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2)); - - vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); - vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3)); - - vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); - vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4)); - - vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); - vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5)); - - vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); - vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6)); - - vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); - vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7)); - vsubps(zmm_t(5), zmm_t(1), zmm_t(3)); - vaddps(zmm_t(1), zmm_t(1), zmm_t(3)); - vaddps(zmm_t(6), zmm_t(2), zmm_t(4)); - vsubps(zmm_t(2), zmm_t(2), zmm_t(4)); - vmovups(zmm_t(3), zmm_T(j, 3)); - - int alpha_offset = (jcp.oc / jcp.nb_oc) - * (jcp.ntiles / jcp.tile_block) * typesize; - int dst_off = j * alpha * alpha_offset; - movps(reg_dst, dst_off, zmm_t(0)); - dst_off += alpha_offset; - movps(reg_dst, dst_off, zmm_t(5)); - dst_off += alpha_offset; - movps(reg_dst, dst_off, zmm_t(1)); - dst_off += alpha_offset; - movps(reg_dst, dst_off, zmm_t(6)); - dst_off += alpha_offset; - movps(reg_dst, dst_off, zmm_t(2)); - dst_off += alpha_offset; - movps(reg_dst, dst_off, zmm_t(3)); - } - - }; - auto compute_transform_SDGtWo = [=]() { - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); - - xor_(reg_zero, reg_zero); - xor_(reg_oc_ur, reg_oc_ur); - Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done; - - L(loop_oc_ur); - { - mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); - mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); - xor_(reg_tile_count, reg_tile_count); - L(loop_mb); - { - L(loop_jtiles); - { - L(loop_itiles); - { - load_src(); - - trans_W_3x3_4x4(); - - add(reg_tile_count, 1); - cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); - jge(tiles_done); - - add(reg_dst, jcp.oc_reg_block * simd_w * typesize); - add(reg_ti, 1); - cmp(reg_ti, jcp.itiles); - jl(loop_itiles); - } - xor_(reg_ti, reg_ti); - add(reg_tj, 1); - cmp(reg_tj, jcp.jtiles); - jl(loop_jtiles); - } - xor_(reg_tj, reg_tj); - add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize); - jmp(loop_mb); - } - - L(tiles_done); - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - add(reg_dst, simd_w * typesize); - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); - - if (with_bias) add(reg_bias, simd_w * typesize); - add(reg_oc_ur, 1); - cmp(reg_oc_ur, jcp.oc_reg_block); - jl(loop_oc_ur); - } - }; - - auto compute_transform = [=]() { - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - mov(reg_G, ptr[reg_transp + GET_OFF(G)]); - if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); - - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); - imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); - add(reg_dst, reg_temp); - - xor_(reg_zero, reg_zero); - xor_(reg_oc_ur, reg_oc_ur); - Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile; - - L(loop_oc_ur); - { - xor_(reg_ti, reg_ti); - xor_(reg_tj, reg_tj); - - L(loop_jtiles); - { - L(loop_itiles); - { - load_src(); - - trans_W_3x3_4x4(); - - add(reg_tile_count, 1); - cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); - jge(next_tile_block); - add(reg_dst, jcp.oc_reg_block * simd_w * typesize); - jmp(next_tile); - - L(next_tile_block); - sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) - * jcp.oc_reg_block * simd_w * typesize); - int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc) - * (jcp.ntiles/jcp.tile_block) * typesize; - add(reg_dst, tblk_off); - xor_(reg_tile_count, reg_tile_count); - - L(next_tile); - add(reg_ti, 1); - cmp(reg_ti, jcp.itiles); - jl(loop_itiles); - } - xor_(reg_ti, reg_ti); - add(reg_tj, 1); - cmp(reg_tj, jcp.jtiles); - jl(loop_jtiles); - } - - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); - imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); - add(reg_dst, reg_temp); - add(reg_dst, simd_w * typesize); - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); - - if (with_bias) add(reg_bias, simd_w * typesize); - add(reg_oc_ur, 1); - cmp(reg_oc_ur, jcp.oc_reg_block); - jl(loop_oc_ur); - } - }; - - preamble(); - if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { - compute_transform_SDGtWo(); - } else { - compute_transform(); - } - postamble(); -} - -void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: -diff_weights_transform_generate(bool first_tile) { - int G_size = 4; - - auto zmm_G = [](int i) { - return Xbyak::Zmm(i); - }; - - auto init_G = [=]() { - mov(reg_G, ptr[reg_transp + GET_OFF(G)]); - for (int i = 0; i < G_size; i++) - vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); - }; - - auto zmm_src = [=](int i) { - return Xbyak::Zmm(G_size + i); - }; - - auto load_src = [=](int i) { - for (int j = 0; j < alpha; j++) { - size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block - * jcp.ic_block * simd_w * simd_w * typesize; - size_t src_off = (j * alpha + i) * alpha_offset; - vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off)); - } - }; - - auto zmm_t = [=](int i) { - return Xbyak::Zmm(G_size + 6 + i); - }; - - auto zmm_T = [=](int j, int i) { - return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i); - }; - - auto zmm_dst = [=](int i) { - return Xbyak::Zmm(G_size + i); - }; - - auto zmm_temp = Xbyak::Zmm(31); - - auto store_dst = [=](int j) { - for (int i = 0; i < jcp.kw; i++) { - size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize; - - if (!first_tile) { - vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off)); - vaddps(zmm_dst(i), zmm_dst(i), zmm_temp); - } - vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i)); - } - }; - - auto compute_transform = [=] () { - mov(reg_src, ptr[reg_transp + GET_OFF(src)]); - mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); - - xor_(reg_ic_simd, reg_ic_simd); - Label loop_ic_simd; - L(loop_ic_simd); - { - for (int i = 0; i < alpha; i++) { - load_src(i); - - vaddps(zmm_t(0), zmm_src(1), zmm_src(2)); - vaddps(zmm_t(1), zmm_src(3), zmm_src(4)); - vmovups(zmm_t(2), zmm_src(5)); - vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); - - vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0)); - vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1)); - vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2)); - vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1)); - vsubps(zmm_temp, zmm_src(3), zmm_src(4)); - vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2)); - vmovups(zmm_T(2, i), zmm_t(2)); - vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3)); - } - - for (int j = 0; j < jcp.kh; j++) { - vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2)); - vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4)); - vmovups(zmm_t(2), zmm_T(j, 5)); - vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); - - vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0)); - vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1)); - vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2)); - vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1)); - vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4)); - vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2)); - vmovups(zmm_dst(2), zmm_t(2)); - vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3)); - - store_dst(j); - } - - add(reg_src, jcp.oc_reg_block * simd_w * typesize); - add(reg_dst, simd_w * typesize); - add(reg_ic_simd, 1); - cmp(reg_ic_simd, simd_w); - jl(loop_ic_simd); - } - }; - preamble(); - push(reg_EVEX_max_8b_offt); - mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); - init_G(); - compute_transform(); - pop(reg_EVEX_max_8b_offt); - postamble(); -} - -void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate( - bool is_first_tile) -{ - auto zmm_srcA = [=]() { - return Xbyak::Zmm(0); - }; - - auto zmm_srcB = [=] (size_t N_ur){ - return Xbyak::Zmm(N_ur + 1); - }; - - auto broadcastB = [=](size_t K_ur) { - for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { - size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast) - * sizeof(float); - vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off)); - } - }; - - auto load_srcA = [=] (size_t K_ur, int M_ur) { - size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block - + M_ur * jcp.dimM_simd_block) * sizeof(float); - vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off)); - }; - - auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){ - size_t idx = 1 // zmm_srcA - + jcp.dimN_bcast_ur // zmm_srcB - + M_reg_ur * jcp.dimN_bcast_ur + N_bcast; - assert(idx < 32); - return Xbyak::Zmm(idx); - }; - auto prepare_accumm = [=](){ - for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { - for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { - Zmm zmm = zmm_dstC(M_reg_ur, N_bcast); - vpxord(zmm, zmm, zmm); - } - } - }; - - auto store_dstC = [=](){ - /******** Write C back to memory *******/ - for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) { - for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) { - Zmm zmm = zmm_dstC(M_reg, N_ur); - size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block - + M_reg * jcp.dimM_simd_block) * sizeof(float); - if (!is_first_tile) { - vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off)); - vaddps(zmm, zmm, Xbyak::Zmm(0)); - } - vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm); - } - } - }; - - auto inner_loops = [=]() { - Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur; - - mov(reg_dimM_block_loop_cnt, jcp.dimM_block); - L(dimM_block_loop); - { /************* OC_block (M) loop ***********/ - mov(reg_dimN_block_loop_cnt, jcp.dimN_block); - L(dimN_block_loop); - { /*************** IC_block (N) loop *********/ - - mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur); - L(dimN_bcast_ur); - { - prepare_accumm(); - - mov(reg_dimK_block_loop_cnt, jcp.dimK_block); - L(dimK_block_loop); - { - /************* nb_tile_ur(K) loop ********/ - for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) { - - broadcastB(K_ur); - - for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { - load_srcA(K_ur, M_reg_ur); - for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) { - vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(), - zmm_srcB(N_bcast)); - } - } - } - add(reg_srcA, jcp.dimK_reg_block - * jcp.dimM_reg_block * jcp.dimM_simd_block - * sizeof(float)); - add(reg_srcB, jcp.dimK_reg_block - * jcp.dimN_reg_block - * sizeof(float)); - sub(reg_dimK_block_loop_cnt, 1); - jnz(dimK_block_loop); - } - - store_dstC(); - - sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block - * jcp.dimM_reg_block * jcp.dimM_simd_block - * sizeof(float)); - sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block - * jcp.dimN_reg_block - * sizeof(float)); - add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float)); - add(reg_dstC, jcp.dimN_bcast_ur - * jcp.dimM_reg_block * jcp.dimM_simd_block - * sizeof(float)); - sub(reg_nb_dimN_bcast_ur, 1); - jnz(dimN_bcast_ur); - } - - sub(reg_srcB, jcp.dimN_reg_block * sizeof(float)); - add(reg_srcB, jcp.dimK_block - * jcp.dimK_reg_block - * jcp.dimN_reg_block * sizeof(float)); - sub(reg_dimN_block_loop_cnt, 1); - jnz(dimN_block_loop); - } - - sub(reg_srcB, jcp.dimN_block - * jcp.dimK_block * jcp.dimK_reg_block - * jcp.dimN_reg_block - * sizeof(float)); - add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block - * jcp.dimM_reg_block * jcp.dimM_simd_block - * sizeof(float)); - sub(reg_dimM_block_loop_cnt, 1); - jnz(dimM_block_loop); - } - }; - - /* Preamble */ - preamble(); - - inner_loops(); - - /* Postamble */ - postamble(); - ret(); -} - -namespace { - -void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) { -/*M params*/ - jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block - / jcp.dimM_simd_block; - jcp.oc_reg_block = jcp.dimM_reg_block; - jcp.oc_block = jcp.dimM_block; - jcp.nb_oc = jcp.dimM_nb_block; - /*N params*/ - jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; - jcp.ic_block = jcp.dimN_block; - jcp.nb_ic = jcp.dimN_nb_block; - - /*K params*/ - jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; - jcp.tile_block_ur = jcp.dimK_reg_block; - jcp.nb_tile_block_ur = jcp.dimK_block; - jcp.tile_block = jcp.dimK_nb_block; -} - -status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { - - size_t K_blk_ur, N_blk, M_blk; - /* IS this strategy feasible? */ - auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { - size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); - size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); - size_t nthreads = mkldnn_get_max_threads(); - return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size) - && (jcp.dimK / nthreads >= 1.0); - }; - - auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, - int max_block=1) { - size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float); - size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); - size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); - size_t nthreads = mkldnn_get_max_threads(); - bool load_balance=true; - if (!(jcp.dimK % nthreads)) { - load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0); - } - return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) - && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) - && load_balance - && (M_L2_block < L2_cache_size); - }; - - auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, - int useless_arg=0) { - return (dimK_ur >= 2) && (dimK_ur <= 8); - }; - - auto blocking_ok = [&](){ - size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block - * K_blk_ur * sizeof(float); - size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block - * K_blk_ur * sizeof(float); - size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block - * N_blk * jcp.dimN_reg_block * sizeof(float); - size_t L2_block = M_L2_block + V_L2_block + U_L2_block; - /*Replace 2.375 with L2+L3 cache size*/ - return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size); - }; - - if (test_MV_large_enough(jcp)) { - if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { - jcp.dimM_reg_block = 2; - } else { - jcp.dimM_reg_block = 1; - } - jcp.dimM_simd_block = jcp.oc_simd_block; - jcp.dimN_reg_block = jcp.ic_simd_block; - jcp.dimN_bcast_ur = 8; - /*dimK_block and dimK_ur*/ - size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1); - - jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block; - jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block; - for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) { - if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) { - for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { - if (!(jcp.dimN_block % N_blk)) { - for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { - if (!(jcp.dimM_block % M_blk) && blocking_ok()) { - jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); - if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented; - jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; - jcp.dimN_block = N_blk; - jcp.dimM_block = M_blk; - jcp.sched_policy = WSCHED_WEI_SDGtWo; - set_jcp_WEI_params(jcp); - jcp.nthr = nstl::min(mkldnn_get_max_threads(), - jcp.tile_block); - return status::success; - } - } - } - } - } - } - } - return status::unimplemented; -} - -status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { - if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { - jcp.dimM_reg_block = 2; - } else { - jcp.dimM_reg_block = 1; - } - jcp.dimN_bcast_ur = 8; - jcp.dimN_reg_block = jcp.ic_simd_block; - jcp.dimM_simd_block = jcp.oc_simd_block; - jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; - jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; - float C1 = 0.0, C2 = 0.0; - float C1_max = 0.5, C2_max = 1.4; - int N_blk, M_blk, K_blk_ur; - - auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, - int useless_arg=0) { - return (dimK_ur >= 2) && (dimK_ur <= 8); - }; - - auto blocking_ok = [&]() -> bool { - size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float); - size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float); - bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size) - && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size); - - size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block; - size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block; - size_t nb_K_blk = jcp.dimK / K_blk_ur; - size_t nthreads = mkldnn_get_max_threads(); - bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads; - if (!(nb_K_blk % nthreads)) { - load_balance = load_balance && (nb_K_blk % nthreads == 0); - } - - size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float); - - size_t L2_block = V_L2_block; - /*Replace 2.375 with L2+L3 cache size*/ - bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size); - return L1_cond && load_balance && L2_cond; - }; - - for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) { - if (jcp.dimK % K_blk_ur == 0) { - for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { - if (jcp.dimN_block % N_blk == 0) { - for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { - if (jcp.dimM_block % M_blk == 0) { - if (blocking_ok()) { - jcp.dimN_block = N_blk; - jcp.dimM_block = M_blk; - jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); - jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; - jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; - set_jcp_WEI_params(jcp); - return status::success; - } - } - } - } - } - } - } - jcp.dimK_reg_block = 1; - jcp.dimK_block = 1; - jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; - set_jcp_WEI_params(jcp); - return status::success; -} -} // namespace -status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf( - jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, - const memory_desc_wrapper &diff_weights_d) { - if (!mayiuse(avx512_core)) - return status::unimplemented; - else - jcp.ver = ver_avx512_core; - - jcp.nthr = mkldnn_get_max_threads(); - - jcp.prop_kind = cd.prop_kind; - const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; - jcp.mb = src_d.dims()[0]; - jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = diff_dst_d.dims()[2]; - jcp.ow = diff_dst_d.dims()[3]; - jcp.kh = diff_weights_d.dims()[with_groups + 2]; - jcp.kw = diff_weights_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.r_pad = nstl::max( - 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); - jcp.b_pad = nstl::max( - 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - jcp.ohp = jcp.oh; - jcp.owp = jcp.ow; - jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - bool ok_to_pad_channels = jcp.ngroups == 1; - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - } - - // Winograd specific initialization - jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; - jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; - jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; - - // Winograd kernel works only for 3x3 convolution with stride 1 - if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, - is_winograd_faster_than_direct(jcp))) - return status::unimplemented; - - if (jcp.ngroups != 1) - return status::unimplemented; - if ((jcp.kh != 3) || (jcp.kw != 3)) - return status::unimplemented; - if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) - return status::unimplemented; - if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) - return status::unimplemented; - if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) - return status::unimplemented; - - format_tag_t dat_tag = nChw16c; - format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); - - if (jcp.src_tag != dat_tag) return status::unimplemented; - if (jcp.wei_tag != wei_tag) return status::unimplemented; - if (jcp.dst_tag != dat_tag) return status::unimplemented; - - bool layout_consistency = true - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= diff_dst_d.padded_dims()[1] - && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] - && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; - if (!layout_consistency) return status::unimplemented; - - /******************Kernel blocking Parameters ***********/ - jcp.ic_simd_block = simd_w; - jcp.oc_simd_block = simd_w; - - jcp.dimK = jcp.ntiles; - jcp.dimN = jcp.ic; - jcp.dimM = jcp.oc; - jcp.dimM_simd_block = jcp.oc_simd_block; - jcp.dimN_reg_block = jcp.ic_simd_block; - jcp.sched_policy = WSCHED_INVALID; - status_t res = set_wsched_WEI_SDGtWo(jcp); - if (res == status::unimplemented) { - res = set_wsched_WEI_S_D_Giot_W(jcp); - assert(res == status::success); - } - return res; -} -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp deleted file mode 100644 index 025a554d9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp +++ /dev/null @@ -1,291 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP -#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP - -#include "c_types_map.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" - -#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel - : public jit_generator { - _jit_avx512_core_fp32_wino_conv_4x3_data_kernel( - jit_conv_winograd_conf_t ajcp) - : jcp(ajcp) { - { - this->weights_transform_data_ker_generate(); - weights_transform_data_ker - = (decltype(weights_transform_data_ker)) this->getCode(); - } - { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->input_transform_data_ker_generate(); - input_transform_data_ker = (decltype(input_transform_data_ker))addr; - } - { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->output_transform_data_ker_generate(); - output_transform_data_ker - = (decltype(output_transform_data_ker))addr; - } - { - align(); - const Xbyak::uint8 *addr = getCurr(); - this->gemm_loop_generate(); - gemm_loop_ker = (decltype(gemm_loop_ker))addr; - } - } - - DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel) - - static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d); - - static status_t init_conf_kernel( - jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); - - jit_conv_winograd_conf_t jcp; - void (*gemm_loop_ker)(float *, const float *, const float *, const int); - void (*input_transform_data_ker)(jit_wino_transform_call_s *); - void (*output_transform_data_ker)(jit_wino_transform_call_s *); - void (*weights_transform_data_ker)(jit_wino_transform_call_s *); - -protected: - using reg64_t = const Xbyak::Reg64; - using reg32_t = const Xbyak::Reg32; - enum { typesize = sizeof(float) }; - - void gemm_loop_generate(); - void input_transform_data_ker_generate(); - void output_transform_data_ker_generate(); - void weights_transform_data_ker_generate(); - - /* registers used for GEMM */ - reg64_t reg_dstC = abi_param1; - reg64_t reg_srcA = abi_param2; - reg64_t reg_srcB = abi_param3; - reg64_t reg_is_beta_zero = abi_param4; - - reg64_t reg_dimM_block_loop_cnt = r10; - reg64_t reg_dimK_block_loop_cnt = r11; - - /* registers used for transforms*/ - reg64_t param = abi_param1; - - /* registers used for output_transform_data_ker */ - reg64_t oreg_temp = abi_not_param1; - reg64_t oreg_Ow = r9; - reg64_t oreg_src = r11; - reg64_t oreg_tile_block = r12; - reg64_t oreg_tile_block_ur = r13; - reg64_t oreg_nb_tile_block_ur = r14; - reg64_t oreg_O = r8; - reg64_t oreg_T = r10; - reg64_t oreg_dst = r11; - reg64_t oreg_ydim = r14; - reg64_t oreg_xdim = r15; - reg64_t oreg_out_j = r12; - reg64_t oreg_bias = rbx; - reg64_t imm_addr64 = rax; - - /* registers used for input_transform_data_ker */ - reg64_t ireg_temp = abi_not_param1; - reg64_t ireg_jtiles = rax; - reg64_t ireg_itiles = rbx; - reg64_t ireg_I = r8; - reg64_t ireg_src = r13; - reg64_t ireg_ydim = r14; - reg64_t ireg_xdim = r15; - reg64_t ireg_inp_j = r12; - reg64_t ireg_inp_i = rdx; - reg64_t ireg_mask_j = r11; - reg64_t ireg_mask = rsi; - reg32_t ireg_mask_32 = esi; - reg64_t ireg_zero = r9; - reg64_t ireg_Iw = r9; - reg64_t ireg_T = r10; - reg64_t ireg_tile_block = r12; - reg64_t ireg_tile_block_ur = r13; - reg64_t ireg_nb_tile_block_ur = r14; - reg64_t ireg_output = r15; - - /* registers used for wei transform */ - reg64_t wreg_temp = abi_not_param1; - reg64_t wreg_F = r8; - reg64_t wreg_src = r9; - reg64_t wreg_MT = r15; - reg64_t wreg_M = r14; - reg64_t wreg_dst = r10; - reg64_t wreg_dst_aux = r9; - reg64_t wreg_dst_idx = r8; - reg64_t wreg_Fw = r11; - reg64_t wreg_T = r12; - reg64_t wreg_cnt_j = rdx; - reg64_t wreg_F_aux = r14; - reg64_t wreg_Fw_aux = r15; -}; - -struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel - : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { - using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: - _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; - - static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); - - static status_t init_conf(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_t &src_md, - memory_desc_t &weights_md, const memory_desc_t &dst_md, - const primitive_attr_t &attr); -}; - -struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel - : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { - using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: - _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; - - static status_t init_conf(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d); -}; - -struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel - : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS( - _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32) - - jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( - jit_conv_winograd_conf_t ajcp) - : jcp(ajcp) - { - //******************* First iter kernel ********************// - this->gemm_loop_generate(true); - gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode(); - - align(); - const Xbyak::uint8 *addr = getCurr(); - this->src_transform_generate(); - src_transform = (decltype(src_transform))addr; - - if (jcp.with_bias) { - align(); - addr = getCurr(); - this->diff_dst_transform_generate(true); - diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr; - } - - align(); - addr = getCurr(); - this->diff_dst_transform_generate(false); - diff_dst_transform = (decltype(diff_dst_transform))addr; - - if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) { - align(); - addr = getCurr(); - this->gemm_loop_generate(false); - gemm_loop_ker = (decltype(gemm_loop_ker))addr; - } - - align(); - addr = getCurr(); - this->diff_weights_transform_generate(true); - diff_weights_transform = (decltype(diff_weights_transform))addr; - - if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { - align(); - addr = getCurr(); - this->diff_weights_transform_generate(false); - diff_weights_transform_accum = - (decltype(diff_weights_transform_accum))addr; - }; - } - - static status_t init_conf(jit_conv_winograd_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &diff_dst_d, - const memory_desc_wrapper &diff_weights_d); - - jit_conv_winograd_conf_t jcp; - void (*gemm_loop_ker)(float *, const float *, const float *); - void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); - void (*src_transform)(jit_wino_transform_call_s *); - void (*diff_dst_transform)(jit_wino_transform_call_s *); - void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *); - void (*diff_weights_transform)(jit_wino_transform_call_s *); - void (*diff_weights_transform_accum)(jit_wino_transform_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - using reg32_t = const Xbyak::Reg32; - enum { typesize = sizeof(float) }; - - void src_transform_generate(); - void diff_dst_transform_generate(bool with_bias); - void diff_weights_transform_generate(bool first_tile); - - /*registers common to transforms*/ - reg64_t reg_transp = abi_param1; - reg64_t reg_ti = rbx; - reg64_t reg_tj = abi_not_param1; - reg64_t reg_src = r8; - reg64_t reg_dst = r9; - reg64_t reg_G = rsi; /*TODO: check if this is ok*/ - reg64_t reg_temp = rsi; - - /*registers common to src/diff_dst transform*/ - reg64_t reg_I = r10; - reg64_t reg_ydim = r11; - reg64_t reg_xdim = r12; - reg64_t reg_src_offset = r13; - reg64_t reg_zero = r14; - reg64_t reg_tile_count = r15; - reg64_t reg_maski = rsi; - reg32_t reg_maski_32 = esi; - reg64_t reg_maskj = rdx; - - reg64_t reg_T = rax; - reg64_t reg_oc_ur = rax; - reg64_t reg_ic_simd = r14; - reg64_t reg_bias = r10; - - void gemm_loop_generate(bool is_first_tile); - - reg64_t reg_dstC = abi_param1; - reg64_t reg_srcA = abi_param2; - reg64_t reg_srcB = abi_param3; - - reg64_t reg_dimM_block_loop_cnt = r9; - reg64_t reg_dimN_block_loop_cnt = r10; - reg64_t reg_nb_dimN_bcast_ur = r11; - reg64_t reg_dimK_block_loop_cnt = r12; -}; -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp deleted file mode 100644 index 002010ffa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp +++ /dev/null @@ -1,1284 +0,0 @@ -/******************************************************************************* - * Copyright 2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp" -#include "jit_generator.hpp" - -#include - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -namespace { - // Below scales are applied to source and weights data accordingly - // because this winograd implementation - // transforms source which may increase values up to 4x - // and transforms weights which may increase values up to 9/4x - const float adj_src_scale = 1.f / 4.f; - const float adj_wei_scale = 4.f / 9.f; - // Winograd transforms need ic and oc to be multiples of 16 - const int load_block = 16; -} - -/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// -struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS( - jit_avx512_core_u8s8s32x_wino_conv_src_trans_t) - - jit_conv_conf_2x3_wino_t jcp; - const primitive_attr_t &attr_; - - struct call_params_t { - const void *src; - const void *wino_src; - const void *v_y_masks; - const void *v_x_masks; - }; - void (*ker_)(const call_params_t *); - - jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( - jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) { - generate(); - ker_ = reinterpret_cast(const_cast(getCode())); - } - void generate(); - - int reg_inp_ind(int i) { - assert(i < jcp.alpha * jcp.alpha); - return (31 - i); - } - - Xmm vreg_inp(int i) { - return Xmm(reg_inp_ind(i)); - } - - Zmm zmm_inp(int i) { - return Zmm(reg_inp_ind(i)); - } - - Xmm vreg_tmp(int i) { - assert(i < jcp.alpha * jcp.alpha); - return Xmm(15 - i); - } - Xmm vreg_out(int i) { - assert(i < jcp.alpha * jcp.alpha); - return Xmm(31 - i); - } - - Opmask y_mask = Opmask(1); - Opmask r_mask = Opmask(2); - Opmask x_mask(int id) { - assert(id < 4); - return Opmask(3 + id); - } - - Reg64 reg_ptr_src = r14; - Reg64 reg_ptr_dst = r13; - - Reg64 reg_ptr_v_y_masks = r12; - Reg64 reg_ptr_v_x_masks = r11; - - Reg64 reg_aux_ptr_src = r10; - Reg64 reg_aux_ptr_dst = r9; - - Reg64 reg_ic_block = r8; - - int unsign_val_in_wino_domain; - - Reg64 reg_scratch_src_alpha = rdx; - Xmm xmm_src_alpha = Xmm(0); - Zmm zmm_src_alpha = Zmm(0); - - Reg64 reg_shift = rax; - Xmm xmm_shift = Xmm(1); - Xmm xmm_zero = Xmm(0); - - Reg64 reg_maskx = rbx; - Reg64 reg_masky = rsi; - Reg64 reg_nomask = reg_maskx; -}; - -void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { - Label ic_block_label; - Label end_label; - Label mask_label; - Label nomask_label; - - auto load_src = [=](bool mask) { - for (int y = 0; y < jcp.alpha; y++) { - if (mask) - kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]); - for (int x = 0; x < jcp.alpha; x++) { - Zmm zmm_i = zmm_inp(y * jcp.alpha + x); - Xmm vreg_i = vreg_inp(y * jcp.alpha + x); - int inp_offset = sizeof(uint8_t) - * ((-jcp.t_pad + y) * jcp.iw * jcp.ic - + (-jcp.l_pad + x) * jcp.ic); - if (mask) { - kandw(r_mask, y_mask, x_mask(x)); - vmovdqu8(vreg_i | r_mask | T_z, - EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); - } else { - vmovdqu8(vreg_i, - EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); - } - vpmovzxbd(zmm_i, vreg_i); // to int32 - vcvtdq2ps(zmm_i, zmm_i); // to fp32 - vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha - vcvtps2dq(zmm_i, zmm_i); // to int32 - vpmovusdb(vreg_i, zmm_i); // to u8 - } - } - }; - - preamble(); - -# define READ_PARAM(reg, field) \ - mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) - READ_PARAM(reg_ptr_src, src); - READ_PARAM(reg_ptr_dst, wino_src); - READ_PARAM(reg_ptr_v_y_masks, v_y_masks); - READ_PARAM(reg_ptr_v_x_masks, v_x_masks); -# undef READ_PARAM - - mov(reg_maskx, ptr[reg_ptr_v_x_masks]); - mov(reg_masky, ptr[reg_ptr_v_y_masks]); - test(reg_maskx, reg_maskx); - jz(end_label, T_NEAR); // skip kernel if x mask is all 0's - test(reg_masky, reg_masky); - jz(end_label, T_NEAR); // skip kernel if y mask is all 0's - and_(reg_maskx, reg_masky); - mov(reg_nomask, reg_maskx); - not_(reg_nomask); // zero if x and y masks are all 1's - - xor_(reg_shift, reg_shift); - mov(reg_shift.cvt8(), (int8_t)-128); - - mov(reg_aux_ptr_src, reg_ptr_src); - mov(reg_aux_ptr_dst, reg_ptr_dst); - - for (int i = 0; i < jcp.alpha; i++) { - kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); - } - - mov(reg_scratch_src_alpha, float2int(adj_src_scale)); - - mov(reg_ic_block, jcp.ic / load_block); - L(ic_block_label); - { - vmovq(xmm_src_alpha, reg_scratch_src_alpha); - vbroadcastss(zmm_src_alpha, xmm_src_alpha); - - test(reg_nomask, reg_nomask); - jz(nomask_label, T_NEAR); - load_src(true); - jmp(mask_label, T_NEAR); - L(nomask_label); - load_src(false); - L(mask_label); - - for(int y = 0; y < 4; y++) { - vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2)); - vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2)); - vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1)); - vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3)); - } - for(int x = 0;x < 4; x++) { - vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2)); - vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2)); - vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1)); - vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3)); - } - - vmovd(xmm_shift, reg_shift.cvt32()); - vpxor(xmm_zero, xmm_zero, xmm_zero); - vpshufb(xmm_shift, xmm_shift, xmm_zero); - - for (int i = 0; i < 16; i++) { - int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i); - if (i != unsign_val_in_wino_domain) - vpsubb(vreg_out(i), vreg_out(i), Xmm(1)); - vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i)); - } - - add(reg_aux_ptr_src, sizeof(uint8_t) * load_block); - add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block); - } - dec(reg_ic_block); - jnz(ic_block_label, T_NEAR); - - L(end_label); - postamble(); -} - -/// DST TRANSFORMS ///////////////////////////////////////////////////////////// -struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS( - jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t) - - jit_conv_conf_2x3_wino_t jcp; - const primitive_attr_t &attr_; - - struct call_params_t { - const void *wino_dst; - const void *dst; - const void *v_y_masks; - const void *v_x_masks; - - const void *bias; - const void *scales; - }; - void (*ker_)(const call_params_t *); - - jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( - jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr) { - generate(); - ker_ = reinterpret_cast(const_cast(getCode())); - } - - void generate(); - bool maybe_relu(int position); - - Zmm vreg_inp(int i) { // 16 - assert(i < jcp.alpha * jcp.alpha); - return Zmm(31 - i); - } - Zmm vreg_stg(int id) { // 8 - const int id_reg_stg = jcp.alpha * jcp.alpha + id; - assert(id < 8); - return Zmm(31 - id_reg_stg); - } - Zmm vreg_out(int id) { // 4 - const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; - assert(id < 4); - return Zmm(31 - id_reg_out); - } - Xmm xmm_out(int id) { // 4 - const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; - assert(id < 4); - return Xmm(31 - id_reg_out); - } - Zmm vreg_tmp(int id) { // 2 - const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; - assert(id < 2); - return Zmm(31 - id_reg_tmp); - } - - Zmm vreg_zero = Zmm(0); - Zmm vreg_bias = Zmm(1); - Zmm vreg_prev_dst = Zmm(2); - Zmm zmm_bias_alpha = Zmm(2); - Xmm xmm_bias_alpha = Xmm(2); - - Opmask y_mask = Opmask(1); - Opmask r_mask = Opmask(2); - Opmask x_mask(int id) { - assert(id < 4); - return Opmask(3 + id); - } - - Reg64 reg_scratch_bias_alpha = r15; - - Reg64 reg_ptr_src = r14; - Reg64 reg_ptr_dst = r13; - - Reg64 reg_ptr_v_y_masks = r12; - Reg64 reg_ptr_v_x_masks = r11; - - Reg64 reg_aux_ptr_src = r10; - Reg64 reg_aux_ptr_dst = r9; - - Reg64 reg_oc_block = r8; - - Reg64 reg_ptr_bias = rbx; - Reg64 reg_ptr_scales = abi_not_param1; - Reg64 reg_ptr_sum_scale = rdx; -}; - -bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { - using namespace primitive_kind; - const auto &p = attr_.post_ops_; - - if (position == 0) { - /* relu before sum */ - return false - || p.contain(eltwise, 0) - || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0)); - } else if (position == 1) { - /* relu after sum */ - const int sum_idx = p.contain(sum, 0) - ? 0 : (p.contain(sum, 1) ? 1 : -1); - if (sum_idx == -1) - return false; - - return false - || p.contain(eltwise, sum_idx + 1) - || jcp.dst_dt == data_type::u8; - } - - return false; -} - -void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { - Label oc_block_label; - - auto loop_body = [=]() { - const auto &p = attr_.post_ops_; - const int sum_idx = p.find(primitive_kind::sum); - const float *p_sum_scale = (sum_idx != -1) - ? &p.entry_[sum_idx].sum.scale - : nullptr; - if (p_sum_scale && *p_sum_scale != 1.f) - mov(reg_ptr_sum_scale, (size_t)p_sum_scale); - - for(int i = 0; i < 16; i++) { - int internal_offset = sizeof(int32_t) * jcp.out_stride * i; - vmovups(vreg_inp(i), - EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); - } - for(int y = 0; y < jcp.alpha; y++) { - vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1)); - vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2)); - - vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2)); - vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3)); - } - for(int x = 0; x < jcp.m; x++) { - vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1)); - vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2)); - - vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2)); - vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3)); - } - - - if (jcp.with_bias) { - vmovq(xmm_bias_alpha, reg_scratch_bias_alpha); - vbroadcastss(zmm_bias_alpha, xmm_bias_alpha); - - auto bias_addr = ptr [ reg_ptr_bias ]; - switch (jcp.bia_dt) { - case data_type::f32: - case data_type::s32: vmovups(vreg_bias, bias_addr); break; - case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break; - case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break; - default: assert(!"unsupported dst data type"); - } - if (jcp.bia_dt != data_type::f32) - vcvtdq2ps(vreg_bias, vreg_bias); - vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha - } - for(int y = 0; y < jcp.m; y++) { - kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]); - for(int x = 0; x < jcp.m; x++) { - kandw(r_mask, y_mask, x_mask(x)); - - int i = y * jcp.m + x; - int offset = jcp.typesize_out * - (y * jcp.ow * jcp.oc + x * jcp.oc); - Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); - - Zmm zmm = vreg_out(i); - Xmm xmm = xmm_out(i); - vcvtdq2ps(zmm, zmm); - if (jcp.with_bias) - vaddps(zmm, zmm, vreg_bias); - vmulps(zmm, zmm, ptr [reg_ptr_scales]); - if (maybe_relu(0)) - vmaxps(zmm, vreg_zero, zmm); - if (p_sum_scale) { // post_op: sum - vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); - switch (jcp.dst_dt) { - case data_type::f32: - case data_type::s32: - vmovups(vreg_prev_dst | r_mask, addr); break; - case data_type::s8: - vpmovsxbd(vreg_prev_dst | r_mask, addr); break; - case data_type::u8: - vpmovzxbd(vreg_prev_dst | r_mask, addr); break; - default: assert(!"unknown dst_dt"); - } - if (jcp.dst_dt != data_type::f32) - vcvtdq2ps(vreg_prev_dst, vreg_prev_dst); - if (*p_sum_scale == 1.f) - vaddps(zmm, vreg_prev_dst); - else - vfmadd231ps(zmm, vreg_prev_dst, - zword_b[reg_ptr_sum_scale]); - } - if (maybe_relu(1)) - vmaxps(zmm, vreg_zero, zmm); - if (jcp.dst_dt != data_type::f32) - vcvtps2dq(zmm, zmm); - switch (jcp.dst_dt) { - case data_type::f32: - case data_type::s32: - vmovups(addr, zmm | r_mask); break; - case data_type::s8: - vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; - case data_type::u8: - vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; - default: assert(!"unknown dst_dt"); - } - } - } - }; - - preamble(); - -# define READ_PARAM(reg, field) \ - mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) - READ_PARAM(reg_ptr_src, wino_dst); - READ_PARAM(reg_ptr_dst, dst); - READ_PARAM(reg_ptr_v_y_masks, v_y_masks); - READ_PARAM(reg_ptr_v_x_masks, v_x_masks); - READ_PARAM(reg_ptr_bias, bias); - READ_PARAM(reg_ptr_scales, scales); -# undef READ_PARAM - - if (jcp.with_bias) - mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale)); - - mov(reg_aux_ptr_src, reg_ptr_src); - mov(reg_aux_ptr_dst, reg_ptr_dst); - - vpxord(vreg_zero, vreg_zero, vreg_zero); - - for (int i = 0; i < jcp.m; i++) - kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); - - int oc_blocks = jcp.oc / load_block; - mov(reg_oc_block, oc_blocks); - L(oc_block_label); { - loop_body(); - add(reg_aux_ptr_src, sizeof(int32_t) * load_block); - add(reg_aux_ptr_dst, jcp.typesize_out * load_block); - - add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); - add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block); - } - dec(reg_oc_block); - jnz(oc_block_label, T_NEAR); - - postamble(); - -} - -/// GEMM kernel //////////////////////////////////////////////////////////////// -struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t) - jit_conv_conf_2x3_wino_t jcp; - const primitive_attr_t &attr_; - - struct call_params_t { - const void *src; - const void *dst; - const void *wei; - const void *dst_b; - }; - void (*ker_)(const call_params_t *); - - void generate(); - static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, - const primitive_attr_t &attr); - - jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( - jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr) - { - generate(); - ker_ = reinterpret_cast(const_cast(getCode())); - } - - static status_t init_conf( - jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, - memory_desc_t &src_md, memory_desc_t &weights_md, - memory_desc_t &dst_md, memory_desc_t &bias_md, - const primitive_attr_t &attr); - - Zmm vreg_out(int n, int m) { - const int id_reg_out = n * jcp.m_block + m; - assert(id_reg_out < jcp.n2_block * jcp.m_block); - return Zmm(31 - id_reg_out); - } - Zmm vreg_wei(int i) { - assert(31 - jcp.n2_block * jcp.m_block - i - > (jcp.ver == ver_vnni ? 0 : 2)); - return Zmm(31 - jcp.n2_block * jcp.m_block - i); - } - - Zmm vreg_src = Zmm(0); - Zmm vreg_one = Zmm(1); - Zmm vreg_tmp = Zmm(2); - - Reg64 reg_ptr_src = r15; - - Reg64 reg_aux_dst_b = r13; - Reg64 reg_aux_dst = r12; - Reg64 reg_aux_dst2 = r11; - Reg64 reg_aux_wei = r10; - Reg64 reg_aux_wei2 = r9; - Reg64 reg_aux_src = r8; - Reg64 reg_aux_src2 = rax; - Reg64 reg_mb = rbx; - Reg64 reg_nnb = abi_not_param1; - Reg64 reg_scratch = rdx; - Reg64 reg_K = rsi; -}; - -bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( - jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { - using namespace primitive_kind; - const auto &p = attr.post_ops_; - - auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; - - switch (p.len_) { - case 0: return true; - case 1: return is_relu(0) || p.contain(sum, 0); - case 2: return (p.contain(sum, 0) && is_relu(1)) || - (p.contain(sum, 1) && is_relu(0)); - case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); - default: return false; - } - - return false; -} - -void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { - Label nnb_loop_label, K_loop_label, mb_loop_label; - - auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { - if (jcp.ver == ver_vnni) { - vpdpbusd(vreg_acc, vreg_src, vreg_wei); - } else { - vpmaddubsw(vreg_tmp, vreg_src, vreg_wei); - vpmaddwd(vreg_tmp, vreg_tmp, vreg_one); - vpaddd(vreg_acc, vreg_acc, vreg_tmp); - } - }; - - preamble(); -# define READ_PARAM(reg, field) \ - mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) - READ_PARAM(reg_ptr_src, src); - READ_PARAM(reg_aux_dst, dst); - READ_PARAM(reg_aux_wei, wei); - READ_PARAM(reg_aux_dst_b, dst_b); -# undef READ_PARAM - - if (jcp.ver != ver_vnni) { - xor_(reg_scratch, reg_scratch); - Reg16 _t = reg_scratch.cvt16(); - mov(_t, 0x1); - vpbroadcastw(vreg_one, _t); - } - - if (!jcp.small_mb) { - mov(reg_nnb, jcp.n_chunks); - L(nnb_loop_label); - } - mov(reg_aux_dst2, reg_aux_dst); - mov(reg_aux_src, reg_ptr_src); - mov(reg_mb, jcp.M / jcp.m_block); - L(mb_loop_label); - { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - for (int m = 0; m < jcp.m_block; m++) { - int offset = jcp.typesize_acc * nb2 * jcp.n_block; - vmovups(vreg_out(nb2, m), - EVEX_compress_addr(reg_aux_dst_b, offset)); - } - } - mov(reg_aux_src2, reg_aux_src); - mov(reg_aux_wei2, reg_aux_wei); - mov(reg_K, jcp.k_chunks); - L(K_loop_label); - { - for (int k = 0; k < jcp.k2_block; k += 4) { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - int wei_offset - = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K); - vmovups(vreg_wei(nb2), - EVEX_compress_addr(reg_aux_wei2, wei_offset)); - } - for (int m = 0; m < jcp.m_block; m++) { - int inp_offset = jcp.typesize_in * m * jcp.K; - vpbroadcastd(vreg_src, - EVEX_compress_addr(reg_aux_src2, inp_offset)); - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) - compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); - } - add(reg_aux_src2, jcp.typesize_in * 4); - add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); - } - } - dec(reg_K); - jnz(K_loop_label, T_NEAR); - - for (int m = 0; m < jcp.m_block; m++) { - for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { - int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block); - vmovups(EVEX_compress_addr(reg_aux_dst2, offset), - vreg_out(nb2, m)); - } - } - add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K); - add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N); - } - dec(reg_mb); - jnz(mb_loop_label, T_NEAR); - - if (!jcp.small_mb) { - add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block); - add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block); - add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K); - - dec(reg_nnb); - jnz(nnb_loop_label, T_NEAR); - } - - postamble(); -} -namespace { -bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { - if (jcp.ver == ver_vnni) { - return (jcp.mb <= mkldnn_get_max_threads() - && (jcp.mb > 4 - && jcp.ic > 64 - && !(jcp.oc > 128 && jcp.ih < 14))) - || jcp.mb > mkldnn_get_max_threads(); - } - return true; -} -} - -status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t -::init_conf(jit_conv_conf_2x3_wino_t &jcp, - const convolution_desc_t &cd, memory_desc_t &src_md, - memory_desc_t &wei_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const primitive_attr_t &attr) { - const memory_desc_wrapper src_d(&src_md); - const memory_desc_wrapper wei_d(&wei_md); - const memory_desc_wrapper dst_d(&dst_md); - const memory_desc_wrapper bias_d(&bias_md); - - const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; - - jcp.nthr = mkldnn_get_max_threads(); - - jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = dst_d.dims()[2]; - jcp.ow = dst_d.dims()[3]; - jcp.kh = wei_d.dims()[with_groups + 2]; - jcp.kw = wei_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.b_pad = cd.padding[1][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.r_pad = cd.padding[1][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - jcp.ver = ver_avx512_core; - if (!(mayiuse(avx512_core) && - src_d.data_type() == data_type::u8 - && wei_d.data_type() == data_type::s8 - && one_of(dst_d.data_type(), data_type::f32, data_type::s32, - data_type::s8, data_type::u8))) - return status::unimplemented; - if (mayiuse(avx512_core_vnni)) - jcp.ver = ver_vnni; - - if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, - is_winograd_faster_than_direct(jcp))) - return status::unimplemented; - - // block sizes needed for GEMM kernel - jcp.ic_block = 4; - jcp.oc_block = 16; - - bool ok = true - && jcp.ngroups == 1 - && jcp.oc % load_block == 0 && jcp.ic % load_block == 0 - && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 - && everyone_is(3, jcp.kh, jcp.kw) - && everyone_is(1, jcp.stride_h, jcp.stride_w) - && everyone_is(0, jcp.dilate_h, jcp.dilate_w) - && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad - && one_of(jcp.t_pad, 0, 1) - && one_of(jcp.l_pad, 0, 1); - if (!ok) return status::unimplemented; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; - jcp.dst_dt = cd.dst_desc.data_type; - - jcp.typesize_in = types::data_type_size(src_d.data_type()); - jcp.typesize_out = types::data_type_size(dst_d.data_type()); - jcp.typesize_acc = sizeof(int32_t); - jcp.typesize_bia = jcp.with_bias - ? types::data_type_size(bias_d.data_type()) - : 0; - - jcp.nb_oc = jcp.oc / jcp.oc_block; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - jcp.m = 2; - jcp.r = 3; - jcp.alpha = jcp.m + jcp.r - 1; - - int aa = jcp.alpha * jcp.alpha; - int L1_cap = get_cache_size(1, true); - int L2_cap = get_cache_size(2, true); - // need 1 extra reg for bcast, and 2 tmp regs for non-vnni - int free_regs = jcp.ver == ver_vnni ? 31 : 29; - - auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) { - float thr_eff; - float Z = (float)jcp.ic + jcp.oc; - float Y = (float)jcp.ic * jcp.oc; - if (small_mb == 0) { // outer par - int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix); - thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); - } else { // inner par - int tranw = iy * ix / jcp.alpha; - int gemmw = aa * (jcp.nb_oc / n2_b); - int tranw_r = rnd_up(tranw, jcp.nthr); - int gemmw_r = rnd_up(gemmw, jcp.nthr); - thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y); - } - return thr_eff; - }; - - auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) { - float mem_eff, req_mem; - int M = ix * iy / jcp.alpha; - if (small_mb == 0) { // outer parallelization strategy - // memory for wino transforms (other memory has poor reuse) - req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc); - mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f; - } else { // inner parallelization strategy - // memory used during gemm - int N = jcp.oc_block * n2_b; - req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N; - mem_eff = nstl::min(1.f, L2_cap / req_mem); - // memory used during wino transforms - int M_per_thr = div_up(M, jcp.nthr); - req_mem = (float)aa * M_per_thr - * (jcp.ic + jcp.typesize_acc * jcp.oc); - if (req_mem > L2_cap) - mem_eff = 0.1f; - } - return mem_eff; - }; - - auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff, - float mem_eff, float reg_eff) { - // these coefficients are chosen empirically - float mem_fac = 0.1f, reg_fac = 0.2f; - // normalized overhead relative to memory and register components - float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff; - // thread and work components affect all others - tot_eff *= thr_eff * work_eff; - return tot_eff; - }; - - auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff, - int &m_block, int &n2_block, float &tot_eff) { - int M = (ix * iy) / jcp.alpha; - int max_m_block = nstl::min(M, free_regs); - int max_n2_block = nstl::min(jcp.nb_oc, free_regs); - tot_eff = 0.f; - for (int im = max_m_block; im > 0; im--) { - if (M % im) - continue; - for (int in2 = max_n2_block; in2 > 0; in2--) { - int used_regs = (im + 1) * in2; - float mem_eff = get_mem_eff(small_mb, ix, iy, in2); - float reg_eff = (float)(im * in2) / (im + in2); - float thr_eff = get_thr_eff(small_mb, ix, iy, in2); - float cur_tot_eff = get_tot_eff( - small_mb, thr_eff, work_eff, mem_eff, reg_eff); - if (jcp.nb_oc % in2 || used_regs > free_regs - || cur_tot_eff <= tot_eff) - continue; - tot_eff = cur_tot_eff; - m_block = im; - n2_block = in2; - } - } - }; - - /* Selecting xb and yb blocking */ - int min_yb = jcp.m; - int min_xb = jcp.m; - int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2)); - int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2)); - float best_eff = 0.f; - for (int ix = min_xb; ix <= max_xb; ix += 2) { - assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2); - for (int iy = max_yb; iy >= min_yb; iy -= 2) { - assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2); - - int m_b[2]; - int n2_b[2]; - bool small_mb; - float inner_eff, outer_eff, work_eff; - - int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix); - work_eff = (float)jcp.oh * jcp.ow / tiled_area; - if (best_eff > 0.f && work_eff < 4.f / 9.f) - continue; // no gain from Winograd transformation - - /* outer parallelization */ - find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff); - - /* inner parallelization */ - find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff); - - small_mb = inner_eff > outer_eff; - float eff = small_mb ? inner_eff : outer_eff; - if (eff > best_eff) { - best_eff = eff; - jcp.yb = iy; - jcp.xb = ix; - jcp.m_block = m_b[small_mb]; - jcp.n2_block = n2_b[small_mb]; - jcp.small_mb = small_mb; - } - } - } - - assert((jcp.m_block + 1) * jcp.n2_block <= free_regs); - assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); - - jcp.mb_block = 1; - if (jcp.small_mb) { - // For small mb harness, set mb_block as large as possible subject to - // the constraint that winograd activations fit into available L3 cache - int L3_cap = get_cache_size(3, true); - int M = jcp.xb * jcp.yb / 4; - int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in; - int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc; - int max_mb_block = nstl::min( - jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size)); - for (int i = max_mb_block; i > 1; i--) { - if (jcp.mb % i == 0) { - jcp.mb_block = i; - break; - } - } - } - jcp.nb_mb = jcp.mb / jcp.mb_block; - - jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4; - jcp.N = jcp.oc; - jcp.K = jcp.ic; - - jcp.inp_stride = jcp.M * jcp.ic; - jcp.out_stride = jcp.M * jcp.oc; - jcp.wei_stride = jcp.ic * jcp.oc; - jcp.bia_stride = jcp.oc; - - jcp.n_block = jcp.oc_block; - jcp.k_block = jcp.ic_block; - - jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block; - - // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4 - // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is - // a multiple of load_block = 16, we just use that for now. - jcp.k2_block = load_block; - jcp.k_chunks = jcp.K / jcp.k2_block; - - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); - - /* re-create weights primitive descriptor - and set weights wino_blocking */ - memory_desc_t expect_wei_md = wei_md; - - expect_wei_md.format_kind = format_kind::wino; - expect_wei_md.data_type = data_type::s8; - mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; - wd.wino_format = mkldnn_wino_wei_aaOIoi; - wd.r = jcp.r; - wd.alpha = jcp.alpha; - wd.ic = jcp.ic; - wd.oc = jcp.oc; - wd.ic_block = jcp.ic_block; - wd.oc_block = jcp.oc_block; - wd.oc2_block = jcp.n2_block; - wd.ic2_block = 1; - wd.adj_scale = adj_wei_scale; - - size_t max_size = types::data_type_size(data_type::s8) * - jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; - max_size += types::data_type_size(data_type::s32) * - jcp.alpha * jcp.alpha * jcp.oc; - wd.size = max_size; - - if (wei_md.format_kind == format_kind::any) - wei_md = expect_wei_md; - if (wei_md != expect_wei_md) - return status::unimplemented; - - const int tilesize = jcp.alpha * jcp.alpha; - const int numtiles = jcp.M; - const int alltiles = numtiles * tilesize; - - jcp.size_wino_src - = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K) - / jcp.typesize_in; - jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic; - jcp.size_wino_dst = alltiles * jcp.oc; - - return status::success; -} -//////////////////////////////////////////////////////////////////////////////// - -template -status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: - pd_t::jit_conf() { - return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf( - jcp_, *this->desc(), this->src_md_, this->weights_md_, - this->dst_md_,this->bias_md_, *this->attr()); -} - -template -void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t:: -init_scratchpad() { - auto scratchpad = this->scratchpad_registry().registrar(); - - int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr; - scratchpad.book(key_wino_V, - sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K); - scratchpad.book(key_wino_M, - sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K); - - dim_t scale_count = attr()->output_scales_.count_; - scratchpad.book(key_conv_adjusted_scales, - sizeof(float) * nstl::max(scale_count, 16)); -} - -template -jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: - jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) -{ - kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( - pd()->jcp_, *pd()->attr()); - src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( - pd()->jcp_, *pd()->attr()); - dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( - pd()->jcp_, *pd()->attr()); -} - -template -jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: - ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() { - delete kernel_; - delete src_trans_; - delete dst_trans_; -} - -template -const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: -adjust_oscales(const memory_tracking::grantor_t &scratchpad) const { - const float *oscales = pd()->attr()->output_scales_.scales_; - auto loc_scales = scratchpad.template get(key_conv_adjusted_scales); - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / (adj_src_scale * adj_wei_scale); - if (count == 1) - utils::array_set(loc_scales, oscales[0] * factor, 16); - else - for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor; - return loc_scales; -} - -template -void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const auto &jcp = kernel_->jcp; - if (jcp.small_mb) - execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx)); - else - execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx)); -} - -template -void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: -execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, - const char *bia, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const float *oscales = adjust_oscales(scratchpad); - - auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); - auto wino_src_base = scratchpad.template get(key_wino_V); - auto wino_dst_base = scratchpad.template get(key_wino_M); - - parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), - [&](int mb, int tile_y_b, int tile_x_b) { - - int tile_y = tile_y_b * jcp.yb; - int tile_x = tile_x_b * jcp.xb; - - int ithr = mkldnn_get_thread_num(); - auto wino_src = wino_src_base + jcp.size_wino_src * ithr; - auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr; - - auto src_trans_p = - jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); - auto dst_trans_p = - jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); - auto gemm_p = - jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t(); - - /* transformation of input tensor to winograd domain */ - for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { - for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { - uint16_t v_y_masks[4], v_x_masks[4]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - - int v_ys = nstl::max(0, jcp.t_pad - y); - int v_ye = nstl::min(jcp.alpha, - nstl::max(0, jcp.ih + jcp.t_pad - y)); - - int v_xs = nstl::max(0, jcp.l_pad - x); - int v_xe = nstl::min(jcp.alpha, - nstl::max(0, jcp.iw + jcp.l_pad - x)); - -#pragma unroll(4) - for (int i = 0; i < jcp.alpha; i++) { - v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); - v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); - } - auto local_s = src - + mb * jcp.ih * jcp.iw * jcp.ic - + y * jcp.iw * jcp.ic + x * jcp.ic; - auto local_w = wino_src + m * jcp.ic; - - src_trans_p.src = local_s; - src_trans_p.wino_src = local_w; - src_trans_p.v_y_masks = v_y_masks; - src_trans_p.v_x_masks = v_x_masks; - - src_trans_->ker_(&src_trans_p); - } - } - /* gemms */ - for (int tile_ij = 0; tile_ij < 16; tile_ij++) { - // start threads at different GEMMs to help bring weights into LLC - int offset = (tile_ij + ithr) % 16; - gemm_p.src = wino_src + jcp.inp_stride * offset; - gemm_p.dst = wino_dst + jcp.out_stride * offset; - gemm_p.wei = wei + jcp.wei_stride * offset; - gemm_p.dst_b = dst_bias + jcp.bia_stride * offset; - - kernel_->ker_(&gemm_p); - } - - /* transformation from winograd domain to output tensor */ - for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { - for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { - uint16_t v_y_masks[2], v_x_masks[2]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); - -#pragma unroll(2) - for (int i = 0; i < jcp.m; i++) { - v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); - v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); - } - auto local_d = dst - + mb * jcp.oh * jcp.ow * jcp.oc - + y * jcp.ow * jcp.oc + x * jcp.oc; - auto local_w = wino_dst + m * jcp.oc; - - auto scales = oscales; - dst_trans_p.dst = local_d; - dst_trans_p.wino_dst = local_w; - dst_trans_p.v_y_masks = v_y_masks; - dst_trans_p.v_x_masks = v_x_masks; - - dst_trans_p.scales = scales; - dst_trans_p.bias = bia; - - dst_trans_->ker_(&dst_trans_p); - } - } - }); -} - -template -void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t:: -execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, - const char *bia, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - const auto &jcp = kernel_->jcp; - const float *oscales = adjust_oscales(scratchpad); - - auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); - auto wino_src = scratchpad.template get(key_wino_V); - auto wino_dst = scratchpad.template get(key_wino_M); - - for (int mbb = 0; mbb < jcp.nb_mb; mbb++) { - for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { - for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { - /* transformation of input tensor to winograd domain */ - parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, - [&](int y_in_block_b, int x_in_block_b, int mb) { - int y_in_block = y_in_block_b * 2; - int x_in_block = x_in_block_b * 2; - - auto src_trans_p = - jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); - - uint16_t v_y_masks[4], v_x_masks[4]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) - + (x_in_block / 2); - - int v_ys = nstl::max(0, jcp.t_pad - y); - int v_ye = nstl::min( - jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); - - int v_xs = nstl::max(0, jcp.l_pad - x); - int v_xe = nstl::min( - jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); - -#pragma unroll(4) - for (int i = 0; i < jcp.alpha; i++) { - v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); - v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); - } - auto local_s = src - + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic - + y * jcp.iw * jcp.ic + x * jcp.ic; - auto local_w = wino_src + m * jcp.ic; - - src_trans_p.src = local_s; - src_trans_p.wino_src = local_w; - src_trans_p.v_y_masks = v_y_masks; - src_trans_p.v_x_masks = v_x_masks; - - src_trans_->ker_(&src_trans_p); - }); - - /* gemms */ - parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { - auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t:: - call_params_t(); - - gemm_p.src = wino_src + jcp.inp_stride * tile_ij; - gemm_p.dst = wino_dst + jcp.out_stride * tile_ij - + nnb * jcp.n2_block * jcp.n_block; - gemm_p.wei = wei + jcp.wei_stride * tile_ij - + nnb * jcp.n2_block * jcp.n_block * jcp.K; - gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij - + nnb * jcp.n2_block * jcp.n_block; - - kernel_->ker_(&gemm_p); - }); - - /* transformation from winograd domain to output tensor */ - parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, - [&](int y_in_block_b, int x_in_block_b, int mb) { - int y_in_block = y_in_block_b * 2; - int x_in_block = x_in_block_b * 2; - - auto dst_trans_p = - jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); - - uint16_t v_y_masks[2], v_x_masks[2]; - - int y = y_in_block + tile_y; - int x = x_in_block + tile_x; - int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) - + (x_in_block / 2); - -#pragma unroll(2) - for (int i = 0; i < jcp.m; i++) { - v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); - v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); - } - auto local_d = dst - + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc - + y * jcp.ow * jcp.oc + x * jcp.oc; - auto local_w = wino_dst + m * jcp.oc; - - auto scales = oscales; - dst_trans_p.dst = local_d; - dst_trans_p.wino_dst = local_w; - dst_trans_p.v_y_masks = v_y_masks; - dst_trans_p.v_x_masks = v_x_masks; - - dst_trans_p.scales = scales; - dst_trans_p.bias = bia; - - dst_trans_->ker_(&dst_trans_p); - }); - }}} -} - -template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; -template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; -template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; -template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t; - -} // namespace cpu -} // namespace impl -} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp deleted file mode 100644 index 9e6e57b05..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp +++ /dev/null @@ -1,128 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP -#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP - -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_primitive_conf.hpp" -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t; -struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t; -struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t; - -template -struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""), - jit_avx512_core_u8s8s32x_wino_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && utils::one_of(desc()->alg_kind, - alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && expect_data_types(data_type::u8, data_type::s8, - data_type::undef, dst_data_type, data_type::s32) - && IMPLICATION(with_bias(), utils::one_of( - desc()->bias_desc.data_type, data_type::f32, - data_type::s32, data_type::s8, data_type::u8)) - && !has_zero_dim_memory() - && set_default_formats(); - - if (!ok) return status::unimplemented; - - status_t status = jit_conf(); - if (status != status::success) return status; - set_default_alg_kind(alg_kind::convolution_winograd); - - init_scratchpad(); - - return status; - } - - jit_conv_conf_2x3_wino_t jcp_; - - protected: - status_t jit_conf(); - void init_scratchpad(); - - bool set_default_formats() { - using namespace format_tag; - return set_default_formats_common(nhwc, any, nhwc); - } - }; - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type acc_data_t; - typedef typename prec_traits::type dst_data_t; - - jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd); - ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(); - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad) - const; - void execute_forward(const exec_ctx_t &ctx) const; - void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, - const char *bia, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; - void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, - const char *bia, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_; - jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_; - jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp deleted file mode 100644 index f4ec29ab0..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp +++ /dev/null @@ -1,820 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_memory.hpp" - -#include "jit_uni_1x1_conv_utils.hpp" -#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" - -#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position) -{ - using namespace primitive_kind; - const auto &p = attr_.post_ops_; - - if (position == 0) { - /* eltwise before sum */ - return p.contain(eltwise, 0); - } else if (position == 1) { - /* eltwise after sum */ - return p.contain(sum, 0) && p.contain(eltwise, 1); - } - - return false; -} - -void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk) -{ - mov(aux1_reg_bcast_data, reg_bcast_data); - mov(aux_reg_bcast_data, reg_bcast_data); - - mov(aux_reg_output_data, reg_output_data); - mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off)); - - Label bcast_loop; - Label bcast_loop_tail; - - cmp(bcast_loop_iter, jcp.ur); - jl(bcast_loop_tail, T_NEAR); - - L(bcast_loop); { - assert(jcp.bcast_block % jcp.ur == 0); - int num_substeps = jcp.bcast_block / jcp.ur; - assert(num_substeps > 0 && num_substeps < 10); - for (int i = 0; i < num_substeps; i++) { - reduce_loop(load_loop_blk, jcp.ur, i, false); - if (i < num_substeps - 1) { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_substep); - } - else { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step - - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); - int output_offset = jcp.bcast_loop_output_step - - (num_substeps - 1) * jcp.bcast_loop_output_substep; - - add(aux_reg_output_data, output_offset); - } - } - sub(bcast_loop_iter, jcp.bcast_block); - cmp(bcast_loop_iter, jcp.bcast_block); - jge(bcast_loop, T_NEAR); - } - - L(bcast_loop_tail); - if (jcp.ur_tail) { - Label bcast_loop_tail_out; - cmp(bcast_loop_iter, 0); - jz(bcast_loop_tail_out, T_NEAR); - reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); - L(bcast_loop_tail_out); - } -} - -void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in, - zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) { - zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; - switch (type_in) { - case data_type::f32: - case data_type::s32: vmovups(zmm, op); break; - case data_type::s8: vpmovsxbd(zmm, op); break; - case data_type::u8: vpmovzxbd(zmm, op); break; - default: assert(!"unsupported data type"); - } - if (type_in != data_type::f32) - vcvtdq2ps(zmm_in, zmm_in); -} - -void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk, - int ur, int substep, bool wraparound) -{ - auto vreg_load = [=](int i_load) { - return Zmm(ur * load_loop_blk + i_load); - }; - - auto vreg_accum = [=](int i_load, int i_ur) { - return Zmm(i_ur * load_loop_blk + i_load); - }; - - auto zmm_bias_alpha = [=]() { - return Zmm(ur * load_loop_blk); - }; - - auto xmm_bias_alpha = [=]() { - return Xmm(ur * load_loop_blk); - }; - auto bias_ptr = [=](int i_load) { - return EVEX_compress_addr(reg_bias_data, - jcp.typesize_bia * jcp.oc_block * i_load); - }; - - auto comp_ptr = [=](int i_load) { - return EVEX_compress_addr(reg_comp_data, - sizeof(int32_t) * jcp.oc_block * i_load); - }; - - auto scale_ptr = [=](int i_load) { - return EVEX_compress_addr(reg_ptr_scales, - jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load)); - }; - - auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { - assert(i_ur < jcp.ur); - assert(i_reduce <= jcp.reduce_loop_unroll); - assert(jcp.reduce_loop_unroll == jcp.reduce_block); - - int offt = (jcp.ic_without_padding * i_ur + i_reduce); - - return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, - bcast); - }; - - auto load_ptr = [=](int i_reduce, int i_load) { - int u0 = i_reduce % jcp.reduce_loop_unroll; - int u1 = i_reduce / jcp.reduce_loop_unroll; - - int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; - - return EVEX_compress_addr(aux_reg_load_data, - u1 * jcp.reduce_loop_load_step - + jcp.typesize_in * offt); - }; - - auto output_ptr = [=](int i_load, int i_ur) { - return EVEX_compress_addr(aux_reg_output_data, - jcp.typesize_out * (jcp.oc_without_padding * i_ur - + i_load * jcp.load_block)); - }; - - auto init = [=]() { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto r = vreg_accum(i_load, i_ur); - vpxord(r, r, r); - } - if (jcp.signed_input) { - xor_(reg_scratch, reg_scratch); - Reg8 _t8 = reg_scratch.cvt8(); - mov(_t8, (int8_t)-128); - vpbroadcastb(zmm_shift, _t8); - } - }; - - auto store = [=](const bool mask_flag_in) { - const auto &p = attr_.post_ops_; - const int sum_idx = p.find(primitive_kind::sum); - const float *p_sum_scale = (sum_idx != -1) - ? &p.entry_[sum_idx].sum.scale - : nullptr; - mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); - mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); - if (p_sum_scale && *p_sum_scale != 1.f) { - mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data); - mov(reg_ptr_sum_scale, (size_t)p_sum_scale); - } - if (jcp.signed_input && jcp.ver != ver_vnni) { - mov(reg_scratch, float2int(jcp.wei_adj_scale)); - vmovq(xmm_bias_alpha(), reg_scratch); - vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); - } - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; - auto zmm_bias = zmm_tmp; - auto zmm_comp = zmm_bcast; - if (jcp.with_bias) { - if (jcp.signed_input) - mov(reg_bias_data, - EVEX_compress_addr(rsp,reg_bias_data_off)); - cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag); - if (jcp.signed_input && jcp.ver != ver_vnni) - vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); - } - if (jcp.signed_input) { - mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); - cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag); - } - - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto r = vreg_accum(i_load, i_ur); - vcvtdq2ps(r, r); - if (jcp.signed_input) - vaddps(r, r, zmm_comp); - if (jcp.with_bias) - vaddps(r, r, zmm_bias); - - zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r; - vmulps(mask_zmm, r, scale_ptr(i_load)); - } - } - - if (maybe_eltwise(0)) - eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); - - if (p_sum_scale) { // post_op: sum - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - const bool mask_flag = mask_flag_in && - i_load == load_loop_blk - 1; - for (int i_ur = 0; i_ur < ur; ++i_ur) { - vpxord(zmm_zero, zmm_zero, zmm_zero); - auto zmm_prev_dst = zmm_zero; - - auto r = vreg_accum(i_load, i_ur); - cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur), - mask_flag); - - if (*p_sum_scale == 1.f) - vaddps(r, zmm_prev_dst); - else - vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); - } - } - } - - if (maybe_eltwise(1)) - eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); - - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - const bool mask_flag = mask_flag_in && - i_load == load_loop_blk - 1; - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto r = vreg_accum(i_load, i_ur); - if (jcp.dst_dt == data_type::u8) { - vpxord(zmm_zero, zmm_zero, zmm_zero); - vmaxps(r, zmm_zero, r); - } - if (jcp.dst_dt != data_type::f32) - vcvtps2dq(r, r); - } - for (int i_ur = 0; i_ur < ur; ++i_ur) { - auto r = vreg_accum(i_load, i_ur); - zmm_t r_zmm = mask_flag ? r | ktail_mask : r; - - switch (jcp.dst_dt) { - case data_type::f32: - case data_type::s32: - vmovups(output_ptr(i_load, i_ur), r_zmm); break; - case data_type::s8: - vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break; - case data_type::u8: - vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break; - default: assert(!"unknown dst_dt"); - } - } - } - mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); - if (p_sum_scale && *p_sum_scale != 1.f) - mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off)); - }; - - auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { - if (jcp.ver == ver_vnni) { - vpdpbusd(vreg_acc, vreg_src, vreg_wei); - } else { - vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); - vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); - vpaddd(vreg_acc, vreg_acc, zmm_tmp); - } - }; - - auto fma_block = [=](bool last_block) { - int reduce_step = 4; - int tail_size = jcp.ic_without_padding % reduce_step; - int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding - ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) - : jcp.reduce_loop_unroll; - for (int i_reduce = 0; i_reduce < loop_unroll; - i_reduce += reduce_step) { - for (int i_load = 0; i_load < load_loop_blk; ++i_load) - vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); - for (int i_ur = 0; i_ur < ur; ++i_ur) { - if (last_block && tail_size != 0 - && i_reduce == loop_unroll - reduce_step) { - Xmm xmm_bcast = Xmm(zmm_bcast.getIdx()); - for (int r = 0; r < tail_size; ++r) - vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data - + jcp.ic_without_padding * i_ur + i_reduce + r], r); - vpbroadcastd(zmm_bcast, xmm_bcast); - } else { - vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false)); - } - if (jcp.signed_input) - vpsubb(zmm_bcast, zmm_bcast, zmm_shift); - for (int i_load = 0; i_load < load_loop_blk; ++i_load) { - compute(vreg_accum(i_load, i_ur), - vreg_load(i_load), zmm_bcast); - } - } - } - }; - - Label reduce_loop; - Label reduce_loop_tail; - - mov(aux_reg_load_data, reg_load_data); - - mov(aux_reg_bcast_data, aux1_reg_bcast_data); - init(); - - mov(reduce_loop_iter, reg_reduce_loop_work); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jle(reduce_loop_tail, T_NEAR); - - L(reduce_loop); { - fma_block(false); - add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); - add(aux_reg_load_data, jcp.reduce_loop_load_step); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jg(reduce_loop, T_NEAR); - } - - L(reduce_loop_tail); - if (jcp.ic != jcp.ic_without_padding) { - fma_block(true); - } else { - fma_block(false); - } - - if (jcp.oc_without_padding != jcp.oc) { - Label end_store, common_store; - mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); - - /*Check if it is the last load_loop_blk*/ - sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); - cmp(reg_load_loop_work, 0); - jg(common_store, T_NEAR); - - /*Check if it is the last ocb*/ - test(reg_reduce_pos_flag, FLAG_OC_LAST); - jz(common_store, T_NEAR); - - store(true); - jmp(end_store, T_NEAR); - - L(common_store); - store(false); - - L(end_store); - - add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); - } else { - store(false); - } -} - -void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() -{ - preamble(); - - xor_(reg_scratch, reg_scratch); - Reg16 _t = reg_scratch.cvt16(); - mov(_t, 0x1); - vpbroadcastw(zmm_one, _t); - - sub(rsp, stack_space_needed); - - if (jcp.oc_without_padding != jcp.oc) { - int tail_size = jcp.oc_without_padding % jcp.oc_block; - int mask = (1 << tail_size) - 1; - Reg32 regw_tmp = reg_last_load.cvt32(); - mov(regw_tmp, mask); - kmovw(ktail_mask, regw_tmp); - } - - if (jcp.with_bias) - mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); - if (jcp.signed_input) { - mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); - mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); - mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); - } - mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); - mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); - mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); - mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); - mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); - - mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); - mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); - mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); - mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); - mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); - - - auto load_loop_body = [=](int load_loop_blk) { - bcast_loop(load_loop_blk); - add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); - if (jcp.with_bias) { - if (jcp.signed_input) - mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); - add(reg_bias_data, - load_loop_blk * jcp.load_block * jcp.typesize_bia); - if (jcp.signed_input) - mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); - } - if (jcp.signed_input) { - mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); - add(reg_comp_data, - load_loop_blk * jcp.load_block * sizeof(int32_t)); - mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); - } - mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); - mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); - add(reg_ptr_scales, - jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float)); - mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); - mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); - add(reg_output_data, - load_loop_blk * jcp.load_block * jcp.typesize_out); - sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); - }; - - const int simd_w = 16; - - Label load_loop_blk[7]; - - static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; - const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast); - const int *ur_cases_fma = ur_cases_fma_expl_bcast; - const int *ur_cases = ur_cases_fma; - const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); - - for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { - int label_idx = num_ur_cases - ur_idx - 1; - if (jcp.ur <= ur_cases[ur_idx]) { - cmp(reg_load_loop_work, simd_w * (label_idx + 1)); - jle(load_loop_blk[label_idx], T_NEAR); - } - } - - for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { - if (jcp.ur <= ur_cases[ur_idx]) { - int label_idx = num_ur_cases - ur_idx - 1; - L(load_loop_blk[label_idx]); - { - if (label_idx == 0) { - cmp(reg_load_loop_work, 0); - je(load_loop_blk[num_ur_cases], T_NEAR); - } - - for (int _i = 1; _i <= label_idx + 1; _i++) { - prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]); - prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]); - } - - load_loop_body(label_idx + 1); - if (label_idx - 1 > 0) { - cmp(reg_load_loop_work, 2 * label_idx * simd_w); - je(load_loop_blk[label_idx - 1], T_NEAR); - } - cmp(reg_load_loop_work, (label_idx + 1) * simd_w); - jge(load_loop_blk[label_idx]); - } - for (int idx = label_idx - 1; idx > 0; --idx) { - cmp(reg_load_loop_work, simd_w * (idx + 1)); - je(load_loop_blk[idx], T_NEAR); - } - if (ur_idx < num_ur_cases - 2) { - cmp(reg_load_loop_work, simd_w); - jle(load_loop_blk[0], T_NEAR); - } - } - } - L(load_loop_blk[num_ur_cases]); - - add(rsp, stack_space_needed); - - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok( - jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { - using namespace primitive_kind; - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - - switch (p.len_) { - case 0: return true; - case 1: return is_eltwise(0) || p.contain(sum, 0); - case 2: return (p.contain(sum, 0) && is_eltwise(1)) - || (p.contain(sum, 1) && is_eltwise(0)); - default: return false; - } - - return false; -} - -status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( - jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, - const primitive_attr_t &attr, int nthreads, bool reduce_src) { - if (!mayiuse(avx512_core)) return status::unimplemented; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) - || weights_d.data_type() != data_type::s8 - || !one_of(dst_d.data_type(), - data_type::f32, data_type::s32, data_type::s8, data_type::u8)) - return status::unimplemented; - jcp.ver = ver_avx512_core; - if (mayiuse(avx512_core_vnni)) - jcp.ver = ver_vnni; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ic_without_padding = jcp.ic; - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = dst_d.dims()[2]; - jcp.ow = dst_d.dims()[3]; - jcp.kh = weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + 3]; - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; - - jcp.os = jcp.oh * jcp.ow; - jcp.is = jcp.ih * jcp.iw; - jcp.tr_is = rnd_up(jcp.is, 4); - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - format_tag_t dat_tag = format_tag::nhwc; - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.ngroups == 1 - && jcp.src_tag == dat_tag - && jcp.dst_tag == dat_tag; - if (!args_ok) return status::unimplemented; - - const int simd_w = 16; - - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.ic, simd_w); - - args_ok = true - && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 - && jcp.t_pad == 0 && jcp.l_pad == 0 - && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides - && jcp.kh == 1 && jcp.kw == 1; - if (!args_ok) return status::unimplemented; - - jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; - jcp.dst_dt = cd.dst_desc.data_type; - - jcp.ic_block = jcp.oc_block = simd_w; - - jcp.typesize_in = types::data_type_size(src_d.data_type()); - jcp.typesize_out = types::data_type_size(dst_d.data_type()); - jcp.typesize_bia = jcp.with_bias - ? types::data_type_size(bias_d.data_type()) - : 0; - - const int SMALL_SPATIAL = 7 * 7; - const int BIG_REDUCE_DIM = 1024; - - int load_blocking = 0; - int load_blocking_max = 0; - int bcast_blocking = 0; - int bcast_blocking_max = 0; - int reduce_blocking = 0; - int reduce_blocking_max = 0; - jcp.load_grp_count = 1; - jcp.use_vmovntps = false; - - const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in); - const int L2_capacity = (L2_size * 3) / 4; - - int size_treshold = 28; - int max_regs = 0; - int min_regs = 6; - if (jcp.ver == ver_vnni) - max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold) - && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9; - else - max_regs = 8; - jcp.expl_bcast = true; - - if (jcp.mb == 1 && jcp.ic > 128 - && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) { - if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) - max_regs = min_regs; // mobilenet_v2 performance improvement - jcp.ur = nstl::min(max_regs, jcp.os); - } else { - const int spatial = jcp.oh; - jcp.ur = 1; - for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { - if ((spatial >= size_treshold && spatial % ur_w == 0) - || (spatial < size_treshold && jcp.os % ur_w == 0)) { - jcp.ur = ur_w; - break; - } - } - if (jcp.ur == 1) { - jcp.ur = nstl::min(max_regs, jcp.os); - int os_tail = jcp.os % max_regs; - for (int i = max_regs; i >= min_regs; i--) { - int i_tail = jcp.os % i; - if (i_tail > os_tail || i_tail == 0) { - jcp.ur = i; - os_tail = i_tail; - if (i_tail == 0) - break; - } - } - } - } - - jcp.reduce_dim = jcp.ic; - jcp.reduce_block = jcp.ic_block; - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.is; - - jcp.bcast_block = jcp.ur; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.typesize_in; - - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; - - jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out; - jcp.bcast_loop_output_substep = -1; // unused - jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in; - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_load_step - = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; - - jcp.load_loop_iter_step = jcp.load_block; - - jcp.loop_order = reduce_src ? loop_blr : loop_lbr; - - int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - reduce_blocking = nb_reduce; - if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) - reduce_blocking = 64; - else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) - reduce_blocking = 16; - reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); - reduce_blocking *= jcp.reduce_block; - - bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; - if (cmp_reduce) - jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; - load_blocking = jcp.load_dim; - - jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); - jcp.load_grp_count = best_divider( - nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); - - if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) { - jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); - } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads - && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { - jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); // - load_blocking = jcp.load_block; - } - - bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, - div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block; - bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); - bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); - - int space_for_bcast - = (L2_capacity - /* kernel_size - */ - 2 * jcp.load_block * reduce_blocking - - jcp.ur * reduce_blocking - 3 * 1024); - if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) - space_for_bcast /= 2; - - int bcast_in_cache - = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); - bcast_blocking = nstl::min( - bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); - - load_blocking_max = load_blocking; - bcast_blocking_max = bcast_blocking * 3 / 2; - reduce_blocking_max = reduce_blocking; - - assert(load_blocking); - assert(load_blocking_max); - assert(bcast_blocking); - assert(bcast_blocking_max); - assert(reduce_blocking); - assert(reduce_blocking_max); - assert(load_blocking % jcp.load_block == 0); - assert(reduce_blocking % jcp.reduce_block == 0); - assert(load_blocking_max % jcp.load_block == 0); - assert(reduce_blocking_max % jcp.reduce_block == 0); - - assert(jcp.reduce_loop_unroll % 4 == 0); - assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); - - assert(jcp.bcast_block % jcp.ur == 0); - assert(jcp.reduce_dim % jcp.reduce_block == 0); - - jcp.ur_tail = jcp.bcast_dim % jcp.ur; - - jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; - jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; - jcp.nb_load_blocking = load_blocking / jcp.load_block; - jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; - jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; - jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; - - jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); - jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - // miniumum size of load dim chunk for work distribution within threads - jcp.nb_load_chunk = 1; - // peformance improvements for googlenet_v3, mb=1; - // TODO: generalize this condition and rewrite it in appropriate manner - if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4 - && jcp.ic * jcp.oc <= L2_size) { - jcp.nb_load_chunk = 4; - jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count); - } - - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); - - jcp.wei_adj_scale = - (weights_d.extra().flags | memory_extra_flags::scale_adjust) - ? weights_d.extra().scale_adjust : 1.f; - - return status::success; -} - -void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( - memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { - using namespace mkldnn::impl::memory_tracking::names; - - if (jcp.signed_input && jcp.ver != ver_vnni) { - dim_t count = nstl::max(attr.output_scales_.count_, 16); - scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); - } -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp deleted file mode 100644 index 22e9732a1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp +++ /dev/null @@ -1,131 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP -#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t) - jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, - const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), - eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32( - this, jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); - } - - ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() { - delete eltwise_injector_; - } - - static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, - const memory_desc_wrapper &bias_d, - const primitive_attr_t &attr, - int nthreads, bool reduce_src); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); - - bool maybe_eltwise(int position); - - jit_1x1_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_1x1_conv_call_s *); - - private: - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - using reg64_t = const Xbyak::Reg64; - using zmm_t = const Xbyak::Zmm; - using mask_t = const Xbyak::Opmask; - - reg64_t reg_bcast_data = r8; - reg64_t reg_ptr_scales = r8; - reg64_t reg_output_data = r9; - reg64_t reg_load_data = r10; - reg64_t reg_ptr_sum_scale = r10; - reg64_t reg_reduce_loop_work = r11; - reg64_t reg_bias_data = r12; - reg64_t reg_comp_data = r12; - reg64_t reg_scratch = r13; - reg64_t aux_reg_bcast_data = r14; - reg64_t aux_reg_load_data = r15; - reg64_t imm_addr64 = r15; - reg64_t reg_reduce_pos_flag = rax; - reg64_t aux1_reg_bcast_data = rbx; - reg64_t reg_bcast_loop_work = rbx; - reg64_t bcast_loop_iter = rdx; // Note: Fix me - reg64_t reg_load_loop_work = rsi; - reg64_t aux_reg_output_data = abi_not_param1; - reg64_t reduce_loop_iter = abi_param1; - - reg64_t reg_last_load = r8; - mask_t ktail_mask = k6; - - mask_t vmask = k7; - - Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28); - Xbyak::Zmm zmm_one = Xbyak::Zmm(29); - Xbyak::Zmm zmm_zero = Xbyak::Zmm(30); - Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31); - Xbyak::Zmm zmm_shift = Xbyak::Zmm(30); - - Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31); - Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31); - - int bcast_loop_work_off = 0; - int reg_bias_data_off = 8; - int reg_bcast_data_off = 16; - int reg_load_data_off = 24; - int reg_ptr_sum_scale_off = 32; - int reg_comp_data_off = 40; - int stack_space_needed = 48; - - void bcast_loop(int load_loop_blk); - void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); - - void generate(); - void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, - bool mask_flag); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp deleted file mode 100644 index 0bf09fc67..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp +++ /dev/null @@ -1,292 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -namespace { -template -void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, - T nx, T &nx_start, T &nx_end, T nx_divider) -{ - const T grp_size = utils::div_up(nthr, nx_divider); - const T grp_count = utils::div_up(nthr, grp_size); - - T grp = ithr / grp_size; - T grp_ithr = ithr % grp_size; - T grp_nthr = grp_size; - T first_grps = nthr % grp_count; - if (first_grps > 0 && grp >= first_grps) { - ithr -= first_grps * grp_size; - grp_nthr--; - grp = ithr / grp_nthr + first_grps; - grp_ithr = ithr % grp_nthr; - } - balance211(nx, grp_count, grp, nx_start, nx_end); - balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); -} -} - -/* convolution forward */ -template -void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const -{ - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - auto scratchpad = this->scratchpad(ctx); - - if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { - auto local_scales = scratchpad.template get( - key_conv_adjusted_scales); - auto scales = pd()->attr()->output_scales_.scales_; - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / pd()->jcp_.wei_adj_scale; - if (count == 1) { - utils::array_set(local_scales, scales[0] * factor, 16); - } else { - for (size_t c = 0; c < count; c++) - local_scales[c] = scales[c] * factor; - } - } - - parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { - execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); - }); -} - -template -void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t -::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, - const wei_data_t *weights, const char *bias, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const size_t bia_dt_size = pd()->with_bias() - ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; - - const auto &jcp = kernel_->jcp; - auto rtus_space = scratchpad.get(key_conv_rtus_space); - auto local_scales = scratchpad.get(key_conv_adjusted_scales); - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - - const int stride_h = pd()->desc()->strides[0]; - const int stride_w = pd()->desc()->strides[1]; - const int pad_t = pd()->desc()->padding[0][0]; - const int pad_l = pd()->desc()->padding[0][1]; - - const auto &oscales = pd()->attr()->output_scales_; - - int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block) - * jcp.oc_block * jcp.ic_block; - wei_data_t *w = const_cast(weights); - int32_t* compensation = (jcp.signed_input) - ? reinterpret_cast(w + offset) : 0; - - auto step = [](int default_step, int remaining, int tail_step) { - assert(default_step <= tail_step); - return remaining < tail_step ? remaining : default_step; - }; - - auto p = jit_1x1_conv_call_s(); - - auto rp = rtus_driver_t::call_params_t(); - const int nb_oc = jcp.nb_load; - const int os_block = jcp.bcast_block; - - int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; - balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, - jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, - jcp.load_grp_count); - if (jcp.nb_load_chunk > 1) { - ocb_start *= jcp.nb_load_chunk; - ocb_end *= jcp.nb_load_chunk; - } - - auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, - int &oh, int &ow, int &ih, int &iw) - { - int osb{0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, - jcp.nb_bcast_blocking_max); - bcast_step = nstl::min(bcast_step, bcast_end - iwork); - - const int os = osb * os_block; - oh = os / jcp.ow; - ow = os % jcp.ow; - - ih = nstl::max(oh * stride_h - pad_t, 0); - iw = nstl::max(ow * stride_w - pad_l, 0); - rp.iw_start = iw; - - p.bcast_dim = this_block_size(os, jcp.os, - bcast_step * os_block); - rp.os = p.bcast_dim; - }; - - auto init_load = [&](int ocb, int &load_step) - { - load_step = step(jcp.nb_load_blocking, ocb_end - ocb, - jcp.nb_load_blocking_max); - p.load_dim = this_block_size(ocb * jcp.oc_block, - ocb_end * jcp.oc_block, load_step * jcp.oc_block); - - if (ocb + load_step >= nb_oc) - p.first_last_flag |= FLAG_OC_LAST; - else - p.first_last_flag &= ~FLAG_OC_LAST; - - }; - - auto init_reduce = [&]() - { - p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic); - rp.icb = p.reduce_dim / jcp.reduce_block; - }; - - auto inner_ker = [&](int ocb, int n, int g, int oh, int ow, - int ih, int iw) - { - const int icb = 0; // Start from the first IC block - const int _ocb = g * nb_oc + ocb; - const int _icb = g; - - const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow); - - p.output_data = &dst[dst_off]; - p.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; - p.compensation = (jcp.signed_input) - ? &compensation[_ocb * jcp.oc_block] : 0; - p.scales = (jcp.signed_input && jcp.ver != ver_vnni) - ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block] - : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]; - if (pd()->rtus_.reduce_src_) { - rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ - + _icb * jcp.is * jcp.ic_block; - if (ocb == ocb_start) { - rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); - rtus_driver_->ker_(&rp); - } - p.bcast_data = rp.ws; - } else - p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); - - kernel_->jit_ker(&p); - }; - - if (jcp.loop_order == loop_rlb) { - init_reduce(); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - inner_ker(ocb, n, g, oh, ow, ih, iw); - iwork += bcast_step; - } - ocb += load_step; - } - } else if (jcp.loop_order == loop_lbr) { - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - init_reduce(); - inner_ker(ocb, n, g, oh, ow, ih, iw); - iwork += bcast_step; - } - ocb += load_step; - } - } else if (jcp.loop_order == loop_rbl) { - init_reduce(); - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - inner_ker(ocb, n, g, oh, ow, ih, iw); - ocb += load_step; - } - iwork += bcast_step; - } - } else if (jcp.loop_order == loop_blr) { - int iwork = bcast_start; - while (iwork < bcast_end) { - int n, g, bcast_step, oh, ow, ih, iw; - init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); - int ocb = ocb_start; - while (ocb < ocb_end) { - int load_step; - init_load(ocb, load_step); - init_reduce(); - inner_ker(ocb, n, g, oh, ow, ih, iw); - ocb += load_step; - } - iwork += bcast_step; - } - } else { - assert(!"unsupported loop order"); - } -} - -using namespace data_type; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; -template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp deleted file mode 100644 index ad9027ac1..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp +++ /dev/null @@ -1,159 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP -#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" -#include "jit_uni_1x1_conv_utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_(), rtus_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""), - jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t< - src_type, dst_type>); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, data_type::s8, data_type::undef, - dst_type, data_type::s32) - && IMPLICATION(with_bias(), utils::one_of( - desc()->bias_desc.data_type, data_type::f32, - data_type::s32, data_type::s8, data_type::u8)) - && !has_zero_dim_memory() - && set_default_formats_common(dat_tag(), format_tag::any, - dat_tag()) - && set_or_check_wei_format(); - if (!ok) return status::unimplemented; - - const convolution_desc_t *conv_d = desc(); - const memory_desc_t *src_d = src_md(); - rtus_prepare(this, conv_d, src_d, dst_md()); - - status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel:: - init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), - *weights_md(1), *attr(), mkldnn_get_max_threads(), - rtus_.reduce_src_); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( - scratchpad, jcp_, *attr()); - - rtus_prepare_space_info(this, scratchpad); - - return status::success; - } - - jit_1x1_conv_conf_t jcp_; - reduce_to_unit_stride_t rtus_; - - protected: - format_tag_t dat_tag() const { return format_tag::nhwc; } - - bool set_or_check_wei_format() { - using namespace format_tag; - - const bool is_src_s8 = src_md_.data_type == data_type::s8; - format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i; - - memory_desc_t want_wei_md = weights_md_; - memory_desc_init_by_tag(want_wei_md, wei_tag); - if (is_src_s8) { - want_wei_md.extra.flags = 0 - | memory_extra_flags::compensation_conv_s8s8 - | memory_extra_flags::scale_adjust; - want_wei_md.extra.compensation_mask = (1 << 0) - + (with_groups() ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust = - mayiuse(avx512_core_vnni) ? 1.f : 0.5f; - } - - if (weights_md_.format_kind == format_kind::any) { - weights_md_ = want_wei_md; - return true; - } - - return weights_md_ == want_wei_md; - } - }; - - template - friend void init_rtus_driver(conv_t *self); - - jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr), rtus_driver_(nullptr) - { - kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_, - *pd()->attr()); - init_rtus_driver(this); - } - - ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() { - delete kernel_; - delete rtus_driver_; - } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - - private: - void execute_forward(const exec_ctx_t &ctx) const; - void execute_forward_thr(const int ithr, const int nthr, - const src_data_t *src, const wei_data_t *weights, - const char *bias, dst_data_t *dst, - const memory_tracking::grantor_t &scratchpad) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_; - rtus_driver_t *rtus_driver_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp deleted file mode 100644 index e89d06830..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP -#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" -#include "type_helpers.hpp" -#include "primitive_iterator.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_deconvolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_uni_1x1_conv_utils.hpp" -#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t - : public cpu_primitive_t { - struct pd_t : public cpu_deconvolution_fwd_pd_t { - pd_t(engine_t *engine, const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) {} - - pd_t(const pd_t &other) - : cpu_deconvolution_fwd_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) - {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), - jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t); - - status_t init_convolution() { - convolution_desc_t cd; - status_t status; - - auto dd = desc(); - status = conv_desc_init(&cd, prop_kind::forward_training, - alg_kind::convolution_direct, &(dd->src_desc), - &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc), - dd->strides, dd->dilates, dd->padding[0], dd->padding[1], - dd->padding_kind); - - if (status == status::success) { - status = mkldnn_primitive_desc::create( - &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr); - } - - if (status == status::success) - status = set_default_params(); - - return status; - }; - - status_t init() { - bool ok = true - && is_fwd() - && desc()->alg_kind == alg_kind::deconvolution_direct - && !has_zero_dim_memory() - && desc()->src_desc.data_type == src_type - && desc()->dst_desc.data_type == dst_type - && desc()->weights_desc.data_type == data_type::s8 - && IMPLICATION(with_bias(), utils::one_of( - desc()->bias_desc.data_type, data_type::f32, - data_type::s32, data_type::s8, data_type::u8)) - && desc()->accum_data_type == data_type::s32; - if (!ok) return status::unimplemented; - - CHECK(init_convolution()); - - return status::success; - } - - virtual void init_scratchpad_md() override { - const auto conv_1x1_pd = static_cast(conv_pd_); - scratchpad_md_ = *conv_1x1_pd->scratchpad_md(); - } - - protected: - status_t set_default_params() { - auto conv_1x1_pd_ = static_cast(conv_pd_); - src_md_ = *conv_1x1_pd_->src_md(); - dst_md_ = *conv_1x1_pd_->dst_md(); - weights_md_ = *conv_1x1_pd_->weights_md(); - if (with_bias()) - bias_md_ = *conv_1x1_pd_->weights_md(1); - return status::success; - } - - using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t - ::pd_t; - friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t; - primitive_desc_t *conv_pd_; - }; - - jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - - ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t() - { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - return conv_p_->execute(ctx); - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - primitive_t *conv_p_; -}; - -} -} -} - -#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp deleted file mode 100644 index 10e98a00c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp +++ /dev/null @@ -1,1182 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_memory.hpp" - -#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" - -#define GET_OFF(field) offsetof(jit_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -namespace { -void pick_loop_order(jit_conv_conf_t &jcp, int nthr) -{ - jcp.loop_order = loop_cwgn; - if (jcp.ngroups > 1) { - jcp.loop_order = loop_ngcw; - if (jcp.mb < nthr) - jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; - } -} -} - -template -bool _jit_avx512_core_x8s8s32x_fwd_kernel::maybe_eltwise(int position) -{ - using namespace primitive_kind; - const auto &p = attr_.post_ops_; - - if (position == 0) { - /* eltwise before sum */ - return p.contain(eltwise, 0); - } else if (position == 1) { - /* eltwise after sum */ - return p.contain(sum, 0) && p.contain(eltwise, 1); - } - - return false; -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::prepare_output(int ur_w) -{ - int nb_oc_block - = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; - for (int k = 0; k < nb_oc_block; k++) - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - vpxord(vmm, vmm, vmm); - } - if (jcp.signed_input) { - xor_(reg_scratch, reg_scratch); - if (jcp.is_depthwise && !jcp.is_fast_depthwise) { - Reg32 _t32 = reg_scratch.cvt32(); - mov(_t32, (uint32_t)128); - vpbroadcastd(vmm_shift, _t32); - } else { - Reg8 _t8 = reg_scratch.cvt8(); - mov(_t8, (int8_t)128); - vpbroadcastb(vmm_shift, _t8); - } - } -} - -template -const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel:: - vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) { - return vmm_in; -} - -template<> -const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel:: - vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) { - return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) - : zmm_in; -} - - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in, - const Vmm vmm_in, const Operand &op, bool mask_flag) { - //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in; - const Vmm vmm = vmm_mask(vmm_in, mask_flag); - switch (type_in) { - case data_type::f32: - case data_type::s32: vmovups(vmm, op); break; - case data_type::s8: vpmovsxbd(vmm, op); break; - case data_type::u8: vpmovzxbd(vmm, op); break; - default: assert(!"unsupported data type"); - } - if (type_in != data_type::f32) - vcvtdq2ps(vmm_in, vmm_in); -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_eltwise(int ur_w) { - int nb_oc_block - = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; - if (ur_w == jcp.ur_w) - eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w); - else - for (int k = 0; k < nb_oc_block; k++) - eltwise_injector_->compute_vector_range(k * jcp.ur_w, - k * jcp.ur_w + ur_w); -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::store_output( - int ur_w, bool last_oc_block_flag) { - int nb_oc_block - = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; - int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; - - mov(reg_bias, ptr[param1 + GET_OFF(bias)]); - mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); - if (jcp.signed_input) - mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); - - const auto &p = attr_.post_ops_; - const int sum_idx = p.find(primitive_kind::sum); - const float *p_sum_scale = nullptr; - if (sum_idx != -1) { - const auto &p_entry = p.entry_[sum_idx]; - p_sum_scale = &p_entry.sum.scale; - } - - if (p_sum_scale && *p_sum_scale != 1.f) - mov(reg_ptr_sum_scale, (size_t)p_sum_scale); - - if (jcp.signed_input && jcp.ver != ver_vnni) { - /* put 'wei_adj_scale = 0.5' for bias calculation */ - mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); - vmovq(xmm_bias_alpha(), reg_bias_alpha); - vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha()); - } - - for (int k = 0; k < nb_oc_block; k++) { - const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; - int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); - if (jcp.with_bias) { - int bias_offset = jcp.typesize_bia * k * oc_block; - auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); - - cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); - if (jcp.signed_input && jcp.ver != ver_vnni) - /* bias *= 0.5 */ - vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); - } - if (jcp.signed_input) { - int comp_offset = sizeof(int32_t) * k * oc_block; - auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); - - cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag); - } - /* add to zmm_accum: compensation, bias and permute */ - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - if (jcp.is_fast_depthwise) - vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k)); - vcvtdq2ps(vmm, vmm); - if (jcp.signed_input) - vaddps(vmm, vmm, vmm_comp); - if (jcp.with_bias) - vaddps(vmm, vmm, vmm_bias); - - const Vmm vmm_k = vmm_mask(vmm, mask_flag); - vmulps(vmm_k, vmm, - EVEX_compress_addr(reg_ptr_scales, scale_offset)); - } - } - - /* Do post-ops */ - if (maybe_eltwise(0)) compute_eltwise(ur_w); - if (p_sum_scale) { // post_op: sum - for (int k = 0; k < nb_oc_block; k++) { - const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; - for (int j = 0; j < ur_w; j++) { - int aux_output_offset - = jcp.typesize_out - * (k * oc_block - + j * jcp.oc_without_padding * jcp.ngroups); - auto addr = EVEX_compress_addr(reg_out, aux_output_offset); - Vmm vmm = vmm_out(j, k); - cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag); - if (*p_sum_scale == 1.f) - vaddps(vmm, vmm_prev_dst); - else - vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); - } - } - } - if (maybe_eltwise(1)) compute_eltwise(ur_w); - - /* write out register to output_addr */ - for (int k = 0; k < nb_oc_block; k++) { - const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; - for (int j = 0; j < ur_w; j++) { - Vmm vmm = vmm_out(j, k); - if (jcp.dst_dt == data_type::u8) { - vpxord(vmm_zero, vmm_zero, vmm_zero); - vmaxps(vmm, vmm_zero, vmm); - } - - if (jcp.dst_dt != data_type::f32) { - /* Note: using Zmm for rounding in Xmm/Ymm kernel - because there is no instruction to do rounding - from Xmm/Ymm -> Xmm/Ymm. - Embedded rounding is not supported for Xmm. - TODO: maybe avoid Zmm if it helps performance.*/ - Zmm zmm = zmm_out(j, k); - vcvtps2dq(zmm, zmm); - } - } - - for (int j = 0; j < ur_w; j++) { - int aux_output_offset = jcp.typesize_out - * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); - auto addr = EVEX_compress_addr(reg_out, aux_output_offset); - - Vmm vmm = vmm_out(j, k); - const Vmm r_vmm = vmm_mask(vmm, mask_flag, true); - - switch (jcp.dst_dt) { - case data_type::f32: - case data_type::s32: vmovups(addr, r_vmm); break; - case data_type::s8: vpmovsdb(addr, r_vmm); break; - case data_type::u8: vpmovusdb(addr, r_vmm); break; - default: assert(!"unknown dst_dt"); - } - } - } - -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( - int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { - assert(!"invalid group blocking for depthwise convolution"); -} - -template <> -void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker_dw( - int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { - - auto input_spatial_index = [=](int oi, int ki) { - return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); - }; - - auto input_offset2 = [=](int ii, int ci) { - return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); - }; - - auto input_offset3 = [=](int oi, int ci, int ki) { - return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); - }; - - auto kernel_offset = [=](int ci, int ki) { - return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); - }; - - auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { - // okay for depthwise since src is zero-extended - if (jcp.ver == ver_vnni) { - vpdpbusd(vreg_acc, vreg_src, vreg_wei); - } else { - vpmaddwd(zmm_tmp, vreg_src, vreg_wei); - vpaddd(vreg_acc, vreg_acc, zmm_tmp); - } - }; - - int ii_start = 0; - int ii_end = -1; - if (jcp.is_resrc_depthwise && !h_padded) { - // find bounds of input spatial indices - bool first = true; - for (int ki = 0; ki < jcp.kw; ki++) { - int oi_start = get_ow_start(ki, pad_l); - int oi_end = get_ow_end(ur_w, ki, pad_r); - for (int oi = oi_start; oi < oi_end; oi++) { - int ii = input_spatial_index(oi, ki); - if (first || ii < ii_start) - ii_start = ii; - if (first || ii > ii_end) - ii_end = ii; - first = false; - } - } - } - - if (jcp.signed_input) { - vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero); - vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift); - } - for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { - const bool mask_flag = last_ic_block_flag != no_last_block - && ci == jcp.nb_ch_blocking - 1; - if (jcp.is_resrc_depthwise && !h_padded) { - // now we can load input once and reuse up to jcp.kw times - for (int ii = ii_start; ii <= ii_end; ii++) { - int aux_input_offset = input_offset2(ii, ci); - const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking); - const Zmm zmm_inp_msk = mask_flag - ? zmm_inp_tmp | ktail_mask | T_z - : zmm_inp_tmp; - if (jcp.is_fast_depthwise) { - assert(!mask_flag); - vbroadcasti32x4(zmm_inp_msk, - EVEX_compress_addr(aux_reg_inp, aux_input_offset)); - } else { - vpmovzxbd(zmm_inp_msk, - EVEX_compress_addr(aux_reg_inp, aux_input_offset)); - } - if (jcp.signed_input) - vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift); - } - } - for (int ki = 0; ki < jcp.kw; ki++) { - int aux_kernel_offset = kernel_offset(ci, ki); - if (jcp.is_fast_depthwise) { - vbroadcasti32x4(zmm_wei, - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei); - } else { - vpmovsxbd(zmm_wei, - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - } - if (h_padded) { - assert(jcp.signed_input); - for (int oi = 0; oi < ur_w; oi++) - compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); - } else { - const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src; - int oi_start = get_ow_start(ki, pad_l); - int oi_end = get_ow_end(ur_w, ki, pad_r); - int start_ = jcp.signed_input ? 0 : oi_start; - int end_ = jcp.signed_input ? ur_w : oi_end; - for (int oi = start_; oi < end_; oi++) { - if (oi >= oi_start && oi < oi_end) { - if (jcp.is_resrc_depthwise) { - int ii = input_spatial_index(oi, ki); - zmm_src = zmm_inp(ii, jcp.nb_ch_blocking); - } else { - int aux_input_offset = input_offset3(oi, ci, ki); - if (jcp.is_fast_depthwise) { - assert(!mask_flag); - vbroadcasti32x4(r_zmm_src, - EVEX_compress_addr(aux_reg_inp, - aux_input_offset)); - } else { - vpmovzxbd(r_zmm_src, - EVEX_compress_addr(aux_reg_inp, - aux_input_offset)); - } - if (jcp.signed_input) - vpaddb(zmm_src, zmm_src, vmm_shift); - } - } else if (jcp.signed_input) { - zmm_src = zmm_shifted_zero; - } - compute(zmm_out(oi, ci), zmm_wei, zmm_src); - } - } - } - } -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l, - int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { - if (jcp.is_depthwise) - return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); - - int kw = jcp.kw; - int stride_w = jcp.stride_w; - int ic_block = jcp.ic_block; - int oc_block = jcp.oc_block; - int ch_block_all = jcp.ch_block * ic_block * oc_block; - - int nb_oc_block = jcp.nb_oc_blocking; - - auto input_offset = [=](int oi, int ic, int ki) { - return jcp.typesize_in - * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) - * jcp.ic_without_padding * jcp.ngroups + 4 * ic); - }; - auto kernel_offset = [=](int ii, int ic, int ki) { - return jcp.typesize_in - * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all - + 4 * ic * oc_block); - }; - auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { - if (jcp.ver == ver_vnni) { - vpdpbusd(vreg_acc, vreg_src, vreg_wei); - } else { - vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); - vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); - vpaddd(vreg_acc, vreg_acc, vmm_tmp); - } - }; - - for (int ki = 0; ki < kw; ki++) { - int jj_start = get_ow_start(ki, pad_l); - int jj_end = get_ow_end(ur_w, ki, pad_r); - int tail_size = jcp.ic_without_padding % 4; - int _start = (jcp.signed_input) ? 0 : jj_start; - int _end = (jcp.signed_input) ? ur_w : jj_end; - /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */ - int icb = (last_ic_block_flag != no_last_block) - ? div_up((jcp.ic_without_padding % ic_block), 4) - : ic_block / 4; - for (int ic = 0; ic < icb; ic++) { - if (h_padded == true) { - /* fill padded area with shifted values */ - Vmm inp = vmm_inp(0,nb_oc_block); - vpxord(inp, inp, inp); - vpaddb(inp, inp, vmm_shift); - } else { - for (int jj = _start; jj < _end; jj++) { - int aux_input_offset = input_offset(jj, ic, ki); - if (jj >= jj_start && jj < jj_end) { - if (last_ic_block_flag == last_sp_block - && tail_size != 0 && ic == icb - 1) { - Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx()); - for (int r = 0; r < tail_size; ++r) - vpinsrb(xmm_tmp, xmm_tmp, - ptr[aux_reg_inp + aux_input_offset + r], r); - vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp); - } else { - vpbroadcastd(vmm_inp(jj, nb_oc_block), - EVEX_compress_addr( - aux_reg_inp, aux_input_offset)); - } - if (jcp.signed_input) - vpaddb(vmm_inp(jj, nb_oc_block), - vmm_inp(jj, nb_oc_block), vmm_shift); - } else { - /* fill padded area with shifted values */ - if (jcp.signed_input) { - Vmm inp = vmm_inp(jj, nb_oc_block); - vpxord(inp, inp, inp); - vpaddb(inp, inp, vmm_shift); - } - } - } - } - for (int ii = 0; ii < nb_oc_block; ii++) { - int aux_kernel_offset = kernel_offset(ii, ic, ki); - vmovups(vmm_wei, - EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); - for (int jj = _start; jj < _end; jj++) { - Vmm inp = (h_padded == true) - ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block); - compute(vmm_out(jj, ii), vmm_wei, inp); - } - } - } - } -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::kh_loop( - int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { - Label kh_label, skip_kh_loop; - Label t_overflow_label, no_t_overflow_label, - b_overflow_label, no_b_overflow_label; - - int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; - int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; - int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw - * jcp.ic_without_padding * jcp.ngroups; - - mov(aux_reg_inp, reg_inp); - mov(aux_reg_ker, reg_ker); - - if (jcp.signed_input && jcp.ndims > 3) { - mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); - cmp(reg_overflow, 0); - je(no_t_overflow_label, T_NEAR); - L(t_overflow_label); { - compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); - - add(aux_reg_ker, shift_kernel_ptr); - dec(reg_overflow); - cmp(reg_overflow, 0); - jg(t_overflow_label, T_NEAR); - } - L(no_t_overflow_label); - } - mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); - if ((jcp.signed_input) || (!jcp.signed_input && - (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { - cmp(reg_kj, 0); - je(skip_kh_loop, T_NEAR); - } - L(kh_label); { - compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); - - add(aux_reg_ker, shift_kernel_ptr); - add(aux_reg_inp, shift_input_ptr); - dec(reg_kj); - cmp(reg_kj, 0); - jg(kh_label, T_NEAR); - } - L(skip_kh_loop); - if (jcp.signed_input && jcp.ndims > 3) { - mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); - cmp(reg_overflow, 0); - je(no_b_overflow_label, T_NEAR); - L(b_overflow_label); { - compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); - - add(aux_reg_ker, shift_kernel_ptr); - dec(reg_overflow); - cmp(reg_overflow, 0); - jg(b_overflow_label, T_NEAR); - } - L(no_b_overflow_label); - } -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::icb_loop( - int ur_w, int pad_l, int pad_r, bool is_last_sp_block) -{ - prepare_output(ur_w); - - // IC loop - Label icb_label; - mov(reg_icb, jcp.nb_ic); - L(icb_label); - if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { - Label common_ker, end_ker; - - cmp(reg_icb, 1); // The last IC block - jne(common_ker, T_NEAR); - - kh_loop(ur_w, pad_l, pad_r, - is_last_sp_block ? last_sp_block : last_ic_block); - jmp(end_ker, T_NEAR); - - L(common_ker); - kh_loop(ur_w, pad_l, pad_r, no_last_block); - - L(end_ker); - } else { - kh_loop(ur_w, pad_l, pad_r, no_last_block); - } - // End of IC Loop - int inp_step = jcp.ic_block; - int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block; - add(reg_inp, jcp.typesize_in * inp_step); - add(reg_ker, jcp.typesize_in * ker_step); - - dec(reg_icb); - cmp(reg_icb, 0); - jg(icb_label, T_NEAR); - - sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); - sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic); - - if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { - Label common_store, end_store; - - if (jcp.is_depthwise) - cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); - else - cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); - - jne(common_store, T_NEAR); - - store_output(ur_w, true); // last oc block - jmp(end_store, T_NEAR); - - L(common_store); - store_output(ur_w, false); - - L(end_store); - } else { - store_output(ur_w, false); - } -} - -template -void _jit_avx512_core_x8s8s32x_fwd_kernel::generate() -{ - Label permute_index_table; - int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad) - * jcp.ic_without_padding * jcp.ngroups; - int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad - * jcp.ic_without_padding * jcp.ngroups; - int inp_shift = jcp.typesize_in * - (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding - * jcp.ngroups); - int out_shift = jcp.typesize_out * - (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); - preamble(); - - if (jcp.is_depthwise) { - int idx = jcp.max_regs_ur - 1; - if (!jcp.is_resrc_depthwise) - zmm_src = Zmm(++idx); - if (jcp.ver != ver_vnni) - zmm_tmp = Zmm(++idx); - if (jcp.is_fast_depthwise) - zmm_permute = Zmm(++idx); - if (jcp.signed_input) { - zmm_shifted_zero = Zmm(++idx); - ++idx; // due to extra register used for shifts and compensations - } - assert(idx == ker_dw_reg_base_idx); - } - - if (!jcp.is_depthwise && jcp.ver != ver_vnni) { - xor_(reg_scratch, reg_scratch); - Reg16 _t16 = reg_scratch.cvt16(); - mov(_t16, 0x1); - vpbroadcastw(vmm_one, _t16); - } - - mov(reg_inp, ptr[param1 + GET_OFF(src)]); - mov(reg_out, ptr[param1 + GET_OFF(dst)]); - mov(reg_ker, ptr[param1 + GET_OFF(filt)]); - - if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { - int tail_size = jcp.is_depthwise - ? jcp.ngroups % jcp.ch_block - : jcp.oc_without_padding % jcp.oc_block; - int mask = (1 << tail_size) - 1; - mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); - Reg32 regw_tmp = reg_oi.cvt32(); - mov(regw_tmp, mask); - kmovw(ktail_mask, regw_tmp); - } - if (jcp.is_fast_depthwise) { - // prepare mask register for blending weights - mov(reg_scratch, 0x8888444422221111); - kmovq(kblend_mask, reg_scratch); - // load permute indices from data section - mov(reg_scratch, permute_index_table); - vmovdqu32(zmm_permute, ptr[reg_scratch]); - } - - int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.iw + jcp.l_pad - 1)); - int n_oi = jcp.ow / jcp.ur_w; - int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); - - if (jcp.nb_ow == 1) { - if (r_pad1 > 0 || jcp.ur_w_tail == 0) - n_oi--; - - xor_(reg_oi, reg_oi); - if (jcp.ow == jcp.ur_w) { - icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true); - } else { - if (n_oi == 0) { - icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0); - add(reg_inp, inp_shift_pad); - add(reg_out, out_shift); - if (jcp.ur_w_tail != 0) { - icb_loop(jcp.ur_w_tail, 0, r_pad, true); - } - } else { - if (jcp.l_pad > 0) { - icb_loop(jcp.ur_w, jcp.l_pad, 0, false); - add(reg_inp, inp_shift_pad); - add(reg_out, out_shift); - - inc(reg_oi); - } - if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) - { - Label ow_loop_label; - L(ow_loop_label); { - icb_loop(jcp.ur_w, 0, 0, false); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - - inc(reg_oi); - cmp(reg_oi, n_oi); - jl(ow_loop_label, T_NEAR); - } - } - if (r_pad1 > 0 || jcp.ur_w_tail == 0) { - icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - } - if (jcp.ur_w_tail != 0) { - icb_loop(jcp.ur_w_tail, 0, r_pad, true); - } - } - } - } else { - // ow block is only processed. - // Number of block is passed as parameter owb, - // and padding processing depends on this number. - Label end_label, last_oi_label, middle_ow_blocks_label, tail_label, - oi_loop_label, oi_loop_end_label; - - assert(jcp.ow_block % jcp.ur_w == 0); - int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w; - // to simplify code (and general regs usage), - // size of ow block must be >= 2 * ur_w - assert(n_oi_not_last_ow_block > 1); - int n_oi_next_last_ow_block = n_oi_not_last_ow_block; - int n_oi_first_ow_block = n_oi_not_last_ow_block; - int n_oi_last_ow_block - = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w; - // prepare right padding - bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; - bool first_ow_block_padded - = next_last_ow_block_padded && jcp.nb_ow == 2; - bool last_ow_block_padded - = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0; - - if (last_ow_block_padded) n_oi_last_ow_block--; - else if (first_ow_block_padded) n_oi_first_ow_block--; - else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; - - mov(reg_owb, ptr[param1 + GET_OFF(owb)]); - cmp(reg_owb, 0); // is that the first ow-block ? - jg(middle_ow_blocks_label, T_NEAR); - - // the first ow block, compute left padding - mov(reg_oi, n_oi_first_ow_block); - if (jcp.l_pad > 0) { - icb_loop(jcp.ur_w, jcp.l_pad, 0, false); - add(reg_inp, inp_shift_pad); - add(reg_out, out_shift); - - dec(reg_oi); - } - jmp(oi_loop_label, T_NEAR); - - // middle or last ow block entry - L(middle_ow_blocks_label); - - if (jcp.l_pad > 0) { - // just to consider left padding, not compute - add(reg_inp, inp_shift_pad_second_block); - } - - // set number of iteration for oi-loop - if (n_oi_last_ow_block != n_oi_not_last_ow_block) { - cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? - mov(reg_oi, n_oi_last_ow_block); - je(oi_loop_label, T_NEAR); - } - - if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) { - cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? - - mov(reg_oi, n_oi_next_last_ow_block); - je(oi_loop_label, T_NEAR); - } - mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks - - // oi loop w/o padding - L(oi_loop_label); { - cmp(reg_oi, 0); - jle(oi_loop_end_label, T_NEAR); - - icb_loop(jcp.ur_w, 0, 0, false); - - add(reg_inp, inp_shift); - add(reg_out, out_shift); - dec(reg_oi); - - jmp(oi_loop_label, T_NEAR); - } - L(oi_loop_end_label); - - mov(reg_owb, ptr[param1 + GET_OFF(owb)]); - cmp(reg_owb, 0); // first ow-block ? - if (first_ow_block_padded) - je(last_oi_label, T_NEAR); - else - je(end_label, T_NEAR); - - cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? - jl(end_label, T_NEAR); - if (next_last_ow_block_padded) - je(last_oi_label, T_NEAR); - else - je(end_label, T_NEAR); - - // that is last block - if (!last_ow_block_padded) - jmp(tail_label, T_NEAR); - - // last oi block with right padding - L(last_oi_label); - icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); - add(reg_inp, inp_shift); - add(reg_out, out_shift); - - mov(reg_owb, ptr[param1 + GET_OFF(owb)]); - cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? - jl(end_label, T_NEAR); - - // ur_w tail - L(tail_label); - if (jcp.ur_w_tail != 0) { - icb_loop(jcp.ur_w_tail, 0, r_pad, true); - } - L(end_label); - } - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); - - if (jcp.is_fast_depthwise) { - align(64); - L(permute_index_table); - const uint32_t _idx[] - = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; - for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) - dd(_idx[i]); - } -} - -bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) -{ - using namespace primitive_kind; - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - - switch (p.len_) { - case 0: return true; - case 1: return is_eltwise(0) || p.contain(sum, 0); - case 2: return (p.contain(sum, 0) && is_eltwise(1)) || - (p.contain(sum, 1) && is_eltwise(0)); - default: return false; - } - - return false; -} - -status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const primitive_attr_t &attr, - int nthreads) -{ - using namespace prop_kind; - - const memory_desc_wrapper src_d(&src_md); - const memory_desc_wrapper weights_d(&weights_md); - const memory_desc_wrapper dst_d(&dst_md); - const memory_desc_wrapper bias_d(&bias_md); - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - int ndims = src_d.ndims(); - bool is_1d = ndims == 3; - - if (!(mayiuse(avx512_core) - && one_of(src_d.data_type(), data_type::u8, data_type::s8) - && weights_d.data_type() == data_type::s8 - && one_of(dst_d.data_type(), data_type::f32, data_type::s32, - data_type::s8, data_type::u8))) - return status::unimplemented; - - jcp = zero(); - jcp.ndims = ndims; - jcp.prop_kind = cd.prop_kind; - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.ic_without_padding = jcp.ic; - jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; - jcp.ow = dst_d.dims()[ndims - 1]; - jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; - jcp.l_pad = cd.padding[0][ndims - 3]; - jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; - jcp.stride_w = cd.strides[ndims - 3]; - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - jcp.ur_h = 1; /* no code-unrolling by h so far */ - - jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; - jcp.dilate_w = cd.dilates[ndims - 3]; - - jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; - jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); - - if (jcp.is_depthwise) { - jcp.ch_block = 16; - jcp.ic_block = 1; - jcp.oc_block = 1; - } else { - jcp.ch_block = 1; - jcp.ic_block = 16; - jcp.oc_block = 16; - - if (jcp.ngroups == 1) { - /* For non grouped convolutions, pad channels by 16 if needed */ - jcp.oc = rnd_up(jcp.oc, jcp.oc_block); - jcp.ic = rnd_up(jcp.ic, jcp.ic_block); - } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) { - /* For grouped convolutions, MKL-DNN doesn't support padding. - Use Ymm when channels per group is multiple of 8, - Xmm when channels per group is multiple of 4 */ - jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4; - jcp.oc_block = jcp.ic_block; - } - if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0) - return status::unimplemented; - } - - jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) - - (jcp.ih + jcp.t_pad - 1); - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core; - jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni - && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 - // would require byte masking - // for load from src - jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw - && jcp.kw < 4 && jcp.dilate_w == 0; - if (jcp.is_depthwise) { - jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise - - 2 * jcp.signed_input - (jcp.ver != ver_vnni); - } else { - jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28; - } - - auto set_or_check_wei_format = [&]() { - using namespace format_tag; - format_tag_t wei_tag; - if (jcp.ic_block == 16 || jcp.ch_block == 16) { - if (is_1d) { - wei_tag = with_groups - ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i - : OIw4i16o4i; - } else { - wei_tag = with_groups - ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i - : OIhw4i16o4i; - } - } else if (with_groups && jcp.ic_block == 8) { - wei_tag = gOIhw2i8o4i; - } else - wei_tag = gOIhw4o4i; - - memory_desc_t want_wei_md = weights_md; - memory_desc_init_by_tag(want_wei_md, wei_tag); - if (jcp.signed_input) { - want_wei_md.extra.flags = 0 - | memory_extra_flags::compensation_conv_s8s8 - | memory_extra_flags::scale_adjust; - want_wei_md.extra.compensation_mask = (1 << 0) - + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust = - mayiuse(avx512_core_vnni) ? 1.f : 0.5f; - } - - if (weights_md.format_kind == format_kind::any) { - weights_md = want_wei_md; - return true; - } - - return weights_md == want_wei_md; - }; - - if (!set_or_check_wei_format()) - return status::unimplemented; - - format_tag_t dat_tag = utils::pick(ndims - 3, - format_tag::nwc, format_tag::nhwc); - - if (src_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md, dat_tag)); - jcp.src_tag = dat_tag; - } else { - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - } - if (jcp.src_tag != dat_tag) - return status::unimplemented; - - if (dst_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); - jcp.dst_tag = dat_tag; - } else { - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - } - if (jcp.dst_tag != dat_tag) - return status::unimplemented; - - if (jcp.with_bias) { - if (bias_d.format_kind() == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); - } - - jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; - jcp.dst_dt = cd.dst_desc.data_type; - - jcp.typesize_in = types::data_type_size(src_d.data_type()); - jcp.typesize_out = types::data_type_size(dst_d.data_type()); - jcp.typesize_bia = jcp.with_bias - ? types::data_type_size(bias_d.data_type()) - : 0; - - jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); - jcp.nb_ic = jcp.ic / jcp.ic_block; - jcp.nb_oc = jcp.oc / jcp.oc_block; - - // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) - int nb_ch_blocking = 4; - for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--) - if (jcp.nb_ch % nb_ch_blocking == 0) - break; - jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; - - // If OC blocking is incommensurate with the number of OC blocks (general - // requirement for all convolutions), or if it results in an unrolling - // factor smaller than the left padding (special requirement for SSD:fc6), - // then search for a smaller OC blocking that satisfies both constraints. - auto is_oc_blocking_ok = [&](int block) { - int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); - return jcp.nb_oc % block == 0 - && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1; - }; - - // choose nb_oc work chunk size for distribution within threads - int max_threading_nb_oc_chunk = 4; - // Performance improvements for googlenet_v3 and resnet_50 with mb = 1; - // TODO: generalize this condition and rewrite it in appropriate manner - if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3 - && jcp.stride_w == 1 && jcp.ic % 64 == 0) - max_threading_nb_oc_chunk = 2; - jcp.nb_oc_blocking_thr_chunk = - nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); - for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) { - if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) - break; - } - - // choose oc blocking for computational kernel - jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; - // Performance improvements for googlenet_v3 with mb = 1; - // TODO: generalize this condition and rewrite it in appropriate manner - const int size_treshold_for_nb_oc_blocking_reduction = 17; - if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction - && jcp.stride_w == 1 - && !(jcp.kh == 1 && jcp.kw == 3) - && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) { - const int max_nb_oc_blocking = 2; - jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc); - for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) - if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0 - && is_oc_blocking_ok(jcp.nb_oc_blocking)) - break; - } - - if (jcp.is_resrc_depthwise) - jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) - / (jcp.nb_ch_blocking + jcp.stride_w); - else - jcp.ur_w - = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking - : jcp.nb_oc_blocking + 1); - if (jcp.ow < jcp.ur_w) - jcp.ur_w = jcp.ow; - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - - jcp.ow_block = jcp.ow; - int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh - * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); - float best_thr_eff - = (float)base_work_amount / rnd_up(base_work_amount, nthreads); - int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w); - for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { - int ow_block - = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow); - if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block - && best_thr_eff > 0.8f) - break; - if (div_up(jcp.ow, ow_block) != nb_ow) - continue; - auto work_amount = base_work_amount * nb_ow; - float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads); - if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) { - jcp.ow_block = ow_block; - best_thr_eff = thr_eff; - } - if (best_thr_eff > 0.9f) - break; - } - jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); - - bool args_ok = true - && jcp.oc % jcp.oc_block == 0 - && jcp.l_pad <= jcp.ur_w - && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0); - if (!args_ok) - return status::unimplemented; - - int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.iw + jcp.l_pad - 1)); - if (r_pad_no_tail > jcp.ur_w) - return status::unimplemented; - - pick_loop_order(jcp, nthreads); - - jcp.nb_ic_L2 = jcp.nb_ic; - - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - - assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); - - jcp.wei_adj_scale = - (weights_d.extra().flags | memory_extra_flags::scale_adjust) - ? weights_d.extra().scale_adjust : 1.f; - - return status::success; -} - -void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, - const primitive_attr_t &attr) { - if (jcp.signed_input && jcp.ver != ver_vnni) { - dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block); - scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); - } -} - -template struct _jit_avx512_core_x8s8s32x_fwd_kernel; -template struct _jit_avx512_core_x8s8s32x_fwd_kernel; -template struct _jit_avx512_core_x8s8s32x_fwd_kernel; -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp deleted file mode 100644 index d8a05ad53..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp +++ /dev/null @@ -1,239 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP -#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t) - - enum { STATE_FIRST_DST_LOAD = 0x1U }; - - _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, - const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), - eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32( - this, jcp.eltwise); - - generate(); - jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); - } - - ~_jit_avx512_core_x8s8s32x_fwd_kernel() { - delete eltwise_injector_; - } - - jit_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker_)(jit_conv_call_s *); - -private: - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - enum { - typesize = sizeof(float), - ker_reg_base_idx = 28, - ker_dw_reg_base_idx = 30, - }; - typedef enum { - no_last_block, - last_ic_block, - last_sp_block, - } ic_block_t; - - /* data regs */ - const Xbyak::Reg64 reg_ptr_scales = rax; - const Xbyak::Reg64 reg_inp = r8; - const Xbyak::Reg64 reg_ker = r9; - const Xbyak::Reg64 reg_out = r10; - const Xbyak::Reg64 aux_reg_inp = r11; - const Xbyak::Reg64 reg_ptr_sum_scale = r11; - const Xbyak::Reg64 aux_reg_ker = r12; - const Xbyak::Reg64 reg_compensation = r14; - /* counter regs */ - const Xbyak::Reg64 reg_bias_alpha = abi_not_param1; - const Xbyak::Reg64 reg_oi = rbx; - const Xbyak::Reg64 reg_bias = rdx; - const Xbyak::Reg64 reg_oc_blocks = rsi; - const Xbyak::Reg64 reg_owb = aux_reg_ker; - const Xbyak::Reg64 reg_scratch = reg_compensation; - const Xbyak::Reg64 reg_kj = reg_ptr_scales; - const Xbyak::Reg64 reg_overflow = reg_ptr_scales; - const Xbyak::Reg64 reg_icb = reg_bias; - - const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); - const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); - - const Vmm vmm_wei = Vmm(31); - /* used during bias section of store_output */ - const Vmm vmm_comp = Vmm(30); // only for signed input - const Vmm vmm_bias = Vmm(31); - /* used during post_op sum section of store_output */ - const Vmm vmm_prev_dst = Vmm(31); - /* used during write-out section of store_output */ - const Vmm vmm_zero = Vmm(31); - - /* used in compute_ker (but set during prepare_output) */ - const Vmm vmm_shift = vmm_comp; // only for signed input - /* used in compute_ker (but only for pre-VNNI machines) */ - const Vmm vmm_tmp = Vmm(28); // not used for depthwise - const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise. - - /* registers use only for depthwise - groups are always blocked by 16(padded if needed), - hence use only Zmm registers */ - const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); - Xbyak::Zmm zmm_tmp; - Xbyak::Zmm zmm_src; - Xbyak::Zmm zmm_shifted_zero; - Xbyak::Zmm zmm_permute; - - Vmm vmm_out(int i_ur, int i_oc) { - int idx = i_ur + i_oc * jcp.ur_w; - assert(idx < (jcp.is_depthwise - ? ker_dw_reg_base_idx : ker_reg_base_idx)); - return Vmm(idx); - } - Xbyak::Zmm zmm_out(int i_ur, int i_oc) { - int idx = i_ur + i_oc * jcp.ur_w; - assert(idx < (jcp.is_depthwise - ? ker_dw_reg_base_idx : ker_reg_base_idx)); - return Xbyak::Zmm(idx); - } - Vmm vmm_inp(int i_ic, int nb_x_blocking) { - int idx = i_ic + nb_x_blocking * jcp.ur_w; - assert(idx < 31); - return Vmm(idx); - } - Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { - int idx = i_ic + nb_x_blocking * jcp.ur_w; - assert(idx < 31); - return Xbyak::Zmm(idx); - } - Vmm vmm_bias_alpha() { - int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; - return Vmm(nb_c_block * jcp.ur_w); - } - Xbyak::Xmm xmm_bias_alpha() { - int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; - return Xbyak::Xmm(nb_c_block * jcp.ur_w); - } - int get_ow_start(int ki, int pad_l) { - return nstl::max(0, - utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); - } - int get_ow_end(int ur_w, int ki, int pad_r) { - return ur_w - nstl::max(0, utils::div_up(pad_r - - (jcp.kw - 1 - ki) - * (jcp.dilate_w + 1), - jcp.stride_w)); - } - - bool maybe_eltwise(int position); - void prepare_output(int ur_w); - void store_output(int ur_w, bool last_oc_block_flag); - void compute_ker_dw( - int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded); - void compute_ker(int ur_w, int pad_l, int pad_r, - ic_block_t last_ic_block_flag, bool h_padded = false); - void compute_eltwise(int ur_w); - void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); - void icb_loop( - int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); - void generate(); - void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, - bool mask_flag); - const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false); -}; - -struct jit_avx512_core_x8s8s32x_fwd_kernel { - - jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, - const primitive_attr_t &attr) : - jit_ker(nullptr), - zmm_kernel_(nullptr), - ymm_kernel_(nullptr), - xmm_kernel_(nullptr) { - int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; - switch (ch_block) { - case 16: - zmm_kernel_ = - new _jit_avx512_core_x8s8s32x_fwd_kernel( - ajcp, attr); - jit_ker = zmm_kernel_->jit_ker_; - return; - case 8: - ymm_kernel_ = - new _jit_avx512_core_x8s8s32x_fwd_kernel( - ajcp, attr); - jit_ker = ymm_kernel_->jit_ker_; - return; - case 4: - xmm_kernel_ = - new _jit_avx512_core_x8s8s32x_fwd_kernel( - ajcp, attr); - jit_ker = xmm_kernel_->jit_ker_; - return; - default: - assert(!"invalid channel blocking"); - } - } - - ~jit_avx512_core_x8s8s32x_fwd_kernel() { - delete xmm_kernel_; - delete ymm_kernel_; - delete zmm_kernel_; - } - - static bool post_ops_ok(jit_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, - memory_desc_t &src_pd, - memory_desc_t &weights_pd, - memory_desc_t &dst_pd, - memory_desc_t &bias_pd, - const primitive_attr_t &attr, - int nthreads); - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp, const primitive_attr_t &attr); - - void (*jit_ker)(jit_conv_call_s *); - _jit_avx512_core_x8s8s32x_fwd_kernel *zmm_kernel_; - _jit_avx512_core_x8s8s32x_fwd_kernel *ymm_kernel_; - _jit_avx512_core_x8s8s32x_fwd_kernel *xmm_kernel_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp deleted file mode 100644 index cdbf333d5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp +++ /dev/null @@ -1,423 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_avx512_core_x8s8s32x_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -using namespace nstl; - -using jit_conv_ker_t = void (*)(jit_conv_call_s *); - -#define wht_blk_off(d, g, ...) \ - (pd()->with_groups() \ - ? (d).blk_off((g), __VA_ARGS__) \ - : (d).blk_off(__VA_ARGS__)) - -template -void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const size_t bia_dt_size = pd()->with_bias() - ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; - - const auto &jcp = pd()->jcp_; - assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); - assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); - - const float *oscales = pd()->attr()->output_scales_.scales_; - if (jcp.signed_input && jcp.ver != ver_vnni) { - auto local_scales = scratchpad(ctx).template get( - key_conv_adjusted_scales); - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / pd()->jcp_.wei_adj_scale; - if (count == 1) { - utils::array_set(local_scales, oscales[0] * factor, 16); - } else { - for (size_t c = 0; c < count; c++) - local_scales[c] = oscales[c] * factor; - } - oscales = local_scales; - } - - size_t offset = weights_d.size() - weights_d.additional_buffer_size(); - auto w = const_cast(weights); - int32_t* compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) : 0; - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; - int group_block = jcp.ch_block; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; - - parallel(0, [&](const int ithr, const int nthr) { - - int start{ 0 }, end{ 0 }; - balance211(work_amount, nthr, ithr, start, end); - - auto p = jit_conv_call_s(); - - int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 }; - switch (jcp.loop_order) { - case loop_cwgn: - nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, - nb_groups, n, jcp.mb); - break; - case loop_gncw: - nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, - owb, jcp.nb_ow); - break; - case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks, - owb, jcp.nb_ow); - break; - case loop_nwcg: - nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, - gg, nb_groups); - break; - default: assert(!"unsupported loop order"); - } - while (start < end) { - int ocb = occ * jcp.nb_oc_blocking; - int gb = gg * jcp.nb_ch_blocking; - int g = gb * group_block; - int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; - int g_ic = g * jcp.nb_ic * jcp.ic_block; - int ow_s = owb * jcp.ow_block; - int iw_s = ow_s * jcp.stride_w; - - p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0; - p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; - p.dst = dst + dst_d.blk_off(n, g_oc, ow_s); - p.src = src + src_d.blk_off(n, g_ic, iw_s); - p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); - p.scales = &oscales[jcp.is_oc_scale * g_oc]; - p.oc_blocks = jcp.is_depthwise ? gb : ocb; - p.kh_padding = jcp.kh; - p.t_overflow = 0; - p.b_overflow = 0; - p.owb = owb; - - kernel_->jit_ker(&p); - - ++start; - switch (jcp.loop_order) { - case loop_cwgn: - nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups, - n, jcp.mb); - break; - case loop_gncw: - nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb, - jcp.nb_ow); - break; - case loop_ngcw: - nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb, - jcp.nb_ow); - break; - case loop_nwcg: - nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, - nb_groups); - break; - default: assert(!"unsupported loop order"); - } - } - }); -} - -template -void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const size_t bia_dt_size = pd()->with_bias() - ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; - - const auto &jcp = pd()->jcp_; - assert(jcp.ch_block == 1); - assert(jcp.nb_ch_blocking == 1); - assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); - assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); - - const float *oscales = pd()->attr()->output_scales_.scales_; - if (jcp.signed_input && jcp.ver != ver_vnni) { - auto local_scales = scratchpad(ctx).template get( - key_conv_adjusted_scales); - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / pd()->jcp_.wei_adj_scale; - if (count == 1) { - utils::array_set(local_scales, oscales[0] * factor, 16); - } else { - for (size_t c = 0; c < count; c++) - local_scales[c] = oscales[c] * factor; - } - oscales = local_scales; - } - - size_t offset = weights_d.size() - weights_d.additional_buffer_size(); - auto w = const_cast(weights); - int32_t* compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) : 0; - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; - int nb_groups = jcp.nb_ch; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; - - parallel(0, [&](const int ithr, const int nthr) { - - int start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - auto p = jit_conv_call_s(); - - size_t src_h_stride = src_d.blk_off(0, 0, 1); - size_t dst_h_stride = dst_d.blk_off(0, 0, 1); - size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - - int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 }; - switch (jcp.loop_order) { - case loop_cwgn: - nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, - nb_groups, n, jcp.mb, oh_s, jcp.oh); - break; - case loop_ngcw: - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, - owb, jcp.nb_ow, oh_s, jcp.oh); - break; - case loop_nhwcg: - nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, - occ, oc_chunks, g, nb_groups); - break; - default: assert(!"unsupported loop order"); - } - while (start < end) { - for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; - occ1 += jcp.nb_oc_blocking) { - int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; - int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; - - int g_ic = g * jcp.nb_ic * jcp.ic_block; - - int work_rem = end - start; - int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; - int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; - if (jcp.loop_order == loop_nhwcg) - oh_e = oh_s + 1; // step instead - int ow_s = owb * jcp.ow_block; - int iw_s = ow_s * jcp.stride_w; - - auto bias_w = bias - ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) - : 0; - int32_t *compensation_w = (jcp.signed_input) - ? compensation + g_oc : 0; - - auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s); - auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); - auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); - - auto scales = &oscales[jcp.is_oc_scale * g_oc]; - - for (int oj = oh_s, ij = ih_s; oj < oh_e; - ++oj, ij += jcp.stride_h) { - int dilate_h = jcp.dilate_h + 1; - int i_t_overflow = nstl::min(jcp.kh, - div_up(max(0, -ij), dilate_h)); - int i_b_overflow = nstl::min(jcp.kh, div_up( - max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), - dilate_h)); - int kh_padding = nstl::max(0, - jcp.kh - i_t_overflow - i_b_overflow); - - size_t wei_stride = (!jcp.signed_input) - ? i_t_overflow * wht_h_stride : 0; - p.src = src_w + i_t_overflow * dilate_h * src_h_stride; - p.dst = dst_w; - p.filt = wht_w + wei_stride; - p.bias = bias_w; - p.compensation = compensation_w; - p.oc_blocks = ocb; - p.kh_padding = kh_padding; - p.scales = scales; - p.t_overflow = i_t_overflow; - p.b_overflow = i_b_overflow; - p.owb = owb; - - kernel_->jit_ker(&p); - src_w += src_h_stride * jcp.stride_h; - dst_w += dst_h_stride; - } - } - switch (jcp.loop_order) { - case loop_cwgn: - nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g, - nb_groups, n, jcp.mb, oh_s, jcp.oh); - break; - case loop_ngcw: - nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, - oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); - break; - case loop_nhwcg: - ++start; - nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, - oc_chunks, g, nb_groups); - break; - default: assert(!"unsupported loop order"); - } - } - }); -} - -template -void jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const size_t bia_dt_size = pd()->with_bias() - ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; - - const auto &jcp = pd()->jcp_; - assert(jcp.ic_block == 1); - assert(jcp.oc_block == 1); - assert(jcp.nb_ic == 1); - assert(jcp.nb_oc == 1); - assert(jcp.nb_oc_blocking == 1); - assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); - - const float *oscales = pd()->attr()->output_scales_.scales_; - if (jcp.signed_input && jcp.ver != ver_vnni) { - auto local_scales = scratchpad(ctx).template get( - key_conv_adjusted_scales); - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / pd()->jcp_.wei_adj_scale; - if (count == 1) { - utils::array_set(local_scales, oscales[0] * factor, 16); - } else { - for (size_t c = 0; c < count; c++) - local_scales[c] = oscales[c] * factor; - } - oscales = local_scales; - } - - size_t offset = weights_d.size() - weights_d.additional_buffer_size(); - auto w = const_cast(weights); - int32_t* compensation = (jcp.signed_input) - ? reinterpret_cast(&w[offset]) : 0; - int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; - int group_block = jcp.ch_block; - - parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, - [&](int n, int oh_s, int owb, int gg) { - - auto p = jit_conv_call_s(); - - size_t src_h_stride = src_d.blk_off(0, 0, 1); - size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - - int gb = gg * jcp.nb_ch_blocking; - int g = gb * group_block; - - int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; - int ow_s = owb * jcp.ow_block; - int iw_s = ow_s * jcp.stride_w; - - auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; - int32_t *compensation_w = jcp.signed_input ? compensation + g : 0; - - auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s); - auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); - auto wht_w = weights + wht_blk_off(weights_d, gb, 0); - - auto scales = &oscales[jcp.is_oc_scale * g]; - - int dilate_h = jcp.dilate_h + 1; - int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); - int i_b_overflow = nstl::min(jcp.kh, - div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), - dilate_h)); - int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); - - size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride; - p.src = src_w + i_t_overflow * dilate_h * src_h_stride; - p.dst = dst_w; - p.filt = wht_w + wei_stride; - p.bias = bias_w; - p.compensation = compensation_w; - p.oc_blocks = gb; - p.kh_padding = kh_padding; - p.scales = scales; - p.t_overflow = i_t_overflow; - p.b_overflow = i_b_overflow; - p.owb = owb; - - kernel_->jit_ker(&p); - }); -} - -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::s8, data_type::u8>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::u8, data_type::u8>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::s8, data_type::s8>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::u8, data_type::s8>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::s8, data_type::s32>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::u8, data_type::s32>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::s8, data_type::f32>; -template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< - data_type::u8, data_type::f32>; -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp deleted file mode 100644 index 203ebdf94..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP -#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""), - jit_avx512_core_x8s8s32x_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, data_type::s8, data_type::undef, - dst_type, data_type::s32) - && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type, - data_type::f32, data_type::s32, data_type::s8, - data_type::u8)) - && !has_zero_dim_memory(); - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf( - jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, - *attr(), mkldnn_get_max_threads()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad, - jcp_, *attr()); - - return status; - } - - jit_conv_conf_t jcp_; - }; - - jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - { - kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_, - *pd()->attr()); - } - - ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override - { - const auto &_pd = pd(); - if (_pd->ndims() == 3) - execute_forward_1d(ctx); - else if (_pd->jcp_.is_depthwise) - execute_forward_2d_dw(ctx); - else - execute_forward_2d(ctx); - return status::success; - } - -private: - void execute_forward_1d(const exec_ctx_t &ctx) const; - void execute_forward_2d(const exec_ctx_t &ctx) const; - void execute_forward_2d_dw(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_avx512_core_x8s8s32x_fwd_kernel *kernel_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp deleted file mode 100644 index 142af1f54..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp +++ /dev/null @@ -1,1034 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_avx512_core_x8s8s32x_deconvolution.hpp" - -#define GET_OFF(field) offsetof(jit_deconv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; -using namespace Xbyak; - -using namespace nstl; - -#define wht_blk_off(d, g, ...) \ - (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \ - (d).blk_off(__VA_ARGS__)) - -status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( - jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, - memory_desc_t &src_md, memory_desc_t &weights_md, - memory_desc_t &dst_md, const bool with_bias, - memory_desc_t &bias_md, const primitive_attr_t &attr) { - const memory_desc_wrapper src_d(&src_md); - const memory_desc_wrapper dst_d(&dst_md); - const memory_desc_wrapper weights_d(&weights_md); - const memory_desc_wrapper bias_d(&bias_md); - - if (!(mayiuse(avx512_core) - && one_of(src_d.data_type(), data_type::u8, data_type::s8) - && weights_d.data_type() == data_type::s8 - && one_of(dst_d.data_type(), data_type::f32, data_type::s32, - data_type::s8, data_type::u8))) - return status::unimplemented; - - jcp = zero(); - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - jcp.signed_input = src_d.data_type() == data_type::s8; - const int ndims = jcp.ndims = dst_d.ndims(); - const bool is_1d = ndims == 3; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; - jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; - jcp.is_depthwise = true && with_groups - && utils::everyone_is(1, jcp.ic_without_padding, - jcp.oc_without_padding); - - /* TODO: future work, on hold until depthwise specialized kernel is - * implemented. */ - if (jcp.is_depthwise && jcp.signed_input) - return status::unimplemented; - - format_tag_t dat_tag = utils::pick(ndims - 3, - format_tag::nwc, format_tag::nhwc); - - if (src_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(src_md, dat_tag)); - jcp.src_tag = dat_tag; - } else { - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - } - if (jcp.src_tag != dat_tag) - return status::unimplemented; - - if (dst_d.format_kind() == format_kind::any) { - CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); - jcp.dst_tag = dat_tag; - } else { - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - } - if (jcp.dst_tag != dat_tag) - return status::unimplemented; - - auto set_or_check_wei_format = [&]() { - using namespace format_tag; - - format_tag_t wei_tag = is_1d - ? (jcp.is_depthwise - ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i)) - : (jcp.is_depthwise - ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i)); - - memory_desc_t want_wei_md = weights_md; - memory_desc_init_by_tag(want_wei_md, wei_tag); - if (jcp.signed_input && !jcp.is_depthwise) { - want_wei_md.extra.flags = 0 - | memory_extra_flags::compensation_conv_s8s8 - | memory_extra_flags::scale_adjust; - want_wei_md.extra.compensation_mask = (1 << 0) - + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); - want_wei_md.extra.scale_adjust = - mayiuse(avx512_core_vnni) ? 1.f : 0.5f; - } - - if (weights_md.format_kind == format_kind::any) { - weights_md = want_wei_md; - return true; - } - - return weights_md == want_wei_md; - }; - - if (!set_or_check_wei_format()) - return status::unimplemented; - - jcp.with_bias = with_bias; - if (jcp.with_bias) { - if (bias_d.format_kind() == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); - } - - jcp.prop_kind = cd.prop_kind; - jcp.mb = src_d.dims()[0]; - jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; - jcp.ow = dst_d.dims()[ndims - 1]; - jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; - jcp.l_pad = cd.padding[0][ndims - 3]; - jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; - jcp.stride_w = cd.strides[ndims - 3]; - - if (jcp.is_depthwise) { - jcp.ch_block = 16; - jcp.oc_block = 1; - jcp.ic_block = 1; - } else { - jcp.ch_block = 1; - jcp.oc_block = 16; - jcp.ic_block = 16; - - if (jcp.ngroups == 1) { - jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); - jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); - } - if (jcp.ic % jcp.ic_block != 0) - return status::unimplemented; - } - - jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; - jcp.dilate_w = cd.dilates[ndims - 3]; - - if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) - || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) - return status::unimplemented; - - /* padding: bottom and right */ - jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) - - (jcp.oh + jcp.t_pad - 1); - jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.ow + jcp.l_pad - 1); - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - jcp.ver = ver_avx512_core; - if (mayiuse(avx512_core_vnni)) - jcp.ver = ver_vnni; - const auto &oscales = attr.output_scales_; - jcp.is_oc_scale = oscales.mask_ == 1 << 1; - - assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); - - jcp.dst_dt = dst_d.data_type(); - jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; - jcp.typesize_bia - = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; - jcp.typesize_in = types::data_type_size(src_d.data_type()); - jcp.typesize_out = types::data_type_size(dst_d.data_type()); - - jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); - jcp.nb_oc = jcp.oc / jcp.oc_block; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - /* kernel blocking params */ - const int regs = jcp.ver == ver_vnni ? 30 : 28; - jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); - for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) - if (jcp.nb_oc % jcp.nb_oc_blocking == 0 - && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) - break; - - jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); - int l_overflow = max( - 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); - - if (jcp.ow < jcp.ur_w) { - jcp.ur_w = jcp.ow; - jcp.ur_w_tail = 0; - } else { - for (; jcp.ur_w >= 1; jcp.ur_w--) { - /* ur_w should be multiple of stride_w in order - to simplify logic for get_ow_start and get_ow_end */ - bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; - - /* boundary conditions: - These conditions ensure all elements close to boundary - are computed in a single call of compute loop */ - bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w; - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - int r_overflow_no_tail - = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - - max(0, jcp.r_pad) - jcp.ur_w_tail) - / jcp.stride_w); - bool right_boundary_covered - = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; - - if (is_multiple_of_stride && left_boundary_covered - && right_boundary_covered) - break; - else if (jcp.ur_w == 1) - /* The boundary conditions above are also important - to maintain simplicity of calls to icb_loop, - if those conditions are not satisfied, - then special cases will need to be added - to use correct l_overflow/r_overflow values - when different iterations of compute loop - work on the locations close to boundary. - So to keep code simple, return unimplemented - for extreme case when a good ur_w cannot be found. - */ - return status::unimplemented; - } - } - - jcp.wei_adj_scale = - (weights_d.extra().flags | memory_extra_flags::scale_adjust) - ? weights_d.extra().scale_adjust : 1.f; - - jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; - return status::success; -} - -bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) { - using namespace primitive_kind; - const auto &p = attr_.post_ops_; - - if (position == 0) { - /* eltwise before sum */ - return p.contain(eltwise, 0); - } else if (position == 1) { - /* eltwise after sum */ - return p.contain(sum, 0) && p.contain(eltwise, 1); - } - return false; -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) { - int nb_oc_block - = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; - eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w); -} - -bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - using namespace primitive_kind; - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - - switch (p.len_) { - case 0: return true; - case 1: return is_eltwise(0) || p.contain(sum, 0); - case 2: - return (p.contain(sum, 0) && is_eltwise(1)) - || (p.contain(sum, 1) && is_eltwise(0)); - default: return false; - } - - return false; -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, - const primitive_attr_t &attr) { - if (jcp.signed_input && jcp.ver != ver_vnni) { - dim_t count = nstl::max(attr.output_scales_.count_, 16); - scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); - } -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w, - int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, - bool h_padded) { - - const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; - const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w; - - auto src_offset = [=](int oj, int icb, int ki) { - return jcp.typesize_in - * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) - * jcp.ngroups * jcp.ic_without_padding - + icb * 4); - }; - - auto kernel_offset = [=](int ocb, int icb, int ki) { - return jcp.typesize_in - * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all - + icb * jcp.oc_block * jcp.ic_block / 4 - + ki * ch_block_all); - }; - - auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) { - if (jcp.ver == ver_vnni) { - vpdpbusd(vreg_acc, vreg_src, vreg_wei); - } else if (jcp.is_depthwise) { - vpmulld(zmm_tmp, vreg_src, vreg_wei); - vpaddd(vreg_acc, vreg_acc, zmm_tmp); - } else { - vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); - vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); - vpaddd(vreg_acc, vreg_acc, zmm_tmp); - } - }; - - for (int ki = 0; ki < jcp.kw; ki++) { - - int jj_start = get_ow_start(ki, l_overflow); - int jj_end = get_ow_end(ur_w, ki, r_overflow); - - int _start = (jcp.signed_input) ? 0 : jj_start; - int _end = (jcp.signed_input) ? ur_w : jj_end; - - int tail_size = jcp.ic_without_padding % 4; - int n_ic_blocks = jcp.is_depthwise ? - 1 : - (last_ic_block_flag & ~no_last_block ? - div_up(jcp.ic_without_padding % jcp.ic_block, - 4) : - jcp.ic_block / 4); - - for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { - if (h_padded == true) { - /* fill padded area with shifted values */ - Zmm inp = zmm_inp(0, jcp.nb_oc_blocking); - vpxord(inp, inp, inp); - vpsubb(inp, inp, zmm_shift); - } else { - - for (int jj = _start; jj < _end; jj += ur_w_stride) { - - int aux_src_off = src_offset(jj, icb1, ki); - - if (jj >= jj_start && jj < jj_end - && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) { - if (jcp.is_depthwise) { - vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking), - EVEX_compress_addr( - aux_reg_src, aux_src_off)); - } else if ((last_ic_block_flag & last_sp_block) - && tail_size != 0 && icb1 == n_ic_blocks - 1) { - xmm_t xmm_tmp = xmm_t( - zmm_inp(jj, jcp.nb_oc_blocking).getIdx()); - for (int r = 0; r < tail_size; ++r) - vpinsrb(xmm_tmp, xmm_tmp, - ptr[aux_reg_src + aux_src_off + r], r); - vpbroadcastd( - zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp); - } else { - vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), - EVEX_compress_addr( - aux_reg_src, aux_src_off)); - } - if (jcp.signed_input) - vpsubb(zmm_inp(jj, jcp.nb_oc_blocking), - zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift); - } else { - /* fill padded area with shifted values */ - if (jcp.signed_input) { - Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking); - vpxord(inp, inp, inp); - vpsubb(inp, inp, zmm_shift); - } - } - } - } - for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { - int aux_filt_off = kernel_offset(ocb, icb1, ki); - - if (_end - _start > 0) { - if (jcp.is_depthwise) - vpmovsxbd(zmm_wei, - EVEX_compress_addr(aux_reg_filt, aux_filt_off)); - else - vmovups(zmm_wei, - EVEX_compress_addr(aux_reg_filt, aux_filt_off)); - } - for (int jj = _start; jj < _end; jj += ur_w_stride) { - Zmm inp = (h_padded == true) ? - zmm_inp(0, jcp.nb_oc_blocking) : - zmm_inp(jj, jcp.nb_oc_blocking); - compute(zmm_out(jj, ocb), zmm_wei, inp); - } - } - } - } -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w, - int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { - - int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; - int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw - * jcp.ngroups * jcp.ic_without_padding; - const int stride_h = jcp.signed_input ? 1 : jcp.stride_h; - int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h; - - Label kh_loop_label, skip_kh_loop; - Label t_overflow_label, no_t_overflow_label, b_overflow_label, - no_b_overflow_label; - - mov(aux_reg_src, reg_src); - mov(aux_reg_filt, reg_filt); - - if (jcp.signed_input && jcp.ndims > 3) { - /* Weights are transposed, so first compute 'bottom' padding. */ - mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); - cmp(reg_overflow, 0); - je(no_b_overflow_label, T_NEAR); - L(b_overflow_label); { - compute_ker(ur_w, 0, 0, last_ic_block_flag, true); - - add(aux_reg_filt, shift_filt_kh); - dec(reg_overflow); - cmp(reg_overflow, 0); - jg(b_overflow_label, T_NEAR); - } - L(no_b_overflow_label); - } - - mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); - - if (jcp.signed_input || ((!jcp.signed_input) - && ((min(jcp.t_pad, jcp.b_pad) < 0) - || ((jcp.kh - 1) * (jcp.dilate_h + 1) - < nstl::max(jcp.t_pad, jcp.b_pad))))) { - cmp(reg_kh, 0); - je(skip_kh_loop, T_NEAR); - } - - L(kh_loop_label); { - compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); - sub(aux_reg_src, shift_src_ih); - add(aux_reg_filt, shift_filt_kh); - dec(reg_kh); - - /* Insert weight compensation in stride 'holes' */ - if (jcp.signed_input && jcp.stride_h > 1) { - Label kh_comp_loop; - - cmp(reg_kh, 0); - je(skip_kh_loop, T_NEAR); - mov(reg_comp_strides, jcp.stride_h - 1); - L(kh_comp_loop); - { - compute_ker( - ur_w, 0, 0, last_ic_block_flag, true); - add(aux_reg_filt, shift_filt_kh); - dec(reg_comp_strides); - cmp(reg_comp_strides, 0); - jg(kh_comp_loop, T_NEAR); - } - } - cmp(reg_kh, 0); - jg(kh_loop_label, T_NEAR); - } - L(skip_kh_loop); - if (jcp.signed_input && jcp.ndims > 3) { - mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); - cmp(reg_overflow, 0); - je(no_t_overflow_label, T_NEAR); - L(t_overflow_label); { - compute_ker(ur_w, 0, 0, last_ic_block_flag, true); - - add(aux_reg_filt, shift_filt_kh); - dec(reg_overflow); - cmp(reg_overflow, 0); - jg(t_overflow_label, T_NEAR); - } - L(no_t_overflow_label); - } -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) { - for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { - for (int ur = 0; ur < ur_w; ur++) { - zmm_t zmm = zmm_out(ur, ocb); - vpxord(zmm, zmm, zmm); - } - } - if (jcp.signed_input) { - xor_(reg_scratch, reg_scratch); - Reg8 _t8 = reg_scratch.cvt8(); - mov(_t8, (int8_t)-128); - vpbroadcastb(zmm_shift, _t8); - } -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps( - data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) { - zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; - switch (type_in) { - case data_type::f32: - case data_type::s32: vmovups(zmm, op); break; - case data_type::s8: vpmovsxbd(zmm, op); break; - case data_type::u8: vpmovzxbd(zmm, op); break; - default: assert(!"unsupported data type"); - } - if (type_in != data_type::f32) - vcvtdq2ps(zmm_in, zmm_in); -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( - int ur_w, bool last_oc_block) { - mov(reg_bias, ptr[param1 + GET_OFF(bias)]); - mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); - - if (jcp.signed_input) - mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); - - const auto &p = attr_.post_ops_; - const int sum_idx = p.find(primitive_kind::sum); - const float *p_sum_scale - = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; - if (p_sum_scale && *p_sum_scale != 1.f) - mov(reg_ptr_sum_scale, (size_t)p_sum_scale); - - if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) { - mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); - vmovq(xmm_bias_alpha(), reg_bias_alpha); - vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); - } - - for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { - const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; - int scale_offset - = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); - - auto zmm_bias = zmm_tmp; - if (jcp.with_bias) { - int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; - auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); - cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); - if (jcp.signed_input && jcp.ver != ver_vnni) - vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); - } - if (jcp.signed_input) { - int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block; - auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); - cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag); - } - - for (int ur = 0; ur < ur_w; ur++) { - zmm_t zmm = zmm_out(ur, ocb); - vcvtdq2ps(zmm, zmm); - if (jcp.signed_input) - vaddps(zmm, zmm, zmm_comp); - if (jcp.with_bias) - vaddps(zmm, zmm, zmm_bias); - zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm; - vmulps(mask_zmm, zmm, - EVEX_compress_addr(reg_ptr_scales, scale_offset)); - } - } - if (maybe_eltwise(0)) - compute_eltwise(ur_w); - if (p_sum_scale) { // post_op: sum - for (int k = 0; k < jcp.nb_oc_blocking; k++) { - const bool mask_flag - = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1; - for (int j = 0; j < ur_w; j++) { - int aux_output_offset - = jcp.typesize_out - * (k * jcp.oc_block - + j * jcp.oc_without_padding * jcp.ngroups); - auto addr = EVEX_compress_addr(reg_dst, aux_output_offset); - Zmm zmm = zmm_out(j, k); - cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); - if (*p_sum_scale == 1.f) - vaddps(zmm, zmm_prev_dst); - else - vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); - } - } - } - if (maybe_eltwise(1)) - compute_eltwise(ur_w); - - for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { - const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; - for (int ur = 0; ur < ur_w; ur++) { - zmm_t zmm = zmm_out(ur, ocb); - if (jcp.dst_dt == data_type::u8) { - vpxord(zmm_zero, zmm_zero, zmm_zero); - vmaxps(zmm, zmm_zero, zmm); - } - if (jcp.dst_dt != data_type::f32) - vcvtps2dq(zmm, zmm); - } - for (int ur = 0; ur < ur_w; ur++) { - int aux_dst_off = jcp.typesize_out - * (ur * jcp.ngroups * jcp.oc_without_padding - + ocb * jcp.oc_block); - auto addr = EVEX_compress_addr(reg_dst, aux_dst_off); - - zmm_t zmm = zmm_out(ur, ocb); - zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm; - switch (jcp.dst_dt) { - case data_type::f32: - case data_type::s32: vmovups(addr, r_zmm); break; - case data_type::s8: vpmovsdb(addr, r_zmm); break; - case data_type::u8: vpmovusdb(addr, r_zmm); break; - default: assert(!"unknown dst_dt"); - } - } - } -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop( - int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { - - int shift_src_icb = jcp.typesize_in * jcp.ic_block; - int shift_filt_icb - = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block; - - prepare_output(ur_w); - - Label skip_icb_loop, icb_loop_label; - - mov(reg_icb, jcp.nb_ic); - L(icb_loop_label); { - - if (jcp.ic_without_padding != jcp.ic) { - Label common_ker, end_ker; - cmp(reg_icb, 1); - jg(common_ker, T_NEAR); - - kh_loop(ur_w, l_overflow, r_overflow, - is_last_sp_block ? last_sp_block : last_ic_block); - jmp(end_ker, T_NEAR); - - L(common_ker); - kh_loop(ur_w, l_overflow, r_overflow, no_last_block); - - L(end_ker); - } else { - kh_loop(ur_w, l_overflow, r_overflow, no_last_block); - } - - add(reg_src, shift_src_icb); - add(reg_filt, shift_filt_icb); - dec(reg_icb); - cmp(reg_icb, 0); - jg(icb_loop_label, T_NEAR); - } - - /* come-back pointers */ - sub(reg_src, jcp.nb_ic * shift_src_icb); - sub(reg_filt, jcp.nb_ic * shift_filt_icb); - L(skip_icb_loop); - - if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { - Label common_store, end_store; - mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); - if (jcp.is_depthwise) - cmp(reg_oc_blocks, jcp.nb_ch - 1); - else - cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); - jne(common_store, T_NEAR); - - store_output(ur_w, true); - jmp(end_store, T_NEAR); - - L(common_store); - store_output(ur_w, false); - - L(end_store); - - } else { - store_output(ur_w, false); - } -} - -void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { - preamble(); - - xor_(reg_scratch, reg_scratch); - Reg16 _t = reg_scratch.cvt16(); - mov(_t, 0x1); - vpbroadcastw(zmm_one, _t); - - if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { - int tail_size = jcp.is_depthwise ? - jcp.ngroups % jcp.ch_block : - jcp.oc_without_padding % jcp.oc_block; - int mask = (1 << tail_size) - 1; - Reg32 regw_tmp = reg_nur_w.cvt32(); - mov(regw_tmp, mask); - kmovw(ktail_mask, regw_tmp); - } - - mov(reg_src, ptr[param1 + GET_OFF(src)]); - mov(reg_filt, ptr[param1 + GET_OFF(filt)]); - mov(reg_dst, ptr[param1 + GET_OFF(dst)]); - - int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups - * jcp.oc_without_padding; - int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups - * jcp.ic_without_padding; - - int l_overflow = max( - 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); - int r_overflow - = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)) - / jcp.stride_w); - - int r_overflow1 - = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) - / jcp.stride_w); - int nur_w = jcp.ow / jcp.ur_w; - if (r_overflow1 > 0) - nur_w--; - - if (jcp.ur_w == jcp.ow) { - icb_loop(jcp.ur_w, l_overflow, r_overflow, true); - } else if (nur_w == 0) { - icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - if (jcp.ur_w_tail != 0) - icb_loop(jcp.ur_w_tail, 0, r_overflow, true); - } else { - xor_(reg_nur_w, reg_nur_w); - if (l_overflow > 0) { - icb_loop(jcp.ur_w, l_overflow, 0, false); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - inc(reg_nur_w); - } - if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) { - Label ow_loop_label; - L(ow_loop_label); - { - icb_loop(jcp.ur_w, 0, 0, false); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - inc(reg_nur_w); - cmp(reg_nur_w, nur_w); - jl(ow_loop_label, T_NEAR); - } - } - if (r_overflow1 > 0) { - icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0); - add(reg_src, src_shift); - add(reg_dst, dst_shift); - } - if (jcp.ur_w_tail != 0) { - icb_loop(jcp.ur_w_tail, 0, r_overflow, true); - } - } - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -template -void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - auto &jcp = kernel_->jcp; - - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int nb_groups = jcp.nb_ch; - - const float *oscales = pd()->attr()->output_scales_.scales_; - if (jcp.signed_input && jcp.ver != ver_vnni) { - auto local_scales - = scratchpad(ctx).template get(key_conv_adjusted_scales); - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / pd()->jcp_.wei_adj_scale; - if (count == 1) { - utils::array_set(local_scales, oscales[0] * factor, 16); - } else { - for (size_t c = 0; c < count; c++) - local_scales[c] = oscales[c] * factor; - } - oscales = local_scales; - } - size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; - auto w = const_cast(weights); - int32_t *compensation - = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; - - parallel(0, [&](const int ithr, const int nthr) { - int start{ 0 }, end{ 0 }; - int work_amount = jcp.mb * nb_groups * oc_chunks; - balance211(work_amount, nthr, ithr, start, end); - - auto p = jit_deconv_call_s(); - - int n{ 0 }, g{ 0 }, occ{ 0 }; - if (jcp.loop_order == loop_ngc) - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); - else if (jcp.loop_order == loop_cgn) - nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); - else - assert(!"unsupported loop order"); - while (start < end) { - - int ocb = occ * jcp.nb_oc_blocking; - int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; - int g_ic = g * jcp.ch_block * jcp.ic; - - p.dst = dst + dst_d.blk_off(n, g_oc); - p.src = src + src_d.blk_off(n, g_ic); - p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); - p.bias = jcp.with_bias ? - bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : - 0; - p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; - p.scales = &oscales[jcp.is_oc_scale * g_oc]; - p.t_overflow = 0; - p.b_overflow = 0; - p.kh_padding = jcp.kh; - p.oc_blocks = jcp.is_depthwise ? g : ocb; - - kernel_->jit_ker(&p); - - ++start; - if (jcp.loop_order == loop_ngc) - nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); - else if (jcp.loop_order == loop_cgn) - nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); - else - assert(!"unsupported loop order"); - } - }); -} - -template -void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - auto &jcp = kernel_->jcp; - - int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; - int nb_groups = jcp.nb_ch; - - size_t src_h_stride = src_d.blk_off(0, 0, 1); - size_t dst_h_stride = dst_d.blk_off(0, 0, 1); - size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); - - const float *oscales = pd()->attr()->output_scales_.scales_; - if (jcp.signed_input && jcp.ver != ver_vnni) { - auto local_scales - = scratchpad(ctx).template get(key_conv_adjusted_scales); - size_t count = pd()->attr()->output_scales_.count_; - float factor = 1.f / pd()->jcp_.wei_adj_scale; - if (count == 1) { - utils::array_set(local_scales, oscales[0] * factor, 16); - } else { - for (size_t c = 0; c < count; c++) - local_scales[c] = oscales[c] * factor; - } - oscales = local_scales; - } - size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; - auto w = const_cast(weights); - int32_t *compensation - = (jcp.signed_input) ? reinterpret_cast(&w[offset]) : 0; - - parallel(0, [&](const int ithr, const int nthr) { - int start{ 0 }, end{ 0 }; - int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; - balance211(work_amount, nthr, ithr, start, end); - - auto p = jit_deconv_call_s(); - - /*loop order = cgn*/ - int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }; - if (jcp.loop_order == loop_ngc) - nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, - oh_s, jcp.oh); - else if (jcp.loop_order == loop_cgn) - nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, - oh_s, jcp.oh); - else - assert(!"unsupported loop order"); - while (start < end) { - - int ocb = occ * jcp.nb_oc_blocking; - int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; - int g_ic = g * jcp.ch_block * jcp.ic; - int work_rem = end - start; - int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; - - auto dst_w = dst + dst_d.blk_off(n, g_oc); - auto src_w = src + src_d.blk_off(n, g_ic); - auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); - auto bias_w = jcp.with_bias ? - bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : - 0; - int32_t *compensation_w - = (jcp.signed_input) ? compensation + g_oc : 0; - - auto scales = &oscales[jcp.is_oc_scale * g_oc]; - for (int oj = oh_s; oj < oh_e; oj++) { - int ih_max = 0, kh_lo = 0, kh_len = 0; - if (jcp.dilate_h != 0 && jcp.stride_h == 1) { - /* dilation */ - int dilate_h = jcp.dilate_h + 1; - // Note: use div_up to account for "holes" in filter - int o_t_overflow = div_up( - max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), - dilate_h); - int o_b_overflow - = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh - + oj - jcp.b_pad), - dilate_h); - kh_len = jcp.kh - o_t_overflow - o_b_overflow; - kh_lo = o_b_overflow; - ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; - } else { - int o_t_overflow = max( - 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); - int o_b_overflow - = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) - / jcp.stride_h); - int overflow_kh_hi = jcp.kh - 1 - - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h; - int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; - - kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h - + 1 - o_t_overflow - o_b_overflow; - kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; - ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; - } - - int wei_stride - = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0; - p.src = src_w + ih_max * src_h_stride; - p.dst = dst_w + oj * dst_h_stride; - p.filt = wht_w + wei_stride; - p.bias = bias_w; - p.compensation = compensation_w; - p.t_overflow = max( - 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h - + 1)); - p.b_overflow = kh_lo; - p.kh_padding = kh_len; - p.scales = scales; - p.oc_blocks = jcp.is_depthwise ? g : ocb; - kernel_->jit_ker(&p); - } - if (jcp.loop_order == loop_ngc) - nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, - oc_chunks, oh_s, jcp.oh); - else if (jcp.loop_order == loop_cgn) - nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, - jcp.mb, oh_s, jcp.oh); - else - assert(!"unsupported loop order"); - } - }); -} - -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t; -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp deleted file mode 100644 index 901038fa4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp +++ /dev/null @@ -1,237 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP -#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "cpu_primitive.hpp" -#include "cpu_memory.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "nstl.hpp" - -#include "cpu_deconvolution_pd.hpp" -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -typedef enum { - no_last_block = 0x1U, - last_ic_block = 0x2U, - last_sp_block = 0x4U, -} ker_block_t; - -struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); - - jit_avx512_core_x8s8s32x_deconv_fwd_kernel( - const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32( - this, jcp.eltwise); - generate(); - jit_ker = (void (*)(jit_deconv_call_s *))getCode(); - } - - ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { - delete eltwise_injector_; - } - - static bool post_ops_ok(jit_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_conv_conf_t &jcp, - const deconvolution_desc_t &cd, - memory_desc_t &src_md, - memory_desc_t &weights_md, - memory_desc_t &dst_md, - const bool with_bias, - memory_desc_t &bias_md, - const primitive_attr_t &attr); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp, const primitive_attr_t &attr); - - const jit_conv_conf_t &jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_deconv_call_s *); -private: - jit_uni_eltwise_injector_f32 *eltwise_injector_; - using reg64_t = const Xbyak::Reg64; - using zmm_t = const Xbyak::Zmm; - using xmm_t = const Xbyak::Xmm; - - reg64_t reg_src = r8; - reg64_t reg_filt = r9; - reg64_t reg_dst = r10; - reg64_t param1 = abi_param1; - reg64_t reg_kh = abi_not_param1; - reg64_t reg_nur_w = rbx; - reg64_t reg_bias = rdx; - reg64_t reg_icb = reg_bias; - reg64_t reg_ptr_scales = rax; - reg64_t reg_oc_blocks = rsi; - - reg64_t aux_reg_src = r11; - reg64_t aux_reg_filt = r12; - - reg64_t reg_compensation = r14; - reg64_t reg_scratch = r14; - reg64_t reg_ptr_sum_scale = r11; - reg64_t reg_bias_alpha = abi_not_param1; - reg64_t reg_overflow = rax; - reg64_t reg_comp_strides = reg_overflow; - - Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); - zmm_t zmm_tmp = zmm_t(28); - zmm_t zmm_one = zmm_t(29); - /* used during write-out section of store_output */ - zmm_t zmm_zero = zmm_t(31); - zmm_t zmm_wei = zmm_t(31); - - /* signed input */ - zmm_t zmm_shift = zmm_t(30); - zmm_t zmm_comp = zmm_t(30); - zmm_t zmm_bias = zmm_t(31); - zmm_t zmm_prev_dst = zmm_t(31); - - zmm_t zmm_out(int i_ur, int i_oc) { - int idx = i_ur * jcp.nb_oc_blocking + i_oc; - assert(idx < 31); - return zmm_t(idx); - } - zmm_t zmm_inp(int i_ic, int nb_x_blocking) { - int idx = i_ic + nb_x_blocking * jcp.ur_w; - assert(idx < 31); - return zmm_t(idx); - } - zmm_t zmm_bias_alpha() { - return zmm_t(jcp.nb_oc_blocking * jcp.ur_w); - } - xmm_t xmm_bias_alpha() { - return xmm_t(jcp.nb_oc_blocking * jcp.ur_w); - } - - int get_ow_start(int ki, int l_overflow) { - int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w - + l_overflow * jcp.stride_w - - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); - while (res < 0) - res += jcp.stride_w; - return res; - } - - int get_ow_end(int ur_w, int ki, int r_overflow) { - if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) - ur_w += nstl::min(0, jcp.r_pad); // remove negative padding - int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w - + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); - while (res < 0) - res += jcp.stride_w; - return ur_w - res; - } - bool maybe_eltwise(int position); - void compute_eltwise(int ur_w); - void prepare_output(int ur_w); - void store_output(int ur_w, bool last_oc_block); - void compute_ker(int ur_w, int l_overflow, int r_overflow, - ker_block_t last_ic_block_flag, bool h_padded = false); - void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); - void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); - void generate(); - void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, - bool mask_flag); -}; - -template -struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_deconvolution_fwd_pd_t { - using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""), - _jit_avx512_core_x8s8s32x_deconvolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && (desc()->alg_kind & alg_kind::deconvolution_direct) - && desc()->src_desc.data_type == src_type - && desc()->dst_desc.data_type == dst_type - && IMPLICATION(with_bias(), utils::one_of( - desc()->bias_desc.data_type, data_type::f32, - data_type::s32, data_type::s8, data_type::u8)) - && desc()->accum_data_type == data_type::s32; - if (!ok) return status::unimplemented; - - status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: - init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, - with_bias(), bias_md_, *attr()); - - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad, - jcp_, *attr()); - - return status::success; - } - - jit_conv_conf_t jcp_; - }; - - _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - { - kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_, - *pd()->attr()); - } - - ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; } - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if(pd()->ndims() == 3) - execute_forward_1d(ctx); - else - execute_forward_2d(ctx); - return status::success; - } - -private: - void execute_forward_1d(const exec_ctx_t &ctx) const; - void execute_forward_2d(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp deleted file mode 100644 index c09592d5c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp +++ /dev/null @@ -1,773 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_AVX2_GENERATOR_HPP -#define CPU_JIT_AVX2_GENERATOR_HPP - -#include - -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_isa_traits.hpp" -#include "jit_utils/jit_utils.hpp" - -#if defined(_WIN32) && !defined(__GNUC__) -# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__ -#else -# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al))) -#endif - -#if defined(_WIN32) -# define OFFSET_SHADOWSPACE 0x28 -#endif - -#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \ - const char *name() const override { return STRINGIFY(jit_name); } \ - const char *source_file() const override { return __FILE__; } - -namespace mkldnn { -namespace impl { -namespace cpu { - -// TODO: move this to jit_generator class? -namespace { - -typedef enum { - PAGE_4K = 4096, - PAGE_2M = 2097152, -} cpu_page_size_t; - -// TODO: move this somewhere else? Although this is only used by jit kernels -// (Roma) -static inline int float2int(float x) { - union { - float vfloat; - int vint; - } cvt; - cvt.vfloat = x; - return cvt.vint; -} - -// TODO: A GPR class that hides ABI details from the JIT kernels and allows -// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and -// stack register (sr). -// -// This will allow using syntax like this: -// -// param = gpr0; -// reg_input = gpr0; -// reg_output = gpr1; -// ... -// -// #ifndef XBYAK64 -// mov(param, ptr[sr]) -// #endif -// -// (Roma) - -#ifdef XBYAK64 -constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = { - Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, - Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15, -#ifdef _WIN32 - Xbyak::Operand::RDI, Xbyak::Operand::RSI, -#endif -}; - -#ifdef _WIN32 -static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX), - abi_param2(Xbyak::Operand::RDX), - abi_param3(Xbyak::Operand::R8), - abi_param4(Xbyak::Operand::R9), - abi_not_param1(Xbyak::Operand::RDI); -#else -static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI), - abi_param2(Xbyak::Operand::RSI), - abi_param3(Xbyak::Operand::RDX), - abi_param4(Xbyak::Operand::RCX), - abi_param5(Xbyak::Operand::R8), - abi_param6(Xbyak::Operand::R9), - abi_not_param1(Xbyak::Operand::RCX); -#endif -#endif - -inline unsigned int get_cache_size(int level, bool per_core = true){ - unsigned int l = level - 1; - // Currently, if XByak is not able to fetch the cache topology - // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core. - if (cpu.getDataCacheLevels() == 0){ - const int L1_cache_per_core = 32000; - const int L2_cache_per_core = 512000; - const int L3_cache_per_core = 1024000; - int num_cores = per_core ? 1 : mkldnn_get_max_threads(); - switch(l){ - case(0): return L1_cache_per_core * num_cores; - case(1): return L2_cache_per_core * num_cores; - case(2): return L3_cache_per_core * num_cores; - default: return 0; - } - } - if (l < cpu.getDataCacheLevels()) { - return cpu.getDataCacheSize(l) - / (per_core ? cpu.getCoresSharingDataCache(l) : 1); - } else - return 0; -} - -} - -class jit_generator : public Xbyak::CodeGenerator -{ -private: - const size_t xmm_len = 16; -#ifdef _WIN32 - const size_t xmm_to_preserve_start = 6; - const size_t xmm_to_preserve = 10; -#else - const size_t xmm_to_preserve_start = 0; - const size_t xmm_to_preserve = 0; -#endif - - const size_t num_abi_save_gpr_regs - = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); - - const size_t size_of_abi_save_regs - = num_abi_save_gpr_regs * rax.getBit() / 8 - + xmm_to_preserve * xmm_len; - -public: - enum { - _cmp_eq_oq = 0u, - _cmp_lt_os = 1u, - _cmp_le_os = 2u, - _cmp_neq_uq = 4u, - _cmp_nlt_us = 5u, - _cmp_nle_us = 6u, - - _op_floor = 1u, - _op_mxcsr = 4u, - }; - - Xbyak::Reg64 param1 = abi_param1; - const int EVEX_max_8b_offt = 0x200; - const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; - - inline size_t get_size_of_abi_save_regs() { - return size_of_abi_save_regs; - } - - void preamble() { - if (xmm_to_preserve) { - sub(rsp, xmm_to_preserve * xmm_len); - for (size_t i = 0; i < xmm_to_preserve; ++i) - movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i)); - } - for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) - push(Xbyak::Reg64(abi_save_gpr_regs[i])); - if (mayiuse(avx512_common)) { - mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); - } - } - - void mic_prefetcht0(Xbyak::Address a) { - if (mayiuse(avx512_mic)) - prefetcht0(a); - } - - void mic_prefetcht1(Xbyak::Address a) { - if (mayiuse(avx512_mic)) - prefetcht1(a); - } - - void mic_prefetcht2(Xbyak::Address a) { - if (mayiuse(avx512_mic)) - prefetcht2(a); - } - - void uni_vzeroupper() { - if (mayiuse(avx) && !mayiuse(avx512_mic)) - vzeroupper(); - } - - void postamble() { - for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) - pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i])); - if (xmm_to_preserve) { - for (size_t i = 0; i < xmm_to_preserve; ++i) - movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]); - add(rsp, xmm_to_preserve * xmm_len); - } - uni_vzeroupper(); - ret(); - } - - template - Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, - T raw_offt, bool bcast = false) - { - using Xbyak::Zmm; - using Xbyak::Reg64; - using Xbyak::Address; - using Xbyak::RegExp; - - assert(raw_offt <= INT_MAX); - auto offt = static_cast(raw_offt); - - int scale = 0; - - if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { - offt = offt - 2 * EVEX_max_8b_offt; - scale = 1; - } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { - offt = offt - 4 * EVEX_max_8b_offt; - scale = 2; - } - - auto re = RegExp() + base + offt; - if (scale) - re = re + reg_EVEX_max_8b_offt * scale; - - if (bcast) - return zword_b [re]; - else - return zword [re]; - } - - Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt, - const Xbyak::Reg64 &tmp_reg, bool bcast = false) { - if (offt > INT_MAX) { - mov(tmp_reg, offt); - return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg]; - } else { - return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt]; - } - } - - Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base, - size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) { - if (raw_offt > INT_MAX) { - return make_safe_addr(base, raw_offt, reg_offt, bcast); - } else { - return EVEX_compress_addr(base, raw_offt, bcast); - } - } - - void safe_add(const Xbyak::Reg64 &base, size_t raw_offt, - const Xbyak::Reg64 ®_offt) { - if (raw_offt > INT_MAX) { - mov(reg_offt, raw_offt); - add(base, reg_offt); - } else { - add(base, raw_offt); - } - } - - void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt, - const Xbyak::Reg64 ®_offt) { - if (raw_offt > INT_MAX) { - mov(reg_offt, raw_offt); - sub(base, reg_offt); - } else { - sub(base, raw_offt); - } - } - - // Disallow char-based labels completely - void L(const char *label) = delete; - void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } - - void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - pxor(x2, op); - } - void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op) { - if (mayiuse(avx2)) { - vpxor(x1, x2, op); - } else { - vxorps(x1, x2, op); - } - } - void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, - const Xbyak::Operand &op) { - vpxord(x1, x2, op); - } - - void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) { - movss(addr, x); - } - void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) { - vmovss(addr, x); - } - void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) { - movss(x, addr); - } - void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) { - vmovss(x, addr); - } - - void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) { - movsd(addr, x); - } - void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) { - vmovsd(addr, x); - } - void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) { - movsd(x, addr); - } - void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) { - vmovsd(x, addr); - } - - void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) { - movdqu(addr, x); - } - void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) { - vmovdqu(addr, x); - } - void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) { - vmovdqu32(addr, x); - } - - void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) { - movdqu(x, addr); - } - void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { - vmovdqu(x, addr); - } - void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) { - vmovdqu32(x, addr); - } - - void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { - movups(addr, x); - } - void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) { - vmovups(addr, x); - } - - void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - movups(x, op); - } - void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - vmovups(x, op); - } - - void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { - movntps(addr, x); - } - void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) { - vmovntps(addr, x); - } - - void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - movss(x, op); - shufps(x, x, 0x0); - } - void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - if (op.isMEM() || mayiuse(avx2)) { - vbroadcastss(x, op); - } else { - Xbyak::Xmm t(x.getIdx()); - if (t.getIdx() != op.getIdx()) movss(t, op); - vinsertf128(x, x, t, 1); - vshufps(x, x, x, 0); - } - } - - void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - movsd(x, op); - pshufd(x, x, 0x0); - } - void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - if (mayiuse(avx2)) { - vpbroadcastd(x, op); - } else { - Xbyak::Xmm t(x.getIdx()); - if (t.getIdx() != op.getIdx()) movsd(t, op); - vinsertf128(x, x, t, 1); - vshufps(x, x, x, 0); - } - } - - void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - rcpss(x, op); - } - void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { - Xbyak::Xmm x1_(x1.getIdx()); - Xbyak::Xmm x2_(x2.getIdx()); - vrcpss(x1_, x1_, x2_); - } - void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) { - Xbyak::Xmm x_(x.getIdx()); - vrcpss(x_, x_, op); - } - - void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - rcpps(x, op); - } - void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - vrcpps(x, op); - } - void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) { - vrcp14ps(x, op); - } - - void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - assert(x.getIdx() == op1.getIdx()); - divps(x, op2); - } - void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - vdivps(x, op1, op2); - } - - void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { - movups(buf, op1); - divps(buf, op2); - if (x.getIdx() != buf.getIdx()) { - movups(x, buf); - } - } - - void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { - vdivps(x, op1, op2); - } - - void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - assert(x.getIdx() == op1.getIdx()); - addps(x, op2); - } - void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - vaddps(x, op1, op2); - } - - void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, - const Xbyak::Operand& op) { - assert(x1.getIdx() == x2.getIdx()); - psignd(x1, op); - } - void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2, - const Xbyak::Operand& op) { - vpsignd(x1, x2, op); - } - - void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - assert(x.getIdx() == op1.getIdx()); - subps(x, op2); - } - void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - vsubps(x, op1, op2); - } - - void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { - movups(buf, op1); - subps(buf, op2); - if (x.getIdx() != buf.getIdx()) { - movups(x, buf); - } - } - - void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { - vsubps(x, op1, op2); - } - - void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - assert(x.getIdx() == op1.getIdx()); - mulps(x, op2); - } - void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - vmulps(x, op1, op2); - } - - void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - mulps(x1, x2); - addps(x1, op); - } - void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op) { - vfmadd213ps(x1, x2, op); - } - - void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - mulps(x2, op); - addps(x1, x2); - } - void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op) { - vfmadd231ps(x1, x2, op); - } - - void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - mulps(x2, op); - subps(x1, x2); - } - - void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op) { - vfnmadd231ps(x1, x2, op); - } - - void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - sqrtps(x, op); - } - void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - vsqrtps(x, op); - } - - void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - paddd(x2, op); - } - void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - vpaddd(x1, x2, op); - } - - void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op = Xbyak::Operand()) { - assert(x1.getIdx() == x2.getIdx()); - andps(x1, op); - } - void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op = Xbyak::Operand()) { - if (!mayiuse(avx512_common) || x1.getBit() < 512) - vandps(x1, x2, op); - else - vpandd(x1, x2, op); - } - - void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op = Xbyak::Operand()) { - assert(x1.getIdx() == x2.getIdx()); - orps(x1, op); - } - void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op = Xbyak::Operand()) { - if (!mayiuse(avx512_common) || x1.getBit() < 512) - vorps(x1, x2, op); - else - vpord(x1, x2, op); - } - - void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op, - const int imm) { - assert(x.getIdx() == op.getIdx()); - pslld(x, imm); - } - void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op, - const int imm) { - vpslld(x, op, imm); - } - - void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op, - const int imm) { - assert(x.getIdx() == op.getIdx()); - psrld(x, imm); - } - void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op, - const int imm) { - vpsrld(x, op, imm); - } - - void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - assert(x.getIdx() == op1.getIdx()); - maxps(x, op2); - } - void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - vmaxps(x, op1, op2); - } - - void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - assert(x.getIdx() == op1.getIdx()); - minps(x, op2); - } - void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, - const Xbyak::Operand &op2 = Xbyak::Operand()) { - vminps(x, op1, op2); - } - - void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - cmpps(x1, op, _cmp_nle_us); - } - - void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op) { - vcmpgtps(x1, x2, op); - } - - void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op) { - assert(x1.getIdx() == x2.getIdx()); - cmpps(x1, op, _cmp_nlt_us); - } - - void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op) { - vcmpps(x1, x2, op, _cmp_nlt_us); - } - - void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { - ptest(x1, op); - } - - void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { - assert(!(x1.isZMM() || op.isZMM())); - vtestps(x1, op); - } - - void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, - const Xbyak::Operand &op, const Xbyak::Xmm &msk) { - assert(x1.getIdx() == x2.getIdx()); - assert(msk.getIdx() == 0); - blendvps(x1, op); - } - void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, - const Xbyak::Operand &op, const Xbyak::Ymm &msk) { - vblendvps(x1, x2, op, msk); - } - - void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op, - const int imm) { - roundps(x, op, imm); - } - void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op, - const int imm) { - vroundps(x, op, imm); - } - - void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - cvtps2dq(x, op); - } - void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - vcvtps2dq(x, op); - } - - void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { - cvtdq2ps(x, op); - } - void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { - vcvtdq2ps(x, op); - } - - void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) { - movmskps(x1.cvt64(), x2); - } - void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) { - vmovmskps(x1, x2); - } - - void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ - assert(x1.getIdx() == x1.getIdx()); - packssdw(x1, op); - } - void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ - vpackssdw(x1, x2, op); - } - - void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ - assert(x1.getIdx() == x1.getIdx()); - packuswb(x1, op); - } - void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ - vpackuswb(x1, x2, op); - } - - - void mul_by_const(const Xbyak::Reg &out, - const Xbyak::Reg64 &tmp, int value) { - // Generates a shift + add sequence for multiplicating contents of the - // out register by a known JIT-time value. Clobbers the tmp register. - // - // Pros compared to mul/imul: - // - does not require using known registers - // - not microcoded on Intel(R) Xeon Phi(TM) processors - // Still, there are probably a lot of cases when mul/imul is faster on - // Intel(R) Core(TM) processors. Not intended for critical path. - - // TODO: detect when overflow is emminent (Roma) - // TODO: detect when using mul/imul is a better option (Roma) - - int p = 0; // the current power of 2 - int old_p = 0; // the last seen power of 2 such that value[old_p] != 0 - - xor_(tmp, tmp); - while (value) { - if (value & 1) { - int shift = p - old_p; - if (shift) { - shl(out, shift); - old_p = p; - } - add(tmp, out); - } - value >>= 1; - p++; - } - mov(out, tmp); - } - -public: - jit_generator( - void *code_ptr = nullptr, - size_t code_size = 256 * 1024 - ) : Xbyak::CodeGenerator(code_size, code_ptr) - { - } - virtual ~jit_generator() {} - - virtual const char *name() const = 0; - virtual const char *source_file() const = 0; - - const Xbyak::uint8 *getCode() { - const Xbyak::uint8 *code = CodeGenerator::getCode(); - size_t code_size = getSize(); - jit_utils::register_jit_code(code, code_size, name(), source_file()); - return code; - } - - template const F getCode() { - return (const F)getCode(); - } -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp deleted file mode 100644 index 56d7f592e..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp +++ /dev/null @@ -1,481 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_PRIMITIVE_CONF_HPP -#define JIT_PRIMITIVE_CONF_HPP - -#include - -#include "common/primitive_attr.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -/* convolution */ -enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; -enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, - loop_ngcw, loop_nhwcg, loop_nwcg}; -enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, - loop_brl}; -enum conv_kernel_kind_t {embd_bcast, expl_bcast}; - -enum { - FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, - FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, - FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, - FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, - FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, - FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips - loading weights-data from memory; this - needs to happen on the first Group/16 - iteration. */ - FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip - loading bias data from memory */ - FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution - pass */ -}; - -struct jit_conv_conf_t { - prop_kind_t prop_kind; - conv_version_t ver; - conv_loop_order_t loop_order; - - int simd_w; - int ndims; - int mb; - int ngroups, ic, oc, oc_without_padding, ic_without_padding; - int id, ih, iw, od, oh, ow; - int f_pad, l_pad, t_pad; - int back_pad, r_pad, b_pad; - int kd, kh, kw; - int stride_d, stride_h, stride_w; - int dilate_d, dilate_h, dilate_w; - format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround - bool with_bias; - bool with_sum; - bool with_eltwise; - - post_ops_t::entry_t::eltwise_t eltwise; - - int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; - - int idp, ihp, iwp, ohp, owp; - int nb_ic, ic_block; - int nb_oc, oc_block; - int nb_ow, ow_block; - int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking - into account vector registers distribution */ - int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work - within threads */ - int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work - int nb_ic_L2; - int h_blocking; - int nb_oc_L2; - int ur_h, ur_w; - int ur_w_tail; - bool is_1stconv; - int nonblk_group_off; - /* fma avx512_core */ - conv_kernel_kind_t kernel_kind; - /* 4fma */ - int tr_iw; - int tr_src_num_guard_elems; - /* 1st conv: 4fma */ - int tr_ld; - int kh_step; - /* 4vnni */ - int typesize_in; - int typesize_out; - int typesize_bia; - int typesize_acc; - /* avx512_u8s8u8 */ - int ic_nb1, ic_nb2; - int oc_nb1; - int ur_ow_max, ur_ow, ur_ow_tail; - int ur_ow_nsteps; - data_type_t bia_dt; - data_type_t dst_dt; - /* avx512: max possible value is nregs(32) - aux_regs(4) */ - int src_offsets[28]; - int src_count; - bool expl_bcast; - bool large_spatial; - int is_oc_scale; - int max_regs_ur; // maximum accumulation registers - // dw conv - int nb_ch, ch_block, nb_ch_blocking; - bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; - int aligned_threads; - // large spatial - int oh_blk_size; - // s8s8 convolution - bool signed_input; - float wei_adj_scale; -}; - -struct jit_conv_conf_2x3_wino_t { - conv_version_t ver; - - int m; - int r; - int alpha; - int tile_h, tile_w; - - int mb; - int ngroups, ic, oc, oc_without_padding; - int ih, iw, oh, ow; - int l_pad, t_pad; - int r_pad, b_pad; - int kh, kw; - int stride_h, stride_w; - int dilate_h, dilate_w; - - int nb_ic, ic_block; - int nb_oc, oc_block; - - int w_block_size, h_block_size; - - data_type_t bia_dt; - data_type_t dst_dt; - - int is_oc_scale; - int typesize_in; - int typesize_out; - int typesize_bia; - int typesize_acc; - - format_tag_t src_tag, dst_tag; // temporary workaround - bool with_bias; - bool small_mb; - - int xb, yb; - int inp_stride; - int out_stride; - int wei_stride; - int bia_stride; - - int M, N, K; - int m_block, n_block, k_block; - int n2_block, n_chunks; - int k2_block, k_chunks; - - int mb_block, nb_mb; - - size_t size_wino_src, size_wino_wei, size_wino_dst; - - int nthr; -}; - -/* - Winograd sched policy: - - Computation Unit: - W: weights transform - S: src transform - D: dst transform - G: gemm - - Thread grouping by: - i: nb_ic - o: nb_oc - t: tile_block - e: element in tile - - Note: 'i' and 'o' are omited if - i. not comblined with t or - ii. with discrete transforms - - Current policies supported: -*/ -enum winograd_sched_t { - WSCHED_INVALID = 0, - - /* Forward & backward-data */ - /* W_S_G_D implements discrete transforms */ - WSCHED_DATA_W_S_G_D, - /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ - WSCHED_DATA_W_SGD, - - /* Backward-weights */ - WSCHED_WEI_S_D_G_W, - WSCHED_WEI_SDGtWo, - WSCHED_WEI_S_D_Giot_W, - WSCHED_WEI_SDGt_W, -}; - -struct jit_conv_winograd_conf_t : public jit_conv_conf_t { - int itiles; - int jtiles; - int ntiles; - int ic_simd_block=16; - int tile_4fma_padding; - int tile_4fma; - int oc_simd_block=16; - int oc_reg_block; - int ic_reg_block; - int tile_block; - int tile_block_ur; - int nb_tile_block_ur; - - bool double_buffering; - bool with_relu_postsum; - int zmm_start; - int nb_reg; - - int dimK; - int dimK_4fma; - int dimK_reg_block; - int dimK_block; - int dimK_nb_block; - - int dimM; - int dimM_reg_block; - int dimM_simd_block; - int dimM_block; - int dimM_nb_block; - - int dimN; - int dimN_reg_block; - int dimN_bcast_ur; - int dimN_block; - int dimN_nb_block; - - winograd_sched_t sched_policy; -}; - -struct jit_conv_call_s { - const void *src; /* hack, non-const for backward_data */ - const void *dst; /* hack, non-const for forward */ - const void *filt; /* hack, non-const for backward_weights */ - const void *bias; /* hack, non-const for backward_bias */ - const void *src_prf; - const void *dst_prf; - const void *filt_prf; - const void *bias_prf; - const void *scales; - const void *acc_s32; - const void *compensation; - size_t kd_offset; - size_t kd_offset_prf; - size_t d_index; - size_t d_index_prf; - size_t d_worksize; - size_t d_worksize_prf; - size_t kd_padding; - size_t kd_padding_prf; - size_t kh_padding; - size_t kh_padding_prf; - size_t owb; - size_t owb_prf; - size_t kw_padding; - size_t channel; - size_t channel_prf; - size_t oc_blocks; - size_t ur_w; - size_t ur_str_w; - size_t ch_blocks; - size_t t_overflow; - size_t b_overflow; - int flags; -}; - -struct jit_deconv_call_s { - const void *src; /* hack, non-const for backward_data */ - const void *dst; /* hack, non-const for forward */ - const void *filt; /* hack, non-const for backward_weights */ - const void *bias; /* hack, non-const for backward_bias */ - const void *scales; - const void *compensation; - size_t t_overflow; - size_t b_overflow; - size_t kh_padding; - size_t oc_blocks; -}; - -struct jit_dw_conv_call_s { - const void *input; - const void *output; - const void *filter; - const void *bias; - size_t kh_count; - size_t oh_count; - size_t oh_index; - size_t filter_pad_off; - unsigned char - exec_flags; /* Flags passed by driver execution to inner kernel */ -}; - -struct jit_wino_transform_call_s { - size_t tile_block; - size_t tile_block_ur; - size_t nb_tile_block_ur; - size_t tile_count; - size_t tj; - size_t ti; - void *src; - void *dst; - void *Mw; - void *M; - void *T; - void *G; - void *bias; -}; - -struct jit_1x1_conv_conf_t { - prop_kind_t prop_kind; - conv_version_t ver; - - int mb; - int ngroups, ic, oc, oc_without_padding, ic_without_padding; - int iw, ih, ow, oh; - int l_pad, t_pad; - int kh, kw; - int stride_h, stride_w; - format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround - bool with_bias; - bool with_sum; - bool with_eltwise; - - post_ops_t::entry_t::eltwise_t eltwise; - - int is, os; - int ic_block, oc_block; - - int ur, ur_tail; - - int reduce_dim, reduce_block, nb_reduce, - nb_reduce_blocking, nb_reduce_blocking_max; - int load_dim, load_block, nb_load, - nb_load_blocking, nb_load_blocking_max, nb_load_chunk; - int bcast_dim, bcast_block, nb_bcast, - nb_bcast_blocking, nb_bcast_blocking_max; - - int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; - int load_loop_load_step, load_loop_iter_step; - int bcast_loop_output_step, bcast_loop_output_substep; - int bcast_loop_bcast_step, bcast_loop_bcast_substep; - int fma_step; - int load_grp_count; - conv_1x1_loop_order_t loop_order; - bool use_vmovntps; - /* avx512 core */ - bool expl_bcast; - /* 4vnni */ - int typesize_in; - int typesize_out; - int typesize_bia; - int typesize_acc; - /* 4fma */ - bool transpose_src; - int tr_is; - int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; - int is_oc_scale; - data_type_t bia_dt; - data_type_t dst_dt; - bool signed_input; - float wei_adj_scale; -}; - -struct jit_gemm_conv_conf_t { - prop_kind_t prop_kind; - - int mb; - int ngroups, ic, oc; - int iw, ih, id, ow, oh, od; - int l_pad, t_pad, f_pad; - int kh, kw, kd; - int stride_h, stride_w, stride_d; - int dilate_h, dilate_w, dilate_d; - bool with_bias; - - int is, os, ks; - int ic_block, oc_block; - - int nthr; - ptrdiff_t im2col_sz; - bool need_wei_reduction; - bool signed_input; - int oh_block; - int ow_block; - bool outer_threading; -}; - -struct jit_1x1_conv_call_s { - const void *bcast_data; - const void *load_data; - const void *output_data; - const void *bias_data; // used in forward and backward_weights only - const void *acc_s32; - const void *scales; - const void *compensation; - - size_t load_dim; - size_t bcast_dim; - size_t reduce_dim; - - size_t output_stride; // used in backward_weights only - - size_t first_last_flag; -}; - -/* pooling */ -struct jit_pool_conf_t { - int ndims; - int mb, c; - int id, ih, iw, od, oh, ow; - int stride_d, stride_h, stride_w; - int kd, kh, kw; - int f_pad, t_pad, l_pad; - alg_kind_t alg; - bool is_training; - bool pad_w_is_null; - bool is_backward; - bool simple_alg; - data_type_t ind_dt; - - int c_block, c_tail, nb_c; - int ur_c, ur_c_tail; - int ur_w; - int ur_w_tail; - size_t tail[4]; - data_type_t src_dt; - data_type_t dst_dt; -}; - -struct jit_pool_call_s { - const float *src; - const float *dst; - const void *indices; - const float *src_prf; - const float *dst_prf; - const void *indices_prf; - size_t oh; - size_t kd_padding; - size_t kh_padding; - size_t kh_padding_shift; - size_t kd_padding_shift; - size_t kw_padding; - const float* init_value; - float ker_area_h; -}; - - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp deleted file mode 100644 index 94d2101d6..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp +++ /dev/null @@ -1,677 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "cpu_memory.hpp" - -#include "jit_sse42_1x1_conv_kernel_f32.hpp" - -#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) -{ - mov(aux1_reg_bcast_data, reg_bcast_data); - mov(aux_reg_output_data, reg_output_data); - mov(bcast_loop_iter, reg_bcast_loop_work); - - Label bcast_loop; - Label bcast_loop_tail; - - cmp(bcast_loop_iter, jcp.ur); - jl(bcast_loop_tail, T_NEAR); - - L(bcast_loop); { - assert(jcp.bcast_block % jcp.ur == 0); - int num_substeps = jcp.bcast_block / jcp.ur; - assert(num_substeps > 0 && num_substeps < 10); - for (int i = 0; i < num_substeps; i++) { - generate_reduce_loop(load_loop_blk, jcp.ur); - if (i < num_substeps - 1) { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_substep); - } else { - add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step - - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); - add(aux_reg_output_data, jcp.bcast_loop_output_step - - (num_substeps - 1) * jcp.bcast_loop_output_substep); - } - } - sub(bcast_loop_iter, jcp.bcast_block); - cmp(bcast_loop_iter, jcp.bcast_block); - jge(bcast_loop, T_NEAR); - } - - L(bcast_loop_tail); - if (jcp.ur_tail) { - Label bcast_loop_tail_out; - cmp(bcast_loop_iter, 0); - jz(bcast_loop_tail_out, T_NEAR); - generate_reduce_loop(load_loop_blk, jcp.ur_tail); - L(bcast_loop_tail_out); - } -} - -void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop( - int load_loop_blk, int ur) -{ - auto reg_load = [=](int i, int n) { - return Xmm(2*ur * load_loop_blk + 2*i + n + 1); - }; - - auto reg_accum = [=](int i, int j, int n) { - return Xmm(2*j * load_loop_blk + 2*i + n + 1); - }; - - auto bias_ptr = [=](int i, int n) { - return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)]; - }; - - auto bcast_ptr = [=](int u, int j) { - assert(j < jcp.ur); - assert(u <= jcp.reduce_loop_unroll); - size_t offt; - if (one_of(jcp.prop_kind, - forward_training, forward_inference, backward_data)) { - assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) - ? jcp.oc_block : jcp.ic_block); - auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; - offt = (u == jcp.reduce_loop_unroll) - ? (height + j) * jcp.reduce_loop_unroll - : j * jcp.reduce_loop_unroll + u; - } else - offt = u * jcp.ic_block + j; - return ptr[aux_reg_bcast_data + sizeof(float) * offt]; - }; - - auto load_ptr = [=](int u, int i, int n) { - size_t offt; - size_t u0 = u % jcp.reduce_loop_unroll; - size_t u1 = u / jcp.reduce_loop_unroll; - switch (jcp.prop_kind) { - case backward_data: - offt = (i * jcp.oc_block + u0) * jcp.ic_block; - break; - case backward_weights: - offt = (i * jcp.os + u0) * jcp.oc_block; - break; - default: - offt = (i * jcp.ic + u0) * jcp.oc_block; - } - return ptr[aux_reg_load_data - + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)]; - }; - - auto output_ptr = [=](int i, int j, int n) { - switch (jcp.prop_kind) { - case backward_data: - return ptr[aux_reg_output_data + - (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)]; - case backward_weights: - return ptr[aux_reg_output_data - + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale - + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)]; - default: - return ptr[aux_reg_output_data + - (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)]; - } - }; - - auto init = [=]() { - Label init_done; - Label init_zero; - - if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, - forward_inference)) { - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jz(init_zero); - - for (int i = 0; i < load_loop_blk; i++) - for (int j = 0; j < ur; ++j) { - movups(reg_accum(i, j, 0), bias_ptr(i, 0)); - movups(reg_accum(i, j, 1), bias_ptr(i, 1)); - } - jmp(init_done); - } - - L(init_zero); - for (int i = 0; i < load_loop_blk; ++i) - for (int j = 0; j < ur; ++j) { - auto r0 = reg_accum(i, j, 0); - auto r1 = reg_accum(i, j, 1); - xorps(r0, r0); - xorps(r1, r1); - } - - L(init_done); - - // load weights - for (int i = 0; i < load_loop_blk; ++i) { - movups(reg_load(i, 0), load_ptr(0, i, 0)); - movups(reg_load(i, 1), load_ptr(0, i, 1)); - } - - movss(reg_bcast, bcast_ptr(0, 0)); - shufps(reg_bcast, reg_bcast, 0); - }; // init() - - auto store = [=]() { - Label store_noadd; - - if (!jcp.with_sum) { - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jnz(store_noadd, T_NEAR); - } - - for (int j = 0; j < ur; ++j) - for (int i = 0; i < load_loop_blk; ++i) { - auto r0 = reg_accum(i, j, 0); - auto r1 = reg_accum(i, j, 1); - addps(r0, output_ptr(i, j, 0)); - addps(r1, output_ptr(i, j, 1)); - } - - L(store_noadd); - - if (jcp.with_eltwise) { - assert(ur * load_loop_blk < 14); - - Label store_norelu; - test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); - jz(store_norelu, T_NEAR); - - eltwise_injector_->compute_vector_range(1, - 2 * ur * load_loop_blk + 1); - - L(store_norelu); - } - - for (int j = 0; j < ur; ++j) - for (int i = 0; i < load_loop_blk; ++i) { - movups(output_ptr(i, j, 0), reg_accum(i, j, 0)); - movups(output_ptr(i, j, 1), reg_accum(i, j, 1)); - } - }; - - auto fma_block = [=](bool last_block) { - for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { - for (int j = 0; j < ur; ++j) { - for (int i = 0; i < load_loop_blk; ++i) { - mulps(reg_load(i, 0), reg_bcast); - mulps(reg_load(i, 1), reg_bcast); - addps(reg_accum(i, j, 0), reg_load(i, 0)); - addps(reg_accum(i, j, 1), reg_load(i, 1)); - - if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) { - movups(reg_load(i, 0), load_ptr(u + 1, i, 0)); - movups(reg_load(i, 1), load_ptr(u + 1, i, 1)); - } - } - if (j < ur - 1) { - movss(reg_bcast, bcast_ptr(u, j + 1)); - shufps(reg_bcast, reg_bcast, 0); - } - } // for ur - if (!last_block || u < jcp.reduce_loop_unroll - 1) { - movss(reg_bcast, bcast_ptr(u + 1, 0)); - shufps(reg_bcast, reg_bcast, 0); - } - } // for reduce_loop_unroll - }; - - Label reduce_loop; - Label reduce_loop_tail; - - mov(aux_reg_load_data, reg_load_data); - mov(aux_reg_bcast_data, aux1_reg_bcast_data); - - init(); - - mov(reduce_loop_iter, reg_reduce_loop_work); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jle(reduce_loop_tail, T_NEAR); - - L(reduce_loop); { - fma_block(false); - add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); - add(aux_reg_load_data, jcp.reduce_loop_load_step); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jg(reduce_loop, T_NEAR); - } - - L(reduce_loop_tail); - fma_block(true); - - store(); -} // reduce_loop() - -void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) -{ - if (!jcp.with_bias || jcp.prop_kind != backward_weights) - return; - - Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; - Label diff_bias_load; - - auto diff_bias_ptr = [=](int i, int n) { - return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)]; - }; - - auto load_ptr = [=](int u, int i, int n) { - return ptr[aux_reg_load_data - + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)]; - }; - - auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); }; - - mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); - cmp(reg_diff_bias_data, 0); - je(diff_bias_loop_out, T_NEAR); - - test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); - jz(diff_bias_load, T_NEAR); - - for (int i = 0; i < load_loop_blk; ++i) { - auto r0 = diff_bias_reg(i, 0); - auto r1 = diff_bias_reg(i, 1); - xorps(r0, r0); - xorps(r1, r1); - } - jmp(diff_bias_init_out, T_NEAR); - - L(diff_bias_load); - for (int i = 0; i < load_loop_blk; ++i) { - movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0)); - movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1)); - } - - L(diff_bias_init_out); - mov(aux_reg_load_data, reg_load_data); - mov(reduce_loop_iter, reg_reduce_loop_work); - L(diff_bias_loop); { - for(int u = 0; u < jcp.reduce_loop_unroll; ++u) - for (int i = 0; i < load_loop_blk; ++i) { - addps(diff_bias_reg(i, 0), load_ptr(u, i, 0)); - addps(diff_bias_reg(i, 1), load_ptr(u, i, 1)); - } - assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); - add(aux_reg_load_data, jcp.reduce_loop_load_step); - sub(reduce_loop_iter, jcp.reduce_loop_unroll); - jnz(diff_bias_loop, T_NEAR); - } - - for (int i = 0; i < load_loop_blk; i++) { - movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0)); - movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1)); - } - - add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); - mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); - - L(diff_bias_loop_out); -} - -void jit_sse42_1x1_conv_kernel_f32::generate() -{ - preamble(); - - mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); - mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); - mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); - if (jcp.with_bias) { - if (jcp.prop_kind == backward_weights) { - sub(rsp, stack_space_needed); - mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); - mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); - } else - mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); - } - - mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); - mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); - mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); - mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); - if (jcp.prop_kind == backward_weights) - mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); - - auto generate_load_loop_body = [=] (int load_loop_blk) { - generate_bcast_loop(load_loop_blk); - add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); - switch (jcp.prop_kind) { - case forward_training: - case forward_inference: - add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); - add(reg_output_data, - load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); - break; - case backward_data: - add(reg_output_data, - load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); - break; - case backward_weights: - for (int i = 0; i < load_loop_blk; i++) - add(reg_output_data, reg_output_stride); - break; - default: - assert(!"invalid prop_kind"); - } - sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); - }; - - Label load_loop_blk_8; - Label load_loop_blk_16; - Label load_loop_blk_24; - Label load_loop_blk_end; - - cmp(reg_load_loop_work, 8); - jle(load_loop_blk_8, T_NEAR); - - cmp(reg_load_loop_work, 32); - je(load_loop_blk_16, T_NEAR); - - cmp(reg_load_loop_work, 16); - jle(load_loop_blk_16, T_NEAR); - - L(load_loop_blk_24); { - generate_diff_bias_loop(3); - generate_load_loop_body(3); - cmp(reg_load_loop_work, 32); - je(load_loop_blk_16); - cmp(reg_load_loop_work, 24); - jge(load_loop_blk_24); - } - - cmp(reg_load_loop_work, 8); - jle(load_loop_blk_8, T_NEAR); - - L(load_loop_blk_16); { - generate_diff_bias_loop(2); - generate_load_loop_body(2); - cmp(reg_load_loop_work, 16); - jge(load_loop_blk_16); - } - - L(load_loop_blk_8); { - cmp(reg_load_loop_work, 0); - je(load_loop_blk_end, T_NEAR); - generate_diff_bias_loop(1); - generate_load_loop_body(1); - } - - L(load_loop_blk_end); - - if (jcp.with_bias && jcp.prop_kind == backward_weights) - add(rsp, stack_space_needed); - - postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok( - jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr) -{ - if (!mayiuse(sse42)) - return status::unimplemented; - - // TODO (Roma): this code is duplicated from the generic kernel; maybe the - // configuration struct could do some stuff below - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - const int ndims = src_d.ndims(); - - jcp.prop_kind = cd.prop_kind; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; - jcp.ow = dst_d.dims()[ndims - 1]; - - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; - jcp.l_pad = cd.padding[0][ndims - 3]; - - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; - jcp.stride_w = cd.strides[ndims - 3]; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - jcp.os = jcp.oh * jcp.ow; - jcp.is = jcp.ih * jcp.iw; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - const int is_bwd_d = jcp.prop_kind == backward_data; - - format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; - format_tag_t wei_tag = with_groups - ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, - gOIhw8o8i) - : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, - OIhw8o8i); - - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.ngroups == 1 - && jcp.src_tag == dat_tag - && jcp.wei_tag == wei_tag - && jcp.dst_tag == dat_tag; - if (!args_ok) return status::unimplemented; - - const int simd_w = 4; - jcp.ic_block = jcp.oc_block = simd_w*2; - - args_ok = true - && jcp.oc % jcp.oc_block == 0 - && jcp.ic % jcp.ic_block == 0 - && jcp.t_pad == 0 && jcp.l_pad == 0 - && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides - && jcp.kh == 1 && jcp.kw == 1; - if (!args_ok) return status::unimplemented; - - jcp.ur = 1; - - int load_blocking{ 0 }; - int load_blocking_max{ 0 }; - int bcast_blocking{ 0 }; - int bcast_blocking_max{ 0 }; - int reduce_blocking{ 0 }; - - if (one_of(jcp.prop_kind, forward_training, forward_inference)) { - jcp.reduce_dim = jcp.ic; - jcp.reduce_block = jcp.ic_block; - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.is; - jcp.bcast_block = jcp.ur; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.is * sizeof(float); - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); - - jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); - jcp.bcast_loop_output_substep = -1; // unused - jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); - jcp.load_loop_iter_step = jcp.oc_block; - - load_blocking = 120; // assumes the kernel is jcp.ur x 3 - load_blocking_max = 144; - bcast_blocking = 128; // affects load balancing across threads - bcast_blocking_max = 192; - reduce_blocking = 128; // affects L1$ utilization - } else if (jcp.prop_kind == backward_data) { - jcp.reduce_dim = jcp.oc; - jcp.reduce_block = jcp.oc_block; - - jcp.load_dim = jcp.ic; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.os; - jcp.bcast_block = jcp.ur; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.os * sizeof(float); - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); - - jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); - jcp.bcast_loop_output_substep = -1; // unused - jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); - jcp.bcast_loop_bcast_substep = -1; // unused - - jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); - jcp.load_loop_iter_step = jcp.ic_block; - - load_blocking = 96; // assumes the kernel is jcp.ur x 3 - load_blocking_max = 144; - bcast_blocking = 128; // affects load balancing across threads - bcast_blocking_max = 196; - reduce_blocking = 64; // affects L1$ utilization - } else if (jcp.prop_kind == backward_weights) { - jcp.reduce_dim = jcp.os; - jcp.reduce_block = 1; - - jcp.load_dim = jcp.oc; - jcp.load_block = jcp.oc_block; - - jcp.bcast_dim = jcp.ic; - jcp.bcast_block = jcp.ic_block; - - jcp.reduce_loop_unroll = jcp.reduce_block; - jcp.reduce_loop_bcast_step - = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); - jcp.reduce_loop_load_step - = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); - - jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); - jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); - jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); - jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); - - jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); - jcp.load_loop_iter_step = jcp.oc_block; - - /* --- */ - - load_blocking = div_up(jcp.load_dim, jcp.load_block); - while (true) { - if (load_blocking <= 32) break; - else if (load_blocking % 2 == 0) load_blocking /= 2; - else if (load_blocking % 3 == 0) load_blocking /= 3; - else break; - } - load_blocking *= jcp.load_block; - load_blocking_max = load_blocking; - assert(jcp.load_dim % load_blocking == 0); - - bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); - while (true) { - if (bcast_blocking <= 9) break; - else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; - else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; - else break; - } - bcast_blocking *= jcp.bcast_block; - bcast_blocking_max = bcast_blocking; - assert(jcp.bcast_dim % bcast_blocking == 0); - - reduce_blocking = 128; // affects L1$ utilization - } else - return status::unimplemented; - - assert(load_blocking); - assert(load_blocking_max); - assert(bcast_blocking); - assert(bcast_blocking_max); - assert(reduce_blocking); - - assert(jcp.bcast_block % jcp.ur == 0); - jcp.ur_tail = jcp.bcast_dim % jcp.ur; - - jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; - jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; - jcp.nb_load_blocking = load_blocking / jcp.load_block; - jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; - jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; - - jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); - jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); - jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); - - return status::success; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp deleted file mode 100644 index b314a5098..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp +++ /dev/null @@ -1,104 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP -#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "cpu_memory.hpp" -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_sse42_1x1_conv_kernel_f32: public jit_generator { - jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, - const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, - jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); - } - - ~jit_sse42_1x1_conv_kernel_f32() { - delete eltwise_injector_; - } - - static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_1x1_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr); - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32) - - jit_1x1_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_1x1_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - using xmm_t = const Xbyak::Xmm; - - reg64_t reg_bcast_data = rax; - reg64_t reg_load_data = rsi; - reg64_t reg_output_data = rbx; - reg64_t aux_reg_bcast_data = rdx; - reg64_t aux1_reg_bcast_data = abi_not_param1; - reg64_t aux_reg_load_data = abi_param1; - reg64_t aux_reg_output_data = rbp; - reg64_t reg_load_loop_work = r9; - reg64_t reg_bcast_loop_work = r10; - reg64_t reg_reduce_loop_work = r11; - reg64_t load_loop_iter = r13; - reg64_t imm_addr64 = load_loop_iter; - reg64_t bcast_loop_iter = r14; - reg64_t reduce_loop_iter = r15; - reg64_t reg_reduce_pos_flag = r8; - reg64_t reg_output_stride = r12; - reg64_t reg_bias_data = r12; - reg64_t reg_diff_bias_data = bcast_loop_iter; - - int reg_diff_bias_data_stack_offt = 0; - int stack_space_needed = 8; - - xmm_t reg_bcast = xmm_t(15); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - void generate_bcast_loop(int load_loop_blk); - void generate_reduce_loop(int load_loop_blk, int ur); - void generate_diff_bias_loop(int load_loop_blk); - - void generate(); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp deleted file mode 100644 index 30c137641..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp +++ /dev/null @@ -1,134 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "jit_sse42_1x1_convolution.hpp" -#include "utils.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -#define data_blk_off(f, n, c, h, w) \ - ((ndims == 3) \ - ? (f).blk_off(n, c, w) \ - : (f).blk_off(n, c, h, w)) - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; - -void jit_sse42_1x1_convolution_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - const int ndims = src_d.ndims(); - - const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; - - parallel(0, [&](const int ithr, const int nthr) { - // TODO (Roma): remove this restriction - assert(jcp.stride_w == 1 && jcp.stride_h == 1); - - auto par_conv = jit_1x1_conv_call_s(); - - const int nb_oc = jcp.nb_load; - const int nb_ic = jcp.nb_reduce; - const int nb_ic_blocking = jcp.nb_reduce_blocking; - const int os_block = jcp.bcast_block; - - int start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - - int iwork = start; - while (iwork < end) { - int n{0}, g{0}, osb{0}; - nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, - jcp.nb_bcast); - - const int bcast_step_rem = jcp.nb_bcast - osb; - int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max - ? bcast_step_rem : jcp.nb_bcast_blocking; - bcast_step = nstl::min(bcast_step, end - iwork); - - const int os = osb * os_block; - const int ow = os % jcp.ow; - const int oh = os / jcp.ow; - const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0); - const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0); - - par_conv.bcast_dim = this_block_size(os, jcp.os, - bcast_step * os_block); - - int ocb = 0; - while (ocb < jcp.nb_load) { - const int load_step_rem = jcp.nb_load - ocb; - const int load_step = load_step_rem < jcp.nb_load_blocking_max - ? load_step_rem : jcp.nb_load_blocking; - - const size_t _ocb = g * nb_oc + ocb; - par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, - load_step * jcp.oc_block); - - const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); - par_conv.output_data = &dst[dst_off]; - - par_conv.bias_data = &bias[_ocb * jcp.oc_block]; - - for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { - par_conv.first_last_flag = 0 - | (icb == 0) * FLAG_REDUCE_FIRST - | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST; - - par_conv.reduce_dim = this_block_size(icb * jcp.ic_block, - jcp.ic, nb_ic_blocking * jcp.ic_block); - - const size_t _icb = g * nb_ic + icb; - const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw); - par_conv.bcast_data = &src[src_off]; - - par_conv.load_data = &weights[pd()->with_groups() - ? weights_d.blk_off(g, ocb, icb) - : weights_d.blk_off(ocb, icb)]; - - kernel_->jit_ker(&par_conv); - } - - ocb += load_step; - } - - iwork += bcast_step; - } - }); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp deleted file mode 100644 index b32b1e478..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp +++ /dev/null @@ -1,96 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP -#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" -#include "jit_sse42_1x1_conv_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""), - jit_sse42_1x1_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), - *src_md(), *weights_md(), *dst_md(), *attr()); - } - - jit_1x1_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) - : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) { - kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); - } - ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; }; - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_sse42_1x1_conv_kernel_f32 *kernel_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp deleted file mode 100644 index 17cabc118..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp +++ /dev/null @@ -1,497 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "cpu_memory.hpp" - -#include "jit_sse42_conv_kernel_f32.hpp" - -#define GET_OFF(field) offsetof(jit_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, - int pad_l, int pad_r, int oc_blocks) -{ - int iw = jcp.iw; - int ih = jcp.ih; - int kw = jcp.kw; - int kh = jcp.kh; - int nb_ic = jcp.nb_ic; - int stride_w = jcp.stride_w; - int dilate_w = jcp.dilate_w + 1; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - - for (int ki = 0; ki < kw; ki++) { - int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); - int jj_end = ur_w - - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w)); - for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { - for (int jj = jj_start; jj < jj_end; jj++) { - int inp_off; - if (one_of(jcp.src_tag, ncw, nchw)) - inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l); - else - inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2; - - movss(Xmm(oc_blocks * ur_w + jj + 1), - ptr[aux_reg_input + sizeof(float) * inp_off]); - shufps(Xmm(oc_blocks * ur_w + jj + 1), - Xmm(oc_blocks * ur_w + jj + 1), 0x0); - } - - for (int ii = 0; ii < oc_blocks; ii++) { - int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk - + ki * ic_blk * oc_blk + ifm2 * oc_blk; - - for (int jj = jj_start; jj < jj_end; jj++) - { - movups(xmm0, - ptr[aux_reg_kernel + sizeof(float) * ker_off]); - mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); - addps(Xmm(ur_w * ii + jj + 1), xmm0); - } - } - } - } -} - -void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, - int pad_l, int pad_r, int oc_blocks) -{ - Label kw_loop; - - int iw = jcp.iw; - int ih = jcp.ih; - int kw = jcp.kw; - int kh = jcp.kh; - int nb_ic = jcp.nb_ic; - int stride_w = jcp.stride_w; - int dilate_w = jcp.dilate_w + 1; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - - xor_(ki_iter, ki_iter); - L(kw_loop); - { - int jj_start = 0; - int jj_end = ur_w; - for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { - for (int jj = jj_start; jj < jj_end; jj++) { - int inp_off; - if (one_of(jcp.src_tag, ncw, nchw)) - inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l); - else - inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2; - - movss(Xmm(oc_blocks * ur_w + jj + 1), - ptr[aux_reg_input + sizeof(float) * inp_off]); - shufps(Xmm(oc_blocks * ur_w + jj + 1), - Xmm(oc_blocks * ur_w + jj + 1), 0x0); - } - for (int ii = 0; ii < oc_blocks; ii++) { - int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk - + ifm2 * oc_blk; - for (int jj = jj_start; jj < jj_end; jj++) { - movups(xmm0, - ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]); - mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); - addps(Xmm(ur_w * ii + jj + 1), xmm0); - } - } - } - add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); - add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ? - dilate_w : ic_blk * dilate_w)); - - inc(ki_iter); - cmp(ki_iter, kw); - jl(kw_loop, T_NEAR); - } -} - -void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w, - int pad_l, int pad_r, int oc_blocks) -{ - int iw = jcp.iw; - int kw = jcp.kw; - int ow = jcp.ow; - int oh = jcp.oh; - int dilate_h = jcp.dilate_h + 1; - int dilate_w = jcp.dilate_w + 1; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - const int inp_mult = one_of(jcp.src_tag, ncw, nchw) - ? dilate_h : ic_blk * dilate_h; - const int inp_off = one_of(jcp.src_tag, ncw, nchw) - ? dilate_w : ic_blk * dilate_w; - - xor_(simd_iter, simd_iter); - - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - - Label init_simd_iter_loop; - Label init_done; - Label init_first; - - L(init_simd_iter_loop); - - if (!jcp.with_sum) { - test(reg_ci_flag, FLAG_IC_FIRST); - jne(init_first, T_NEAR); - } - - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - movups(Xmm(ur_w * ii + jj + 1), xword[reg_output - + sizeof(float) * (ii * oh * ow + jj) * oc_blk]); - - if (jcp.with_sum && jcp.with_bias) { - test(reg_ci_flag, FLAG_IC_FIRST); - je(init_done, T_NEAR); - - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - addps(Xmm(ur_w * ii + jj + 1), - xword[reg_bias + sizeof(float) * ii * oc_blk]); - } - - jmp(init_done); - - L(init_first); - if (this->jcp.with_bias) { - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - movups(Xmm(ur_w * ii + jj + 1), - xword[reg_bias + sizeof(float) * ii * oc_blk]); - } else { - for (int ii = 0; ii < oc_blocks; ii++) - for (int jj = 0; jj < ur_w; jj++) - pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1)); - } - - L(init_done); - - Label skip_kh_loop; - mov(kj, reg_kh); - if ((jcp.dilate_h >= jcp.ih) - || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { - cmp(kj, 0); - je(skip_kh_loop, T_NEAR); - } - Label kh_loop; - L(kh_loop); - { - if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { - oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); - sub(aux_reg_input, sizeof(float) * kw * inp_off); - add(aux_reg_input, sizeof(float) * iw * inp_mult); - } else { - oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); - add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); - add(aux_reg_input, sizeof(float) * iw * inp_mult); - } - - dec(kj); - cmp(kj, 0); - jg(kh_loop, T_NEAR); - } - - L(skip_kh_loop); - - if (jcp.with_eltwise) { - Label regular_store; - test(reg_ci_flag, FLAG_IC_LAST); - je(regular_store, T_NEAR); - - eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1); - - L(regular_store); - } - - for (int ii = 0; ii < oc_blocks; ii++) { - for (int jj = 0; jj < ur_w; jj++) { - const size_t o_off = (ii * oh * ow + jj) * oc_blk; - - Xmm reg_out = Xmm(ur_w * ii + jj + 1); - movups(xword[reg_output + sizeof(float) * o_off], reg_out); - } - } - - mov(aux_reg_kernel, reg_kernel); - mov(aux_reg_input, reg_input); - add(aux_reg_kernel, sizeof(float) * 4); - add(reg_output, sizeof(float) * 4); - add(reg_bias, sizeof(float) * 4); - - inc(simd_iter); - cmp(simd_iter, 2); - jl(init_simd_iter_loop, T_NEAR); - - sub(reg_output, sizeof(float) * 8); - sub(reg_bias, sizeof(float) * 8); -} - -inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks) -{ - int ur_w = jcp.ur_w; - int ur_w_tail = jcp.ur_w_tail; - int n_oi = jcp.ow / ur_w; - int iw = jcp.iw; - int kw = jcp.kw; - int ic_blk = jcp.ic_block; - int oc_blk = jcp.oc_block; - int dilate_w = jcp.dilate_w + 1; - int str_w = jcp.stride_w; - const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk; - - int l_pad = jcp.l_pad; - int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w - - (iw + l_pad - 1)); - int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w - - (iw + l_pad - 1); - if (r_pad1 > 0) n_oi--; - - if (l_pad > 0) { - n_oi--; - if (n_oi < 0 && r_pad1 > 0) - width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" - else - width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" - add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_blk); - } - - Label ow_loop; - xor_(oi_iter, oi_iter); - - if (n_oi > 0) { - L(ow_loop); - - width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" - add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_blk); - - inc(oi_iter); - cmp(oi_iter, n_oi); - jl(ow_loop, T_NEAR); - } - - if (r_pad1 > 0 && n_oi >=0) { - width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" - add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); - add(reg_output, sizeof(float) * ur_w * oc_blk); - } - - if (ur_w_tail != 0) - width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" -} - -void jit_sse42_conv_fwd_kernel_f32::generate() -{ - this->preamble(); - - mov(reg_input, ptr[this->param1 + GET_OFF(src)]); - mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); - if (jcp.with_bias) - mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); - mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); - mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); - mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); - - int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; - Label tail, exit; - - cmp(reg_oc_blocks, jcp.nb_oc_blocking); - jne(nb_oc_tail ? tail : exit, T_NEAR); - - solve_common(jcp.nb_oc_blocking); - jmp(exit, T_NEAR); - - if (nb_oc_tail) { - L(tail); - cmp(reg_oc_blocks, nb_oc_tail); - jne(exit, T_NEAR); - solve_common(nb_oc_tail); - } - - L(exit); - - this->postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr) -{ - if (!mayiuse(sse42)) return status::unimplemented; - - jcp.prop_kind = cd.prop_kind; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - const int ndims = src_d.ndims(); - jcp.ndims = ndims; - - jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; - jcp.mb = src_d.dims()[0]; - - jcp.oc = dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; - jcp.iw = src_d.dims()[ndims - 1]; - jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; - jcp.ow = dst_d.dims()[ndims - 1]; - - jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; - jcp.kw = weights_d.dims()[with_groups + ndims - 1]; - - jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; - jcp.l_pad = cd.padding[0][ndims - 3]; - - jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; - jcp.stride_w = cd.strides[ndims - 3]; - - jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0]; - jcp.dilate_w = cd.dilates[ndims - 3]; - jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) - - (jcp.ih + jcp.t_pad - 1); - - if (ndims == 3) { - jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); - jcp.wei_tag = weights_d.matches_one_of_tag( - Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); - jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); - } else if (ndims == 4) { - jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); - jcp.wei_tag = weights_d.matches_one_of_tag( - Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); - jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); - } - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - const bool flat = jcp.ic == 3; - const bool mimo = !flat; - - bool args_ok = true - && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc) - && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o)) - && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c) - && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o)) - && one_of(jcp.dst_tag, nCw8c, nChw8c); - if (!args_ok) return status::unimplemented; - - const int simd_w = 8; // 2 SSE vectors processing at once - - jcp.ur_h = 1; /* no code-unrolling by h so far */ - jcp.ur_w = 3; - if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - - jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ - - args_ok = true - && jcp.oc % simd_w == 0 - && jcp.l_pad <= jcp.ur_w - && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) - || (jcp.stride_w == 1 && jcp.stride_h == 1)) - && IMPLICATION(mimo, jcp.ic % simd_w == 0); - if (!args_ok) return status::unimplemented; - - int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - - // kernel needs 1 temporary YMM register - const int num_avail_regs = 15; - if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { - /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ - jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, - nstl::min(jcp.ow, num_avail_regs / 2)); - jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; - jcp.ur_w_tail = jcp.ow % jcp.ur_w; - /* check again ... */ - r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) - return status::unimplemented; - } - assert(jcp.nb_oc_blocking > 0); - assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); - - jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; - jcp.nb_ic = jcp.ic / jcp.ic_block; - - jcp.oc_block = simd_w; - jcp.nb_oc = jcp.oc / jcp.oc_block; - - if (one_of(jcp.prop_kind, forward_training, forward_inference)) { - jcp.nb_ic_blocking = 12; - jcp.nb_ic_blocking_max = 16; - } else { - jcp.nb_ic_blocking = 1; - jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; - } - - return status::success; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp deleted file mode 100644 index 33c26ef08..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP -#define JIT_SSE42_CONV_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "cpu_memory.hpp" -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_sse42_conv_fwd_kernel_f32: public jit_generator { - jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, - const primitive_attr_t &attr) - : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, - jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); - } - - ~jit_sse42_conv_fwd_kernel_f32() { - delete eltwise_injector_; - } - - static bool post_ops_ok(jit_conv_conf_t &jcp, - const primitive_attr_t &attr); - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32) - jit_conv_conf_t jcp; - const primitive_attr_t &attr_; - void (*jit_ker)(jit_conv_call_s *); - -private: - using reg64_t = const Xbyak::Reg64; - reg64_t reg_input = rax; - reg64_t aux_reg_input = r8; - reg64_t reg_kernel = rdx; - reg64_t aux_reg_kernel = r9; - reg64_t reg_output = rsi; - reg64_t reg_bias = rbx; - - reg64_t kj = r10; - reg64_t oi_iter = r11; - reg64_t ki_iter = r12; - reg64_t reg_kh = abi_not_param1; - reg64_t simd_iter = r15; - reg64_t reg_oc_blocks = r14; - reg64_t imm_addr64 = reg_oc_blocks; - Xbyak::Reg32 reg_ci_flag = r13d; - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, - int oc_blocks); - inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); - inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); - inline void solve_common(int oc_blocks); - - void generate(); -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp deleted file mode 100644 index 5f77d692f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp +++ /dev/null @@ -1,136 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "jit_sse42_convolution.hpp" -#include "mkldnn_thread.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; - -#define src_blk_off(f, n, c, h, w) \ - (pd()->ndims() == 3) \ - ? (f).blk_off(n, c, w) \ - : (f).blk_off(n, c, h, w) - -#define wht_blk_off_(f, g, ...) \ - pd()->with_groups() \ - ? (f).blk_off(g, __VA_ARGS__) \ - : (f).blk_off(__VA_ARGS__) -#define wht_blk_off(f, g, oc, ic, kh, kw) \ - pd()->ndims() == 3 \ - ? wht_blk_off_(f, g, oc, ic, kw) \ - : wht_blk_off_(f, g, oc, ic, kh, kw) - -void jit_sse42_convolution_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const auto &jcp = kernel_->jcp; - - int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); - const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; - - parallel(0, [&](const int ithr, const int nthr) { - size_t start{ 0 }, end{ 0 }; - balance211(work_amount, nthr, ithr, start, end); - - int icbb = 0; - while (icbb < jcp.nb_ic) { - int icb_step = jcp.nb_ic_blocking; - int icb_step_rem = jcp.nb_ic - icbb; - if (icb_step_rem < jcp.nb_ic_blocking_max) - icb_step = icb_step_rem; - - size_t n{0}, g{0}, ocbb{0}, oh{0}; - nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, - oh, jcp.oh); - for (size_t iwork = start; iwork < end; ++iwork) { - int ocb = ocbb * jcp.nb_oc_blocking; - int ocb_num = jcp.nb_oc_blocking; - - for (int icb = icbb; icb < icbb + icb_step; ++icb) { - auto par_conv = jit_conv_call_s(); - - const int ij = oh * jcp.stride_h; - const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); - const int i_b_overflow = nstl::max(jcp.ih, ij - + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; - - const size_t _oc = g * jcp.nb_oc + ocb; - const size_t _ic = g * jcp.nb_ic + icb; - - const int ih = nstl::max(ij - jcp.t_pad - + div_up(i_t_overflow, - (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); - par_conv.src = &src[src_blk_off(src_d, n, - jcp.ic == 3 ? 0 : _ic, ih, 0)]; - - par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)]; - - const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); - par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, - jcp.ic == 3 ? 0 : icb, wh, 0)]; - - if (icb == 0) { - if (bias) - par_conv.bias = - &bias[bias_d.blk_off(_oc * jcp.oc_block)]; - par_conv.flags |= FLAG_IC_FIRST; - } - - if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { - par_conv.flags |= FLAG_IC_LAST; - } - - par_conv.oc_blocks = - nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; - - par_conv.kw_padding = 0; - const int kh_padding = jcp.kh - - div_up(i_t_overflow, (jcp.dilate_h + 1)) - - div_up(i_b_overflow, (jcp.dilate_h + 1)); - par_conv.kh_padding = nstl::max(0, kh_padding); - kernel_->jit_ker(&par_conv); - } - nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, - oh, jcp.oh); - } - icbb += icb_step; - } - }); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp deleted file mode 100644 index d2f0a38c5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP -#define CPU_JIT_SSE42_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_primitive_conf.hpp" -#include "jit_sse42_conv_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_sse42_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", sse42, ""), - jit_sse42_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), - *src_md(), *weights_md(), *dst_md(), *attr()); - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - const bool flat = IC() == 3; - auto src_tag = flat - ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) - : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto dst_tag = - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); - auto wei_tag = with_groups() - ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, - gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) - : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, - OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); - - return set_default_formats_common(src_tag, wei_tag, dst_tag); - } - }; - - jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } - ~jit_sse42_convolution_fwd_t() { delete kernel_; }; - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_sse42_conv_fwd_kernel_f32 *kernel_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp deleted file mode 100644 index 0e734f726..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp +++ /dev/null @@ -1,1192 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "nstl.hpp" -#include "utils.hpp" -#include "jit_generator.hpp" -#include "cpu_barrier.hpp" - -#include "jit_transpose_src_utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - -#define GET_OFF(x) offsetof(ctx_t, x) - -struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t) - - jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { - generate(); - ker_ = (decltype(ker_))this->getCode(); - } - -private: - using reg64_t = const Xbyak::Reg64; - using reg32_t = const Xbyak::Reg32; - using opmask_t = const Xbyak::Opmask; - - enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 }; - int src_stride, tr_src_stride; - int tail; - bool enable_prefetch; - - opmask_t k3333 = k1; - opmask_t k5555 = k2; - opmask_t kAAAA = k3; - opmask_t kCCCC = k4; - opmask_t k0F0F = k5; - opmask_t kF0F0 = k6; - opmask_t kTail = k7; - - reg64_t reg_src = r8; - reg64_t reg_tr_src = r9; - reg64_t reg_src_prf = r10; - reg64_t reg_tr_src_prf = r11; - reg64_t reg_loop = r12; - reg64_t reg_tr_src_tmp = r13; - reg32_t regw_tmp = r14d; - - void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); - void generate(); -}; - -void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad, - bool nontemporal_stores) { - assert(nrows >= 0 && nrows <= transpose_size); - static_assert(transpose_size == 16, "Unsupported transpose size"); - if (!nrows) - return; - - auto pf_src_t0 = [=](int i) { - if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src, - (transpose_size + i) * src_stride)); - }; - - auto pf_tr_src_t0 = [=](int i) { - int offset = (transpose_size) * typesize + i * tr_src_stride; - if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset)); - if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, - offset + 64)); - }; - - auto pf_src_t1 = [=](int i) { - if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf, - i * src_stride)); - }; - - auto pf_tr_src_t1 = [=](int i) { - if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, - i * tr_src_stride)); - }; - - auto src_zmm = [=](int i) { - assert(i >= 0 && i < 16); - return Zmm(i); - }; - - auto tmp_zmm = [=](int i) { - assert(i >= 0 && i < 16); - return Zmm(16 + i); - }; - - auto load = [=](int i) { - vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); - }; - - auto store = [=](Zmm r, int i) { - auto kmovw = [=](Opmask k, unsigned w) { - mov(regw_tmp, w); - jit_generator::kmovw(k, regw_tmp); - }; - - auto padding = [=] (Reg64 reg, int pad) { - kmovw(kTail, (1 << pad) - 1); - auto k = kTail; - auto base = reg; - base.setOpmaskIdx(k.getIdx(), true); - - auto zmm_zero = r; - vpxord(zmm_zero, zmm_zero, zmm_zero); - auto addr = EVEX_compress_addr(base, i * tr_src_stride); - vmovups(addr, zmm_zero); - }; - - mov(reg_tr_src_tmp, reg_tr_src); - if (l_pad > 0) - add(reg_tr_src_tmp, l_pad * typesize); - - if (tail != transpose_size) - kmovw(kTail, (1 << tail) - 1); - - // Xbyak does not allow k0 to be specified explicitly via the '|' - // operator, so we have to do this via a method call (implicitly - // EVEX encoding uses k0 to mean 'no mask') - bool partial_store = nrows < 16; - auto k = partial_store ? kTail : k0; - auto base = reg_tr_src_tmp; - base.setOpmaskIdx(k.getIdx(), true); - - auto addr = EVEX_compress_addr(base, i * tr_src_stride); - if (nontemporal_stores && !partial_store) - vmovntps(addr, r); - else - vmovups(addr, r); - - if (r_pad > 0) { - add(reg_tr_src_tmp, tail * typesize); - padding(reg_tr_src_tmp, r_pad); - } - - if (l_pad > 0) { - padding(reg_tr_src, l_pad); - } - }; - - auto transpose16x8 = [=](int base_idx) { - assert(base_idx == 0 || base_idx == 8); - - // swap 1 - for (int i = 0; i < 4; i++) { - int src_idx0 = base_idx + i * 2; - int src_idx1 = src_idx0 + 1; - - int next_src_idx0 = src_idx0 + 2; - int next_src_idx1 = src_idx1 + 2; - bool load_next = base_idx == 0 || i < 3; - - if (base_idx == 0 && i == 0) { - load(src_idx0); - load(src_idx1); - } - - auto tmp0 = tmp_zmm(src_idx0); - auto tmp1 = tmp_zmm(src_idx1); - auto src0 = src_zmm(src_idx0); - auto src1 = src_zmm(src_idx1); - - if (next_src_idx0 < nrows && load_next) - load(next_src_idx0); - valignd(tmp0, src0, src0, 0x1); - pf_src_t1(base_idx + i); - - if (next_src_idx1 < nrows && load_next) - load(next_src_idx1); - valignd(tmp1, src1, src1, 0xf); - pf_src_t0(base_idx + i); - - vmovaps(src0 | kAAAA, tmp1); - vmovaps(src1 | k5555, tmp0); - } - // swap 2 - for (int i = 0; i < 4; i++) { - int select_half = (i < 2) ? 0 : 2; - int src_idx0 = base_idx + i + select_half + 0; - int src_idx2 = src_idx0 + 2; - - auto tmp0 = tmp_zmm(src_idx0); - auto tmp1 = tmp_zmm(src_idx2); - auto src0 = src_zmm(src_idx0); - auto src2 = src_zmm(src_idx2); - - valignd(tmp0, src0, src0, 0x2); - pf_src_t1(base_idx + 4 + i); - valignd(tmp1, src2, src2, 0xe); - pf_src_t0(base_idx + 4 + i); - vmovaps(src2 | k3333, tmp0); - vmovaps(src0 | kCCCC, tmp1); - } - - // swap 4 - for (int i = 0; i < 4; i++) { - int src_idx0 = base_idx + i; - int src_idx4 = src_idx0 + 4; - - auto tmp0 = tmp_zmm(src_idx0); - auto src0 = src_zmm(src_idx0); - auto src4 = src_zmm(src_idx4); - - vmovaps(tmp0, src0); - vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); - pf_tr_src_t1(base_idx / 2 + i); - vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); - pf_tr_src_t0(base_idx / 2 + i); - } - }; - - auto fixup16x16 = [=]() { - // swap 8 - for (int i = 0; i < 8; i++) { - auto tmp = tmp_zmm(i); - auto src0 = src_zmm(i); - auto src8 = src_zmm(8 + i); - vshuff64x2(tmp, src0, src8, 0x44); - store(tmp, i); - if (i % 2 == 0) { - pf_tr_src_t1(8 + i / 2); - pf_tr_src_t0(8 + i / 2); - } - } - - for (int i = 0; i < 8; i++) { - auto tmp = tmp_zmm(8 + i); - auto src0 = src_zmm(i); - auto src8 = src_zmm(8 + i); - vshuff64x2(tmp, src0, src8, 0xee); - store(tmp, 8 + i); - if (i % 2 == 0) { - pf_tr_src_t1(12 + i / 2); - pf_tr_src_t0(12 + i / 2); - } - } - }; - - transpose16x8(0); - transpose16x8(8); - fixup16x16(); -} - -void jit_trans_iw_ic_t::generate() { - preamble(); - - const int ic_block = conf_->ic_block; - const int iw = conf_->iw; - const int tr_iw = conf_->tr_iw; - const int transposes = utils::div_up(iw, transpose_size); - int loop_iters = nstl::max(0, transposes - 1); - tail = iw - loop_iters * transpose_size; - - src_stride = ic_block * typesize; - assert(src_stride == 64); - tr_src_stride = tr_iw * typesize; - - bool nontemporal_stores = false; - enable_prefetch = iw > small_spatial ? 1 : 0; - - assert(transpose_size == ic_block); - const int src_step = ic_block * transpose_size * typesize; - const int tr_src_step = ic_block * typesize; - - const int left_pad = conf_->l_pad; - const int right_pad = tr_iw - iw - left_pad; - - mov(reg_src, ptr [param1 + GET_OFF(src)]); - mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); - mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); - mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); - - auto kmovw = [=](Opmask k, unsigned w) { - mov(regw_tmp, w); - jit_generator::kmovw(k, regw_tmp); - }; - - kmovw(k3333, 0x3333); // 0011001100110011 - kmovw(k5555, 0x5555); // 0101010101010101 - kmovw(kAAAA, 0xaaaa); // 1010101010101010 - kmovw(kCCCC, 0xcccc); // 1100110011001100 - kmovw(k0F0F, 0x0f0f); // 0000111100001111 - kmovw(kF0F0, 0xf0f0); // 1111000011110000 - - if (left_pad > 0 && loop_iters > 0) { - loop_iters--; - transpose(transpose_size, left_pad, 0, nontemporal_stores); - add(reg_src, src_step); - add(reg_tr_src, tr_src_step + left_pad * typesize); - add(reg_src_prf, src_step); - add(reg_tr_src_prf, tr_src_step + left_pad * typesize); - } - - if (loop_iters) { - mov(reg_loop, loop_iters); - Label loop; - L(loop); { - transpose(transpose_size, 0, 0, nontemporal_stores); - add(reg_src, src_step); - add(reg_tr_src, tr_src_step); - add(reg_src_prf, src_step); - add(reg_tr_src_prf, tr_src_step); - sub(reg_loop, 1); - jnz(loop); - } - } - if (transposes > 1) - transpose(tail, 0, right_pad, nontemporal_stores); - else - transpose(tail, left_pad, right_pad, nontemporal_stores); - - postamble(); -} - -struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t) - jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf): - jit_trans_src_t(conf) { - generate(); - ker_ = (decltype(ker_))this->getCode(); - } - -private: - using reg64_t = const Xbyak::Reg64; - using reg32_t = const Xbyak::Reg32; - using opmask_t = const Xbyak::Opmask; - - enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; - int src_stride, tr_src_stride; - int tail; - bool enable_prefetch; - - opmask_t kFFFF = k1; - opmask_t k5555 = k2; - opmask_t kAAAA = k3; - opmask_t kAA = k4; - opmask_t k55 = k5; - opmask_t kCC = k6; - opmask_t k33 = k7; - opmask_t kTail = k1; - - reg64_t reg_src = r8; - reg64_t reg_tr_src = r9; - reg64_t reg_src_prf = r10; - reg64_t reg_tr_src_prf = r11; - reg64_t reg_loop = r12; - reg64_t reg_tr_src_tmp = r13; - reg32_t regw_tmp = r14d; - reg64_t imm_addr64 = rbx; - - Xbyak::Zmm vidx1 = zmm31; - Xbyak::Zmm vidx2 = zmm30; - Xbyak::Zmm vidx3 = zmm29; - Xbyak::Zmm vidx4 = zmm28; - Xbyak::Zmm vidx5 = zmm27; - Xbyak::Zmm zmm_tmp = zmm26; - - - void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); - void generate(); -}; - -void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad, - bool nontemporal_stores) { - assert(nrows >= 0 && nrows <= transpose_size); - static_assert(transpose_size == 16, "Unsupported transpose size"); - if (!nrows) - return; - - auto src_zmm = [=](int i) { - return Zmm(i); - }; - - auto src_ymm = [=](int i) { - assert(i >= 0 && i < 16); - return Ymm(i); - }; - - auto load_ymm = [=](int i) { - vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); - }; - - auto kmovw = [=](Opmask k, unsigned w) { - mov(regw_tmp, w); - jit_generator::kmovw(k, regw_tmp); - }; - - auto store = [=](Zmm r, int i) { - - auto padding = [=] (Reg64 reg, int pad) { - kmovw(kTail, (1 << pad) - 1); - auto k = kTail; - auto base = reg; - base.setOpmaskIdx(k.getIdx(), true); - - auto zmm_zero = zmm_tmp; - vpxord(zmm_zero, zmm_zero, zmm_zero); - auto addr = EVEX_compress_addr(base, i * tr_src_stride); - vmovups(addr, zmm_zero); - }; - - int store_tail = (nrows%2) ? nrows+1 : nrows; - - int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2; - mov(reg_tr_src_tmp, reg_tr_src); - if (l_pad > 0) { - padding(reg_tr_src, store_pad); - add(reg_tr_src_tmp, l_pad * typesize); - } - if (r_pad > 0) { - store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2; - int addr_shift = (r_pad%2) ? 1 : 0; - add(reg_tr_src_tmp, (nrows - addr_shift) * typesize); - padding(reg_tr_src_tmp, store_pad); - } - - mov(reg_tr_src_tmp, reg_tr_src); - add(reg_tr_src_tmp, l_pad * typesize); - - kmovw(kTail, (1 << store_tail/2) - 1); - auto k = kTail; - auto base = reg_tr_src_tmp; - base.setOpmaskIdx(k.getIdx(), true); - - auto addr = EVEX_compress_addr(base, i * tr_src_stride); - vmovups(addr, r); - - }; - - kmovw(kFFFF, 0xffff); - //all loads - for (int i=0; i<16; i++){ - vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); - } - - for (int i = 0; i < nrows/2; i++) { - auto src0 = src_ymm(2*i); - auto src1 = src_ymm(2*i+1); - auto zmm_src0 = src_zmm(2*i); - load_ymm(2*i); - - vpunpcklwd(src1, src0, - EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); - vpunpckhwd(src0, src0, - EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); - vinserti64x4(zmm_src0, zmm_src0, src1, 1); - vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); - } - - // for odd numbers we need to mix row with zeroes - if (nrows%2) { - int i = nrows-1; - auto src0 = src_ymm(i); - auto src1 = src_ymm(i+1); //zero - - auto zmm_src0 = src_zmm(i); - vpxor(src1, src1, src1); - - load_ymm(i); - vpunpckhwd(src0, src0, src1); - vinserti64x4(zmm_tmp, zmm_tmp, src0, 0); - vpxor(src0, src0, src0); - load_ymm(i); - vpunpcklwd(src1, src0, src1); - vinserti64x4(zmm_tmp, zmm_tmp, src1, 1); - vpxord(zmm_src0, zmm_src0, zmm_src0); - vmovups(zmm_src0, zmm_tmp); - vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); - } - - // swap 1 - for (int i=0; i<4; i++) { - auto zmm0 = src_zmm(4*i); - auto zmm1 = src_zmm(4*i+2); - auto tmp0 = src_zmm(4*i+1); - auto tmp1 = src_zmm(4*i+3); - - vmovups(tmp0, zmm0); - vmovups(tmp1, zmm1); - - vpermps(tmp0 | kAAAA, vidx3, zmm1); - vpermps(tmp1 | k5555, vidx3, zmm0); - } - // swap 2 - int base_idx; - base_idx=0; - for (int i=0; i<2; i++) { - auto zmm0 = src_zmm(base_idx+2*i+1); - auto zmm1 = src_zmm(base_idx+2*i+5); - - auto tmp0 = src_zmm(base_idx+2*i); - auto tmp1 = src_zmm(base_idx+2*i+4); - - vmovupd(tmp0, zmm0); - vmovupd(tmp1, zmm1); - - vpermpd(tmp0 | kAA, vidx2, zmm1); - vpermpd(tmp1 | k55, vidx2, zmm0); - } - base_idx=8; - for (int i=0; i<2; i++) { - auto zmm0 = src_zmm(base_idx+2*i+1); - auto zmm1 = src_zmm(base_idx+2*i+5); - - auto tmp0 = src_zmm(base_idx+2*i); - auto tmp1 = src_zmm(base_idx+2*i+4); - - vmovupd(tmp0, zmm0); - vmovupd(tmp1, zmm1); - - vpermpd(tmp0 | kAA, vidx2, zmm1); - vpermpd(tmp1 | k55, vidx2, zmm0); - } - - // swap 3 - for (int i=0; i<4; i++) { - auto zmm0 = src_zmm(2*i); - auto zmm1 = src_zmm(2*i+8); - - auto tmp0 = src_zmm(2*i+1); - auto tmp1 = src_zmm(2*i+9); - - vmovupd(tmp0, zmm0); - vmovupd(tmp1, zmm1); - - vpermpd(tmp0 | kCC, vidx1, zmm1); - vpermpd(tmp1 | k33, vidx1, zmm0); - } - - // all stores - for (int i=0; i<8; i++) - vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1); - - store(src_zmm(1), 0); - store(src_zmm(0), 1); - store(src_zmm(3), 2); - store(src_zmm(2), 3); - store(src_zmm(9), 4); - store(src_zmm(8), 5); - store(src_zmm(11), 6); - store(src_zmm(10), 7); - store(src_zmm(5), 8); - store(src_zmm(4), 9); - store(src_zmm(7), 10); - store(src_zmm(6), 11); - store(src_zmm(13), 12); - store(src_zmm(12), 13); - store(src_zmm(15), 14); - store(src_zmm(14), 15); - -} - -void jit_trans_iw_ic_int16_t::generate() { - preamble(); - - alignas(64) static constexpr const int64_t idx1[8] - = { 2, 3, 0, 1, 6, 7, 4, 5 }; - alignas(64) static constexpr const int64_t idx2[8] - = { 1, 0, 3, 2, 5, 4, 7, 6 }; - alignas(64) static constexpr const int32_t idx3[16] - = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 }; - alignas(64) static constexpr const int32_t idx4[16] - = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; - alignas(64) static constexpr const int32_t idx5[16] - = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; - - const int ic_block = conf_->ic_block; - const int iw = conf_->iw; - const int tr_iw = conf_->tr_iw; - const int transposes = utils::div_up(iw, transpose_size); - int loop_iters = nstl::max(0, transposes - 1); - tail = iw - loop_iters * transpose_size; - - src_stride = ic_block * typesize; - tr_src_stride = tr_iw * typesize; - - bool nontemporal_stores = false; - enable_prefetch = iw > small_spatial ? 1 : 0; - - assert(transpose_size == ic_block); - const int src_step = ic_block * transpose_size * typesize; - const int tr_src_step = ic_block * typesize; - - const int left_pad = conf_->l_pad; - const int right_pad = tr_iw - iw - left_pad; - - mov(reg_src, ptr [param1 + GET_OFF(src)]); - mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); - mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); - mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); - - auto kmovw = [=](Opmask k, unsigned w) { - mov(regw_tmp, w); - jit_generator::kmovw(k, regw_tmp); - }; - - kmovw(kFFFF, 0xffff); - kmovw(k5555, 0x5555); - kmovw(kAAAA, 0xaaaa); - kmovw(kAA, 0xaa); - kmovw(k55, 0x55); - kmovw(kCC, 0xcc); - kmovw(k33, 0x33); - - auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { - mov(imm_addr64, reinterpret_cast(addr)); - jit_generator::vmovdqa64(z, ptr[imm_addr64]); - }; - - auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { - mov(imm_addr64, reinterpret_cast(addr)); - jit_generator::vmovdqa32(z, ptr[imm_addr64]); - }; - - vmovdqa64(vidx1, idx1); - vmovdqa64(vidx2, idx2); - vmovdqa32(vidx3, idx3); - vmovdqa32(vidx4, idx4); - vmovdqa32(vidx5, idx5); - - if (left_pad > 0 && loop_iters > 0) { - loop_iters--; - transpose(transpose_size, left_pad, 0, nontemporal_stores); - add(reg_src, src_step); - add(reg_tr_src, tr_src_step + left_pad * typesize); - add(reg_src_prf, src_step); - add(reg_tr_src_prf, tr_src_step + left_pad * typesize); - } - - if (loop_iters) { - mov(reg_loop, loop_iters); - Label loop; - L(loop); { - transpose(transpose_size, 0, 0, nontemporal_stores); - add(reg_src, src_step); - add(reg_tr_src, tr_src_step); - add(reg_src_prf, src_step); - add(reg_tr_src_prf, tr_src_step); - sub(reg_loop, 1); - jnz(loop); - } - } - if (transposes > 1) - transpose(tail, 0, right_pad, nontemporal_stores); - else - transpose(tail, left_pad, right_pad, nontemporal_stores); - - postamble(); - -} - -struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t) - jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) { - generate(); - ker_ = (decltype(ker_))this->getCode(); - } - -private: - using reg64_t = const Xbyak::Reg64; - using reg32_t = const Xbyak::Reg32; - using opmask_t = const Xbyak::Opmask; - using zmm = const Xbyak::Zmm; - - enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; - int src_stride, tr_src_stride; - int tail; - bool enable_prefetch; - - opmask_t kFF = k1; - - zmm vidx1 = zmm31; - - reg64_t reg_src = r8; - reg64_t reg_tr_src = r9; - reg64_t reg_src_prf = r10; - reg64_t reg_tr_src_prf = r11; - reg64_t reg_loop = r12; - reg64_t reg_tr_src_tmp = r13; - reg32_t regw_tmp = r14d; - reg64_t imm_addr64 = rbx; - - void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); - void generate(); -}; - -void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad, - bool nontemporal_stores) { - assert(nrows >= 0 && nrows <= transpose_size); - static_assert(transpose_size == 16, "Unsupported transpose size"); - if (!nrows) - return; - - auto src_zmm = [=](int i) { - return Zmm(i); - }; - - auto src_ymm = [=](int i) { - assert(i >= 0 && i < 16); - return Ymm(i); - }; - - auto load_ymm = [=](int i) { - vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); - }; - - - auto store = [=](Zmm r, int i) { - auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); - if (nontemporal_stores) - vmovntps(addr, r); - else - vmovups(addr, r); - }; - - for (int i = 0; i < nrows/2; i++) { - auto src0 = src_ymm(2*i); - auto src1 = src_ymm(2*i+1); - auto zmm_src0 = src_zmm(2*i); - load_ymm(2*i); - vpunpcklwd(src1, src0, - EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); - vpunpckhwd(src0, src0, - EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); - vinserti64x4(zmm_src0, zmm_src0, src1, 1); - vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); - store(zmm_src0, 2*i); - } - if (r_pad > 0) { - auto src0 = src_ymm(nrows-1); - auto src1 = src_ymm(nrows); - auto zmm_src0 = src_zmm(30); - load_ymm(nrows-1); - - vpxor(src1, src1, src1); - vpunpckhwd(src1, src0, src1); - vinserti64x4(zmm_src0, zmm_src0, src1, 0); - vpxor(src1, src1, src1); - vpunpcklwd(src0, src0, src1); - vinserti64x4(zmm_src0, zmm_src0, src0, 1); - vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); - store(zmm_src0, nrows-1); - } -} - -void jit_trans_ow_oc_t::generate() { - preamble(); - - alignas(64) static constexpr const int64_t idx1[8] - = { 4, 5, 0, 1, 6, 7, 2, 3 }; - - const int oc_block = conf_->oc_block; - const int ow = conf_->ow; - const int transposes = utils::div_up(ow, transpose_size); - int loop_iters = nstl::max(0, transposes - 1); - tail = ow - loop_iters * transpose_size; - - src_stride = oc_block * typesize; - tr_src_stride = oc_block * typesize; - - bool nontemporal_stores = false; - enable_prefetch = ow > small_spatial ? 1 : 0; - - const int src_step = oc_block * transpose_size * typesize; - const int tr_src_step = oc_block * transpose_size * typesize; - const int right_pad = ow % 2; - - mov(reg_src, ptr [param1 + GET_OFF(src)]); - mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); - mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); - mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); - - auto kmovw = [=](Opmask k, unsigned w) { - mov(regw_tmp, w); - jit_generator::kmovw(k, regw_tmp); - }; - - kmovw(kFF, 0xFF); - - auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { - mov(imm_addr64, reinterpret_cast(addr)); - jit_generator::vmovdqa64(z, ptr[imm_addr64]); - }; - - vmovdqa64(vidx1, idx1); - if (loop_iters) { - mov(reg_loop, loop_iters); - Label loop; - L(loop); { - transpose(transpose_size, 0, 0, nontemporal_stores); - add(reg_src, src_step); - add(reg_tr_src, tr_src_step); - add(reg_src_prf, src_step); - add(reg_tr_src_prf, tr_src_step); - sub(reg_loop, 1); - jnz(loop); - } - } - transpose(tail, 0, right_pad, nontemporal_stores); - - postamble(); -} - -struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t) - - jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { - generate(); - ker_ = (decltype(ker_))this->getCode(); - } - - void generate(); - enum { typesize = (int)sizeof(float) }; -}; - -/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4] - * required for 1st 4fma backward by weights convolution */ -void jit_trans_iw_x4_4x_t::generate() { - using namespace utils; - - /* TODO: put into code */ - static int mask[16] = { - 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, }; - - const auto &c = *conf_; - const int simd_w = cpu_isa_traits::vlen / typesize; - const int niters = c.tr_ld / simd_w; - - assert(niters <= 4); /* [bwd_w:tr_src:r1] */ - - Reg64 reg_ptr_src = r8; - Reg64 reg_ptr_tr_src = r9; - - Reg64 reg_ih = rax; - Reg64 reg_ih_end = rbx; - - Reg64 reg_nthr_oc_b = rsi; - Reg64 reg_ptr_tr_src_bctx = abi_not_param1; - - Reg64 reg_tmp = rdx; - - Zmm vmsk = Zmm(31); - Opmask kmsk = k7; - - auto emit_tr_sync = [&]() { - simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b); - }; - - auto emit_tr_iw = [&]() { - auto vreg = [](int iter, int i) { - assert(4 * iter + i < 24); - return Zmm(4 * iter + i); - }; - auto vtmp = [](int i) { return Zmm(24 + i); }; - - auto emit_load = [&](int iter) { - for (int i = 0; i < 4; ++i) { - auto v = vreg(iter, i); - const int off = (iter * 4 + i) * simd_w; - - if (off + simd_w <= c.iw) - vmovups(v, ptr[reg_ptr_src + off * typesize]); - else if (off < c.iw) - vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]); - else - vpxord(v, v, v); - } - }; - - auto emit_tr = [&](int iter) { - for (int i = 0; i < 4; ++i) - vpermps(vreg(iter, i), vmsk, vreg(iter, i)); - - vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88); - vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd); - vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88); - vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd); - - vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88); - vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd); - vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88); - vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd); - }; - - auto emit_store = [&]() { - for (int i = 0; i < 4; ++i) { - for (int iter = 0; iter < niters; ++iter) { - const size_t off = i * c.tr_ld + iter * simd_w; - vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i)); - } - } - }; - - for (int iter = 0; iter < niters; ++iter) - emit_load(iter); - - for (int iter = 0; iter < niters; ++iter) - emit_tr(iter); - - emit_store(); - }; - - preamble(); - - mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]); - mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]); - - mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]); - mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]); - mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]); - mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]); - - emit_tr_sync(); - - Label l_ih_loop, l_tr_done; - cmp(reg_ih, reg_ih_end); - je(l_tr_done, T_NEAR); - - mov(reg_tmp, (size_t)&mask[0]); - vmovups(vmsk, ptr[reg_tmp]); - - if (c.iw % simd_w) { - const char load_mask = (1 << (c.iw % simd_w)) - 1; - mov(reg_tmp, load_mask); - kmovw(kmsk, reg_tmp.cvt32()); - } - - /* src += ih_start * c.iw; */ - imul(reg_tmp, reg_ih, c.iw * typesize); - add(reg_ptr_src, reg_tmp); - /* tr_src += ih_start * c.stride_w * c.tr_ld; */ - imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize); - add(reg_ptr_tr_src, reg_tmp); - - L(l_ih_loop); { - emit_tr_iw(); - - add(reg_ptr_src, c.iw * typesize); - add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize); - - inc(reg_ih); - cmp(reg_ih, reg_ih_end); - jl(l_ih_loop, T_NEAR); - } - - L(l_tr_done); - - emit_tr_sync(); - - postamble(); -} - -/* -// ------------------------------------------------- -// jit_transpose4x16_src -// ------------------------------------------------- -*/ - -void jit_transpose4x16_src::transpose(int nrows) -{ - assert(nrows >= 0 && nrows <= transpose_size); - static_assert(transpose_size == 4, "Unsupported transpose size"); - if (!nrows) - return; - - auto pf_src_t0 = [=](int i) { - if (tparams->src_pf0_distance) - prefetcht0(EVEX_compress_addr( - reg_src, (tparams->src_pf0_distance + i) * src_stride)); - }; - - auto pf_tr_src_t0 = [=](int i) { - if (tparams->tr_src_pf0_distance) - prefetcht0(EVEX_compress_addr(reg_tr_src, - (tparams->tr_src_pf0_distance + i) * src_stride)); - }; - - auto pf_src_t1 = [=](int i) { - if (tparams->src_pf1) - prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride)); - }; - - auto pf_tr_src_t1 = [=](int i) { - if (tparams->tr_src_pf1) - prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride)); - }; - - auto src_zmm = [=](int i) { - assert(i >= 0 && i < 4); - return Zmm(i); - }; - - auto tmp_zmm = [=](int i) { - assert(i >= 0 && i < 4); - return Zmm(4 + i); - }; - - auto load = [=](int i) { - vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); - }; - - auto store = [=](Zmm r, int i) { - vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r); - }; - - auto tmp0 = tmp_zmm(0); - auto tmp1 = tmp_zmm(1); - auto tmp2 = tmp_zmm(2); - auto tmp3 = tmp_zmm(3); - - auto src0 = src_zmm(0); - auto src1 = src_zmm(1); - auto src2 = src_zmm(2); - auto src3 = src_zmm(3); - for (int i = 0; i < nrows; i++) { - load(i); - } - - for (size_t i = nrows; i < 4; i++) { - vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); - } - - vmovupd(tmp0, src0); - vmovupd(tmp1, src1); - pf_src_t0(0); - vpermpd(tmp0 | kF0, vidx01, src2); - vpermpd(tmp1 | kF0, vidx01, src3); - - valignd(src0, src0, src0, 8); - valignd(src1, src1, src1, 8); - pf_src_t0(1); - vmovupd(tmp2, src0); - vmovupd(tmp3, src1); - pf_src_t0(2); - vpermpd(tmp2 | kF0, vidx10, src2); - vpermpd(tmp3 | kF0, vidx10, src3); - pf_src_t0(3); - - vmovupd(src0, tmp0); - pf_src_t1(0); - vmovupd(src1, tmp2); - pf_src_t1(1); - vmovupd(src2, tmp1); - pf_src_t1(2); - vmovupd(src3, tmp3); - pf_src_t1(3); - vpermpd(src0 | kCC, vidx1, tmp1); - vpermpd(src1 | kCC, vidx1, tmp3); - pf_tr_src_t0(0); - vpermpd(src2 | k33, vidx1, tmp0); - vpermpd(src3 | k33, vidx1, tmp2); - pf_tr_src_t0(1); - - vmovupd(tmp0, src0); - vmovupd(tmp1, src2); - pf_tr_src_t0(2); - vmovupd(tmp2, src1); - vmovupd(tmp3, src3); - pf_tr_src_t0(3); - vpermps(tmp0 | kFFFF, vidxP, src0); - pf_tr_src_t1(0); - vpermps(tmp1 | kFFFF, vidxP, src2); - pf_tr_src_t1(1); - vpermps(tmp2 | kFFFF, vidxP, src1); - pf_tr_src_t1(3); - vpermps(tmp3 | kFFFF, vidxP, src3); - pf_tr_src_t1(4); - - store(tmp0, 0); - store(tmp1, 1); - store(tmp2, 2); - store(tmp3, 3); -} - -alignas(64) static constexpr const int64_t idx01[8] - = { 0, 0, 0, 0, 0, 1, 2, 3 }; -alignas(64) static constexpr const int64_t idx10[8] - = { 0, 0, 0, 0, 4, 5, 6, 7 }; -alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 }; -alignas(64) static constexpr const int32_t idxP[16] - = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; - -void jit_transpose4x16_src::generate() -{ - preamble(); - - const int ic_block = params->ic_block; - const int is = params->is; - int tail = is % transpose_size; - - src_stride = ic_block * typesize; - assert(src_stride == 64); - tr_src_stride = ic_block * typesize; - - const int src_step = ic_block * transpose_size * typesize; - const int tr_src_step = ic_block * transpose_size * typesize; - -#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x) - mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]); - mov(reg_src, ptr[param1 + GET_TR_OFF(src)]); - mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]); - mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]); - mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]); -#undef GET_TR_OFF - - auto kmovw = [=](Opmask k, unsigned w) { - mov(regw_tmp, w); - jit_generator::kmovw(k, regw_tmp); - }; - - auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { - mov(imm_addr64, reinterpret_cast(addr)); - jit_generator::vmovdqa64(z, ptr[imm_addr64]); - }; - - auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { - mov(imm_addr64, reinterpret_cast(addr)); - jit_generator::vmovdqa32(z, ptr[imm_addr64]); - }; - - kmovw(kF0, 0xf0); // 11110000 - kmovw(kCC, 0xcc); // 11001100 - kmovw(k33, 0x33); // 00110011 - kmovw(kFFFF, 0xffff); // 1111111111111111 - - vmovdqa64(vidx01, idx01); - vmovdqa64(vidx10, idx10); - vmovdqa64(vidx1, idx1); - vmovdqa32(vidxP, idxP); - - Label loop_label; - Label tail_label; - - cmp(reg_loop, transpose_size); - jl(tail_label, T_NEAR); - - L(loop_label); - { - transpose(transpose_size); - add(reg_src, src_step); - add(reg_tr_src, tr_src_step); - add(reg_src_prf, src_step); - add(reg_tr_src_prf, tr_src_step); - sub(reg_loop, transpose_size); - cmp(reg_loop, transpose_size); - jge(loop_label, T_NEAR); - } - L(tail_label); - transpose(tail); - - postamble(); -} - -jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) { - if (conf->ver == ver_4fma && !conf->is_1stconv) - return new jit_trans_iw_ic_t(conf); - if (conf->ver == ver_4fma && conf->is_1stconv) - return new jit_trans_iw_x4_4x_t(conf); - assert(!"unsupported configuration"); - return nullptr; -} - -jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) { - assert(!"unsupported configuration"); - return nullptr; -} -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp deleted file mode 100644 index 565e97e4f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp +++ /dev/null @@ -1,145 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_TRANSPOSE_SRC_HPP -#define CPU_JIT_TRANSPOSE_SRC_HPP - -#include "cpu_barrier.hpp" -#include "jit_primitive_conf.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_trans_src_t { - struct ctx_t { - const void *src; - const void *tr_src; - const void *src_prf; - const void *tr_src_prf; - - /* 1st conv 4fma: backward by weights */ - int nthr_oc_b; /* number of threads process given src image */ - int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ - simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ - }; - - jit_trans_src_t(const jit_conv_conf_t *conf) - : conf_(conf), ker_(nullptr) {} - virtual ~jit_trans_src_t() {} - - void operator()(const ctx_t *ctx) - { assert(ker_); ker_(ctx); } - - const jit_conv_conf_t *conf_; - void (*ker_)(const ctx_t *); -}; - -struct jit_src_transpose_s { - size_t size; - const void *src; - const void *tr_src; - const void *src_prf; - const void *tr_src_prf; -}; - -struct jit_trans_dst_t { - struct ctx_t { - const void *src; - const void *tr_src; - const void *src_prf; - const void *tr_src_prf; - - /* 1st conv 4fma: backward by weights */ - int nthr_oc_b; /* number of threads process given src image */ - int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ - simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ - }; - - jit_trans_dst_t(const jit_conv_conf_t *conf) - : conf_(conf), ker_(nullptr) {} - virtual ~jit_trans_dst_t() {} - - void operator()(const ctx_t *ctx) - { assert(ker_); ker_(ctx); } - - const jit_conv_conf_t *conf_; - void (*ker_)(const ctx_t *); -}; - -struct jit_transpose4x16_src_t { - int src_pf0_distance; - int tr_src_pf0_distance; - bool src_pf1; - bool tr_src_pf1; -}; - -struct jit_transpose4x16_src : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) - - jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, - jit_transpose4x16_src_t *tparams_) - : params(aparams), tparams(tparams_) - { - this->generate(); - jit_ker = (decltype(jit_ker))this->getCode(); - } - - const jit_1x1_conv_conf_t *params; - const jit_transpose4x16_src_t *tparams; - void (*jit_ker)(jit_src_transpose_s *); - - void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } - - static const int transpose_size = 4; -private: - static const int typesize = sizeof(float); - - int src_stride, tr_src_stride; - - Xbyak::Reg64 imm_addr64 = rbx; - - Xbyak::Opmask kF0 = k1; - Xbyak::Opmask kCC = k2; - Xbyak::Opmask k33 = k3; - Xbyak::Opmask kFFFF = k4; - - Xbyak::Zmm vidx01 = zmm31; - Xbyak::Zmm vidx10 = zmm30; - Xbyak::Zmm vidx1 = zmm29; - Xbyak::Zmm vidxP = zmm28; - - Xbyak::Reg64 reg_src = r8; - Xbyak::Reg64 reg_tr_src = r9; - Xbyak::Reg64 reg_src_prf = r10; - Xbyak::Reg64 reg_tr_src_prf = r11; - Xbyak::Reg64 reg_loop = r12; - Xbyak::Reg64 reg_tr_src_tmp = r13; - Xbyak::Reg32 regw_tmp = r14d; - - void transpose_block(int ur, int nrows); - void transpose(int nrows); - void generate(); -}; - -jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); -jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp deleted file mode 100644 index 53313f9f0..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp +++ /dev/null @@ -1,327 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_UNI_1x1_CONV_UTILS_HPP -#define JIT_UNI_1x1_CONV_UTILS_HPP - -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; - -struct reduce_to_unit_stride_t { - convolution_desc_t conv_d_; - bool reduce_src_; - size_t space_per_thread_; -}; - -/* 1x1-kernel does not support non-unit strides so far, so the idea is: - * - for fwd or bwd_weights: to copy src to a scratch memory (with strides - * equal to 1) and then call the kernel - * - for bwd_data: reduce the problem to the one with unit stride by - * performing computations in a scratch memory (with strides equal to 1) - * and then copy the result to diff_src */ -template -inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d, - const memory_desc_t *&src_d, const memory_desc_t *dst_d) { - const bool is_bwd_data = self->desc()->prop_kind - == prop_kind::backward_data; - - const int ndims = src_d->ndims; - const auto dat_tag = ndims == 3 - ? memory_desc_wrapper(dst_d).matches_one_of_tag( - format_tag::nCw8c, format_tag::nCw16c) - : memory_desc_wrapper(dst_d).matches_one_of_tag( - format_tag::nChw8c, format_tag::nChw16c); - - bool rtus_applicable = true - && utils::pick(ndims - 3, - (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type, - data_type::s32)), - (conv_d->strides[0] != 1 || conv_d->strides[1] != 1)) - && dat_tag != format_tag::undef; - for (int d = 2; d < ndims; ++d) { - /* TODO: relax these conditions (by improving reducer) */ - rtus_applicable = rtus_applicable - && conv_d->padding[0][d - 2] == 0 - && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d]; - } - - if (rtus_applicable) { - self->rtus_.reduce_src_ = true; - conv_d = &(self->rtus_.conv_d_ = *conv_d); - self->rtus_.conv_d_.strides[0] = 1; - if (ndims == 4) - self->rtus_.conv_d_.strides[1] = 1; - utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2); - if (ndims == 4) - utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2); - const int ic = src_d->dims[1]; - if (is_bwd_data) { - src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d); - self->rtus_.conv_d_.diff_src_desc.dims[1] = ic; - memory_desc_wrapper::compute_blocking( - self->rtus_.conv_d_.diff_src_desc, dat_tag); - } else { - data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type; - src_d = &(self->rtus_.conv_d_.src_desc = *dst_d); - self->rtus_.conv_d_.src_desc.dims[1] = ic; - self->rtus_.conv_d_.src_desc.data_type = data_type; - memory_desc_wrapper::compute_blocking( - self->rtus_.conv_d_.src_desc, dat_tag); - } - } -} - -template -inline void rtus_prepare_space_info(conv_pd_t *self, - memory_tracking::registrar_t &scratchpad) { - const auto &jcp = self->jcp_; - - const int max_threads = mkldnn_get_max_threads(); - const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind, - jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking); - size_t typesize = types::data_type_size( - conv_prop_invariant_src_d(self->desc())->data_type); - - self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block; - scratchpad.book(memory_tracking::names::key_conv_rtus_space, - typesize * max_threads * self->rtus_.space_per_thread_); -} - -template -struct rtus_driver_t: public jit_generator { - - struct call_params_t { - const void *ws; /* reduced image (w/ strides = 1) */ - const void *src; /* source image (w/ non-unit strides) */ - size_t icb; - size_t os; - size_t iw_start; - }; - - void (*ker_)(const call_params_t *p); - - DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t) - - /* cpu specific part */ - using Vmm = typename utils::conditional::type; - - Xbyak::Reg64 reg_ws = abi_param1; - Xbyak::Reg64 reg_src = abi_not_param1; - Xbyak::Reg64 reg_icb = rdx; - Xbyak::Reg64 reg_os = r11; - Xbyak::Reg64 reg_iw_start = r8; - - Xbyak::Reg64 reg_cur_os = rax; - Xbyak::Reg64 reg_cur_iw = r9; - Xbyak::Reg64 reg_cur_src = r10; - - int iw_, stride_w_; - int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_; - bool src_to_ws_; - size_t typesize_; - Vmm reg_zero; - Vmm reg_v; - - rtus_driver_t(int iw, int stride_w, int src_step_h, - int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize) - : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h) - , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb) - , src_to_ws_(src_to_ws), typesize_(typesize) - { - using namespace Xbyak; - vlen_ = cpu_isa_traits::vlen; - vlen_shift_ = cpu_isa_traits::vlen_shift; - if (typesize_ == 2) { - vlen_ /= 2; - vlen_shift_--; - } - - reg_zero = Vmm(0); - reg_v = Vmm(1); - - generate(); - } - - void loop_is() { - using namespace Xbyak; - - mov(reg_cur_src, reg_src); - mov(reg_cur_iw, reg_iw_start); - mov(reg_cur_os, reg_os); - - Label is_loop, skip_h_step; - L(is_loop); - - if (src_to_ws_) { - vmovups(reg_v, ptr[reg_cur_src]); - vmovups(ptr[reg_ws], reg_v); - } else { - vmovups(reg_v, ptr[reg_ws]); - vmovups(ptr[reg_cur_src], reg_v); - for (int w = 1; w < stride_w_; ++w) - vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); - } - - add(reg_ws, vlen_); - - add(reg_cur_iw, stride_w_); - add(reg_cur_src, stride_w_ * vlen_); - - cmp(reg_cur_iw, iw_); - jl(skip_h_step); - /* for 1d convolution the loop over h should be skipped */ - if (src_step_icb_ == iw_) jmp(skip_h_step); - - if (src_to_ws_) { - add(reg_cur_src, (src_step_h_ - iw_) * vlen_); - } else { - Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */ - mov(reg_cur_src_fin, reg_cur_src); - add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_); - Label ih_loop; - L(ih_loop); - - for (int w = 0; w < stride_w_; ++w) - vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); - - add(reg_cur_src, stride_w_ * vlen_); - cmp(reg_cur_src, reg_cur_src_fin); - jl(ih_loop); - } - xor_(reg_cur_iw, reg_cur_iw); - - L(skip_h_step); - - sub(reg_cur_os, vlen_); - jnz(is_loop); - - /* restore dst */ - sub(reg_ws, reg_os); - } - - void generate() { - using namespace Xbyak; - assert(isa == avx2 || isa == avx512_common - || isa == avx512_core || isa == avx512_mic); - -#if defined(_WIN32) - assert(reg_src == abi_not_param1 && abi_not_param1 == rdi); - push(rdi); -#endif - -#define READ_PARAM(what) \ - mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)]) - READ_PARAM(src); - READ_PARAM(icb); - READ_PARAM(os); - READ_PARAM(iw_start); - - assert(reg_ws == abi_param1); - READ_PARAM(ws); /* reg_ws should always be read the last */ -#undef READ_PARAM - - shl(reg_os, vlen_shift_); - - if (!src_to_ws_) - uni_vpxor(reg_zero, reg_zero, reg_zero); - - Label icb_loop; - L(icb_loop); - - loop_is(); - - add(reg_ws, ws_step_icb_ * vlen_); - add(reg_src, src_step_icb_ * vlen_); - - dec(reg_icb); - jnz(icb_loop, T_NEAR); - -#if defined(_WIN32) - pop(rdi); -#endif - - uni_vzeroupper(); - ret(); - this->ker_ = reinterpret_cast(const_cast( - this->getCode())); - } -}; - -template -inline void init_rtus_driver(conv_t *self) { - const auto &conf = *self->pd(); - if (!conf.rtus_.reduce_src_) return; - - const auto &cd = *conf.desc(); - const int ndims = conf.ndims(); - const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0]; - const int stride_w = cd.strides[ndims - 3]; - - const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data; - const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md(); - - const int ih = ndims == 3 ? 1 : src_d.dims[2]; - const int iw = src_d.dims[ndims - 1]; - - const int src_step_h = stride_h * iw; - const int src_step_icb = ih * iw; - const int ws_step_icb = conf.jcp_.is; - const bool src_to_ws = !is_bwd_data; - const size_t typesize = types::data_type_size( - conv_prop_invariant_src_d(self->pd()->desc())->data_type); - - self->rtus_driver_ = new rtus_driver_t(iw, stride_w, src_step_h, - src_step_icb, ws_step_icb, src_to_ws, typesize); -} - -inline int best_divider(int value, int min_divider, int max_divider, - bool find_max, int step = 1) -{ - max_divider = nstl::max(1, nstl::min(max_divider, value)); - min_divider = nstl::max(1, nstl::min(min_divider, max_divider)); - - auto loss_ratio = [](int total, int chunk) - { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); }; - - float min_loss = FLT_MAX; - int x_divider = max_divider; - for (int divider = max_divider; divider >= min_divider; divider -= step) { - const float loss = loss_ratio(value, divider); - if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) { - min_loss = loss; - x_divider = divider; - } - } - return x_divider; -} - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp deleted file mode 100644 index 72fe3a810..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp +++ /dev/null @@ -1,1407 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "math_utils.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_barrier.hpp" -#include "cpu_batch_normalization_utils.hpp" -#include "jit_generator.hpp" - -#include "jit_uni_batch_normalization.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace { - -using namespace memory_tracking::names; - -using namespace Xbyak; -namespace barrier = simple_barrier; - -typedef float data_t; - -template -struct jit_bnorm_t: public jit_generator { - struct call_params_t { - // keep all sizes at 8 bytes -- jit code expects this - size_t N_ithr, N_nthr; - size_t coff_max, soff_max; - size_t mb_stride_Bc, spat_size, spat_size_loc; - size_t S_s, S_tail; - size_t is_cblk_tail; - data_t chan_size, eps, one; - const data_t *scale_shift; - const data_t *mean, *var; - const data_t *diff_scale_shift; - const data_t *src, *dst; - const data_t *diff_src, *diff_dst; - const data_t *rbuf1, *rbuf2; - const uint8_t *ws; - barrier::ctx_t *barrier; - }; - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t) - - /* cpu specific part */ - using Vmm = typename utils::conditional3::type; - const AddressFrame &vmmword = (isa == sse42) ? xword : - (isa == avx2) ? yword : zword; - - const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; - - const batch_normalization_pd_t *bdesc_; - bool is_spatial_thr_; - - void (*ker)(const call_params_t *); - void operator()(const call_params_t *p) { (*ker)(p); } - - Reg64 reg_param = abi_param1; - - Reg64 reg_scale_shift = rbx; - Reg64 reg_rbuf1 = abi_not_param1; - Reg64 reg_rbuf2 = rdx; - - Reg64 reg_mean = rbp; - Reg64 reg_var = reg_param; - Reg64 reg_diff_scale_shift = rax; - - Reg64 reg_coff = r8; - Reg64 reg_coff_max = r9; - Reg64 reg_soff = r10; - Reg64 reg_soff_max = r11; - Reg64 reg_ctr = r12; - Reg64 reg_roff = r13; - - Reg64 reg_mb_stride_Bc = r14; - - Reg64 reg_src = r15; - Reg64 reg_diff_src = reg_rbuf1; - Reg64 reg_dst = rsi; - Reg64 reg_diff_dst = reg_dst; - - Reg64 reg_tmp_off = reg_roff; - - // Reuse loop counters - Reg64 reg_bar = reg_coff; - Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff - Reg64 reg_tmp = reg_ctr; - - // Relu section - bool with_relu, with_relu_inf_only; - Vmm vzero; // is_fwd() ? vdiff_beta : vbeta - Reg64 reg_ws = reg_roff; - Label l_relu_mask_avx2; - Opmask kstore_mask = Opmask(1); - - // channel tail processing - Opmask ktail_mask = Opmask(2); - - size_t unroll_blocks; - size_t unroll_regs; - Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5); - Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6); - Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7); - Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8); - Vmm vone = Vmm(isa == avx512_common ? 24 : 9); - Vmm vmean = Vmm(isa == avx512_common ? 25 : 10); - Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11); - Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12); - Vmm veps = Vmm(isa == avx512_common ? 28 : 13); - Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14); - Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15); - - size_t t0_pf_offt; - size_t t1_pf_offt; - size_t spat_size; - size_t chan_data_offt; - - enum { - stack_off_N_nthr = 0, - stack_off_N_ithr = 8, - stack_off_src = 16, - stack_off_dst = 24, - stack_off_diff_src = 32, - stack_off_diff_dst = 40, - stack_off_diff_scale_shift = 48, - stack_off_ws = 56, - stack_off_barrier = 64, - stack_off_spat_size_loc = 72, - stack_off_s_s = 80, - stack_off_s_tail = 88, - stack_off_is_cblk_tail = 96, - stack_size_required = 104, - }; - - bool is_c_padded() const { - const memory_desc_wrapper data_d(bdesc_->src_md()); - return bdesc_->C() != data_d.padded_dims()[1]; - } - - void compute_static_strides() { - spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H(); - chan_data_offt = bdesc_->C() * sizeof(data_t); - - if (isa == avx512_mic) { - t0_pf_offt = 4096; - t1_pf_offt = 0; - } else { - t0_pf_offt = 0; - t1_pf_offt = 0; - } - } - - void load_common_params() { -# define PARAM_OFF(x) offsetof(call_params_t, x) - mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]); - if (bdesc_->is_bwd()) - mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]); - mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]); - mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]); - mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]); - shl(reg_coff_max, 2); - shl(reg_soff_max, 2); - shl(reg_mb_stride_Bc, 2); - - mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); - mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]); - - uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]); - uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]); - uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); - - mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]); - mov(ptr[rsp + stack_off_N_nthr], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]); - mov(ptr[rsp + stack_off_N_ithr], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]); - mov(ptr[rsp + stack_off_src], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]); - mov(ptr[rsp + stack_off_dst], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]); - mov(ptr[rsp + stack_off_diff_src], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]); - mov(ptr[rsp + stack_off_diff_dst], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]); - mov(ptr[rsp + stack_off_ws], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]); - mov(ptr[rsp + stack_off_barrier], reg_tmp); - if (is_spatial_thr_) { - mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]); - mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]); - mov(ptr[rsp + stack_off_s_s], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]); - mov(ptr[rsp + stack_off_s_tail], reg_tmp); - } - if (is_c_padded()) { - mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]); - mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp); - } - - if (bdesc_->is_fwd()) { - mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); - mov(reg_var, reg_tmp); - } else { - mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]); - mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp); - mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); - mov(reg_var, reg_tmp); - } -# undef PARAM_OFF - } - - void prepare_tail_mask_avx512_common() { - if (!is_c_padded()) return; - - const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); - const int mask = (1 << tail) - 1; - - Reg32 regw_tmp = reg_tmp.cvt32(); - mov(regw_tmp, mask); - kmovw(ktail_mask, regw_tmp); - } - - void prepare_tail_mask_avx2_common() { - if (!is_c_padded()) return; - - const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); - static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, - 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, - 0, 0, 0, 0, 0, 0, 0, 0}; - - mov(reg_tmp, reinterpret_cast(&mask[8 - tail])); - vmovups(vtail_mask, ptr[reg_tmp]); - } - - void prepare_relu() { - with_relu = bdesc_->is_fwd() - ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu() - : bdesc_->fuse_bn_relu(); - with_relu_inf_only = with_relu && bdesc_->is_fwd() - && !(bdesc_->fuse_bn_relu() && bdesc_->is_training()); - - vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta; - if (with_relu) { - uni_vpxor(vzero, vzero, vzero); - if (!bdesc_->is_fwd() && isa == avx2) - prepare_l_relu_mask_avx2(); - } - } - - void prepare_l_relu_mask_avx2() { - Label l_mask_after; - jmp(l_mask_after); - align(32); - L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ - for (int i = 0; i < 8; ++i) dd(1< - void spat_loop(size_t len, size_t blocks, size_t regs, - init_t init, body_t body, fini_t fini) { - size_t factor = regs * blocks; - size_t loop_unroll = len / factor * factor; - size_t loop_tail = len - loop_unroll; - size_t num_active_regs = (len < regs) ? len : regs; - for (size_t i = 0; i < num_active_regs; i++) - init(i); - if (loop_unroll) { - if (is_spatial_thr_) { - mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); - add(reg_soff, ptr[rsp + stack_off_s_s]); - } else { - mov(reg_ctr, loop_unroll); - } - Label label; - L(label); { - for (size_t i = 0; i < factor; i++) { - size_t base_reg = i % regs; - body(base_reg, i); - } - add(reg_soff, factor * vlen); - sub(reg_ctr, factor); - jnz(label); - } - if (is_spatial_thr_) { - add(reg_soff, ptr[rsp + stack_off_s_tail]); - } - } - - for (size_t i = 0; i < loop_tail; i++) { - size_t base_reg = i % regs; - body(base_reg, i); - } - if (loop_tail) - add(reg_soff, loop_tail * vlen); - - for (size_t i = 0; i < num_active_regs; i++) - fini(i); - } - - void mean_channels() { - Label ch_label; - L(ch_label); { - uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); - spat_loop(spat_size, unroll_blocks, - unroll_regs, - [=](size_t base_reg) { - Vmm v = Vmm(base_reg * 2); - if (base_reg) - uni_vpxor(v, v, v); - }, - [=](size_t base_reg, size_t i) { - Vmm v0 = Vmm(base_reg * 2 + 0); - Vmm v1 = Vmm(base_reg * 2 + 1); - size_t offt = i * vlen; - uni_vmovups(v1, - vmmword[reg_src + reg_soff + offt]); - uni_vaddps(v0, v0, v1); - mic_prefetcht0(ptr[reg_src + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht1(ptr[reg_src + reg_soff + offt - + t1_pf_offt]); - }, - [=](size_t base_reg) { - Vmm b = Vmm(0); - Vmm v = Vmm(base_reg * 2); - if (base_reg) - uni_vaddps(b, b, v); - }); - uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); - - add(reg_coff, vlen); - cmp(reg_coff, reg_coff_max); - jl(ch_label); - } - } - - void var_channels() { - Label ch_label; - L(ch_label); { - uni_vmovups_maybe_tail(vmean, mean_ptr()); - uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); - spat_loop(spat_size, unroll_blocks, unroll_regs, - [=](size_t base_reg) { - Vmm v = Vmm(base_reg * 3); - if (base_reg > 0) - uni_vpxor(v, v, v); - }, - [=](size_t base_reg, size_t i) { - Vmm v = Vmm(3 * base_reg); - Vmm vtmp0 = Vmm(3 * base_reg + 1); - Vmm vtmp1 = Vmm(3 * base_reg + 2); - size_t offt = i * vlen; - uni_vmovups(vtmp0, - vmmword[reg_src + reg_soff + offt]); - if (isa == sse42) { - movups(vtmp1, vmean); - subps(vtmp1, vtmp0); - } else { - vsubps(vtmp1, vmean, vtmp0); - } - uni_vfmadd231ps(v, vtmp1, vtmp1); - - mic_prefetcht0(ptr[reg_src + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht1(ptr[reg_src + reg_soff + offt - + t1_pf_offt]); - }, - [=](size_t base_reg) { - Vmm b = Vmm(0); - Vmm v = Vmm(base_reg * 3); - if (base_reg) - uni_vaddps(b, b, v); - }); - uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); - add(reg_coff, vlen); - cmp(reg_coff, reg_coff_max); - jl(ch_label); - } - } - - void compute_mean_variance() { - uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); - xor_(reg_coff, reg_coff); - Label zero_rbuf; - L(zero_rbuf); { - uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); - add(reg_coff, isa == sse42 ? vlen / 2 : vlen); - cmp(reg_coff, reg_coff_max); - jne(zero_rbuf); - } - - mov(reg_src, ptr[rsp + stack_off_src]); - - xor_(reg_soff, reg_soff); - Label mean_spatial; - L(mean_spatial); { - xor_(reg_coff, reg_coff); - - if (isa == sse42) - mov(reg_tmp_off, reg_soff); - - mean_channels(); - - if (isa == sse42) { - mov(reg_soff, reg_tmp_off); - add(reg_src, vlen / 2); - mov(reg_coff, vlen / 2); - - mean_channels(); - - sub(reg_src, vlen / 2); - } - - add(reg_soff, reg_mb_stride_Bc); - cmp(reg_soff, reg_soff_max); - jne(mean_spatial); - } - - Label no_mean_reduction; - barrier(); { - mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); - cmp(reg_tmp, 0); - jne(no_mean_reduction); - mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); - xor_(reg_coff, reg_coff); - Label mean_reduction_channels; - L(mean_reduction_channels); { - mov(reg_roff, reg_coff); - uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); - uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); - mov(reg_ctr, reg_nnthr); - Label mean_reduction_thrs; - L(mean_reduction_thrs); { - uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); - uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0)); - add(reg_roff, reg_coff_max); - sub(reg_ctr, 1); - jnz(mean_reduction_thrs); - } - uni_vdivps(Vmm(1), Vmm(1), vchan_size); - uni_vmovups_maybe_tail(mean_ptr(), Vmm(1)); - - add(reg_coff, isa == sse42 ? vlen / 2 : vlen); - - cmp(reg_coff, reg_coff_max); - jne(mean_reduction_channels); - } - } - L(no_mean_reduction); - barrier(); - - xor_(reg_soff, reg_soff); - Label var_spatial; - L(var_spatial); { - xor_(reg_coff, reg_coff); - - if (isa == sse42) - mov(reg_tmp_off, reg_soff); - - var_channels(); - - if (isa == sse42) { - mov(reg_soff, reg_tmp_off); - add(reg_src, vlen / 2); - mov(reg_coff, vlen / 2); - - var_channels(); - - sub(reg_src, vlen / 2); - } - - add(reg_soff, reg_mb_stride_Bc); - cmp(reg_soff, reg_soff_max); - jne(var_spatial); - } - - Label no_var_reduction; - barrier(); { - mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); - cmp(reg_tmp, 0); - jne(no_var_reduction); - - mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); - xor_(reg_coff, reg_coff); - Label var_reduction_channels; - L(var_reduction_channels); { - mov(reg_roff, reg_coff); - uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); - mov(reg_ctr, reg_nnthr); - Label var_reduction_thrs; - L(var_reduction_thrs); { // TODO: unroll (?) - uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); - add(reg_roff, reg_coff_max); - sub(reg_ctr, 1); - jnz(var_reduction_thrs); - } - uni_vdivps(Vmm(1), Vmm(1), vchan_size); - uni_vmovups_maybe_tail(var_ptr(), Vmm(1)); - add(reg_coff, isa == sse42 ? vlen / 2 : vlen); - - cmp(reg_coff, reg_coff_max); - jne(var_reduction_channels); - } - } - L(no_var_reduction); - barrier(); - } - - void forward_channels() { - Label ch_label; - L(ch_label); { - uni_vmovups_maybe_tail(vmean, mean_ptr()); - uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); - uni_vaddps(vsqrtvar, vsqrtvar, veps); - uni_vsqrtps(vsqrtvar, vsqrtvar); - - if (bdesc_->use_scaleshift()) { - uni_vmovups_maybe_tail(vgamma, gamma_ptr()); - uni_vmovups_maybe_tail(vbeta, beta_ptr()); - } - - Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone; - Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar; - - if (isa == sse42) { - movups(vbuf, vscale); - divps(vbuf, vsqrtvar); - movups(vdiv, vbuf); - } else { - vdivps(vdiv, vscale, vsqrtvar); - } - - auto compute = [=](bool output_is_aligned) { - spat_loop(spat_size, unroll_blocks, unroll_regs, - [](size_t base_reg) {UNUSED(base_reg);}, - [=](size_t base_reg, size_t i) { - Vmm v = Vmm(base_reg); - size_t offt = i * vlen; - uni_vmovups(v, - vmmword[reg_src + reg_soff + offt]); - mic_prefetcht0(ptr[reg_src + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht1(ptr[reg_src + reg_soff + offt - + t1_pf_offt]); - uni_vsubps(v, v, vmean); - if (bdesc_->use_scaleshift()) { - uni_vfmadd213ps(v, vgamma, vbeta); - } else { - uni_vmulps(v, v, vsqrtvar); - } - if (with_relu_inf_only) { - uni_vmaxps(v, v, vzero); - } else if (with_relu) { - if (isa == avx512_common) - fwd_process_relu_avx512_common(v, offt); - else - fwd_process_relu_avx2(v, offt, Vmm(3)); - } - if (output_is_aligned) { - uni_vmovntps( - vmmword[reg_dst + reg_soff + offt], v); - } else { - uni_vmovups( - vmmword[reg_dst + reg_soff + offt], v); - } - }, - [](size_t base_reg) {UNUSED(base_reg);}); - }; - - Label unaligned_store, end_store; - test(reg_dst, vlen - 1); - jnz(unaligned_store, T_NEAR); - compute(true); - jmp(end_store, T_NEAR); - L(unaligned_store); { - compute(false); - } - L(end_store); - - add(reg_coff, vlen); - cmp(reg_coff, reg_coff_max); - jl(ch_label); - } - } - - void forward() { - mov(reg_src, ptr[rsp + stack_off_src]); - mov(reg_dst, ptr[rsp + stack_off_dst]); - mov(reg_ws, ptr[rsp + stack_off_ws]); - - xor_(reg_soff, reg_soff); - Label dst_spatial; - L(dst_spatial); { - xor_(reg_coff, reg_coff); - if (isa == sse42) - mov(reg_tmp_off, reg_soff); - - forward_channels(); - - if (isa == sse42) { - mov(reg_soff, reg_tmp_off); - add(reg_src, vlen / 2); - add(reg_dst, vlen / 2); - mov(reg_coff, vlen / 2); - - forward_channels(); - - sub(reg_src, vlen / 2); - sub(reg_dst, vlen / 2); - } - - add(reg_soff, reg_mb_stride_Bc); - cmp(reg_soff, reg_soff_max); - jnz(dst_spatial); - } - } - - void backward_sh_channels() { - Label sh_channels; - L(sh_channels); { - uni_vmovups_maybe_tail(vmean, mean_ptr()); - uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); - uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]); - spat_loop(spat_size, 1, 1, - [=](size_t base_reg) { - if (base_reg > 0) { - for (int i = 0; i < 2; i++) { - Vmm v(base_reg * 5 + i); - uni_vpxor(v, v, v); - } - } - }, - [=](size_t base_reg, size_t i) { - Vmm o0 = Vmm(base_reg * 5 + 0); - Vmm o1 = Vmm(base_reg * 5 + 1); - Vmm t1 = Vmm(base_reg * 5 + 2); - Vmm t2 = Vmm(base_reg * 5 + 3); - Vmm t3 = Vmm(base_reg * 5 + 4); - size_t offt = i * vlen; - uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]); - uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff - + offt]); - if (with_relu) { - if (isa == avx512_common) - bwd_process_relu_avx512_common(t2, offt); - else if (isa == avx2) - bwd_process_relu_avx2(t2, offt, t3); - else - assert(false); - } - uni_vsubps(t3, vmean, t1, t3); - if (isa == sse42) { - mulps(t3, t2); - subps(o0, t3); - } else { - vfnmadd231ps(o0, t3, t2); - } - uni_vaddps(o1, o1, t2); - mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht0(ptr[reg_src + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt - + t1_pf_offt]); - mic_prefetcht1(ptr[reg_src + reg_soff + offt - + t1_pf_offt]); - }, - [=](size_t base_reg) { - Vmm b0 = Vmm(0); - Vmm b1 = Vmm(1); - if (base_reg) { - uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0)); - uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1)); - } - }); - uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); - uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1)); - add(reg_coff, vlen); - cmp(reg_coff, reg_coff_max); - jl(sh_channels); - } - } - - void backward_diff_channels() { - Label diff_channels; - L(diff_channels); { - uni_vmovups_maybe_tail(vmean, mean_ptr()); - uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); - uni_vaddps(vsqrtvar, vsqrtvar, veps); - uni_vsqrtps(vsqrtvar, vsqrtvar); - uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); - if (bdesc_->use_scaleshift()) - uni_vmovups_maybe_tail(vgamma, gamma_ptr()); - uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr()); - uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr()); - uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); - uni_vdivps(vdiff_beta, vdiff_beta, vchan_size); - uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size); - - auto compute = [=](bool output_is_aligned) { - spat_loop(spat_size, unroll_blocks, unroll_regs, - [=](size_t base_reg) {UNUSED(base_reg);}, - [=](size_t base_reg, size_t i) { - Vmm v(base_reg * 2 + 0); - Vmm t(base_reg * 2 + 1); - Vmm t1(base_reg * 2 + 2); - size_t offt = i * vlen; - uni_vmovups(v, vmmword[reg_diff_dst + reg_soff - + offt]); - if (with_relu) { - if (isa == avx512_common) - bwd_process_relu_avx512_common(v, offt); - else if (isa == avx2) - bwd_process_relu_avx2(v, offt, t); - else - assert(false); - } - if (!bdesc_->use_global_stats()) { - uni_vsubps(v, v, vdiff_beta); - uni_vmovups(t, vmmword[reg_src + reg_soff - + offt]); - uni_vsubps(t, vmean, t, t1); - uni_vmulps(t, t, vdiff_gamma); - uni_vaddps(v, v, t); - } - uni_vmulps(v, v, vsqrtvar); - if (bdesc_->use_scaleshift()) { - uni_vmulps(v, v, vgamma); - } - if (output_is_aligned) { - uni_vmovntps( - vmmword[reg_diff_src + reg_soff + offt], - v); - } else { - uni_vmovups( - vmmword[reg_diff_src + reg_soff + offt], - v); - } - mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht0(ptr[reg_src + reg_soff + offt - + t0_pf_offt]); - mic_prefetcht1(ptr[reg_diff_dst + reg_soff - + offt + t1_pf_offt]); - mic_prefetcht1(ptr[reg_src + reg_soff + offt - + t1_pf_offt]); - }, - [=](size_t base_reg) {UNUSED(base_reg);}); - }; - - Label unaligned_store, end_store; - test(reg_diff_src, vlen - 1); - jnz(unaligned_store, T_NEAR); - compute(true); - jmp(end_store, T_NEAR); - L(unaligned_store); { - compute(false); - } - L(end_store); - - add(reg_coff, vlen); - cmp(reg_coff, reg_coff_max); - jl(diff_channels); - } - } - - void backward() { - uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); - xor_(reg_coff, reg_coff); - Label zero_rbuf, sh_spatial; - - L(zero_rbuf); { - uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); - uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0)); - add(reg_coff, isa == sse42 ? vlen / 2 : vlen); - cmp(reg_coff, reg_coff_max); - jne(zero_rbuf); - } - - mov(reg_src, ptr[rsp + stack_off_src]); - mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); - if (with_relu) { - assert(isa == avx2 || isa == avx512_common); - mov(reg_ws, ptr[rsp + stack_off_ws]); - } - - xor_(reg_soff, reg_soff); - L(sh_spatial); { - xor_(reg_coff, reg_coff); - if (isa == sse42) { - mov(reg_tmp_off, reg_soff); - } - backward_sh_channels(); - if (isa == sse42) { - mov(reg_soff, reg_tmp_off); - add(reg_diff_dst, vlen / 2); - add(reg_src, vlen / 2); - mov(reg_coff, vlen / 2); - backward_sh_channels(); - sub(reg_diff_dst, vlen / 2); - sub(reg_src, vlen / 2); - } - add(reg_soff, reg_mb_stride_Bc); - cmp(reg_soff, reg_soff_max); - jne(sh_spatial); - } - - mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]); - - Label no_sh_reduction; - barrier(); { - mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); - cmp(reg_tmp, 0); - Label sh_reduction_channels; - jne(no_sh_reduction, T_NEAR); - - mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); - xor_(reg_coff, reg_coff); - L(sh_reduction_channels); { - mov(reg_roff, reg_coff); - uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); - uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); - uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); - uni_vaddps(vsqrtvar, vsqrtvar, veps); - uni_vsqrtps(vsqrtvar, vsqrtvar); - uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); - mov(reg_ctr, reg_nnthr); - Label sh_reduction_thrs; - L(sh_reduction_thrs); { // TODO: unroll (?) - uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]); - uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]); - add(reg_roff, reg_coff_max); - sub(reg_ctr, 1); - jnz(sh_reduction_thrs); - } - uni_vmulps(Vmm(0), Vmm(0), vsqrtvar); - uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0)); - uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1)); - add(reg_coff, isa == sse42 ? vlen / 2 : vlen); - cmp(reg_coff, reg_coff_max); - jne(sh_reduction_channels); - } - } - L(no_sh_reduction); - barrier(); - - mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); - if (with_relu) { - assert(isa == avx2 || isa == avx512_common); - mov(reg_ws, ptr[rsp + stack_off_ws]); - } - - xor_(reg_soff, reg_soff); - Label diff_spatial; - L(diff_spatial); { - xor_(reg_coff, reg_coff); - if (isa == sse42) { - mov(reg_tmp_off, reg_soff); - } - backward_diff_channels(); - if (isa == sse42) { - mov(reg_soff, reg_tmp_off); - add(reg_diff_dst, vlen / 2); - add(reg_diff_src, vlen / 2); - add(reg_src, vlen / 2); - mov(reg_coff, vlen / 2); - backward_diff_channels(); - sub(reg_diff_dst, vlen / 2); - sub(reg_diff_src, vlen / 2); - sub(reg_src, vlen / 2); - } - add(reg_soff, reg_mb_stride_Bc); - cmp(reg_soff, reg_soff_max); - jne(diff_spatial); - } - } - - jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) { - static_assert(isa == sse42 || isa == avx2 || isa == avx512_common - || isa == avx512_mic, "unsupported isa"); - - const int simd_w = isa == sse42 ? 8 : - cpu_isa_traits::vlen / sizeof(data_t); - is_spatial_thr_ = - bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t)); - - unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; - unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; - - preamble(); - - if (isa == avx512_common) - prepare_tail_mask_avx512_common(); - else if (isa == avx2) - prepare_tail_mask_avx2_common(); - - compute_static_strides(); - sub(rsp, stack_size_required); - load_common_params(); - prepare_relu(); - - if (bdesc_->is_fwd()) { - if (!bdesc_->stats_is_src()) { - compute_mean_variance(); - } - forward(); - } else { - backward(); - } - add(rsp, stack_size_required); - postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); - } -}; - -template -struct uni_bnorm_driver_t: public c_compatible { - uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc) - : bdesc_(bdesc), ker_(bdesc_) - { - const int nthrs = mkldnn_get_max_threads(); - const dim_t C_PADDED = get_c_padded(bdesc_); - - size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED - * bdesc_->D() * bdesc_->H() * bdesc_->W(); - l3_size_ = get_cache_size(3, true) * nthrs / 2; - do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0); - } - - ~uni_bnorm_driver_t() {} - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const batch_normalization_pd_t *bdesc) { - int nthrs = mkldnn_get_max_threads(); - dim_t C_PADDED = get_c_padded(bdesc); - - int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED; - int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED; - int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs; - - scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz); - scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz); - scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz); - - if (mkldnn_thr_syncable()) { - int n_barriers = C_PADDED / simd_w; - scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers); - } - } - - void exec(int ithr, int nthr, const data_t *src, data_t *diff_src, - data_t *dst, const data_t *diff_dst, const data_t *scale_shift, - data_t *diff_scale_shift, const data_t *mean, const data_t *var, - const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) { - auto sbuf = scratchpad.get(key_bnorm_tmp_stats); - auto pbuf = scratchpad.get(key_bnorm_tmp_diff_ss); - auto rbuf = scratchpad.get(key_bnorm_reduction); - auto barriers = scratchpad.get(key_barrier); - - dim_t N = bdesc_->MB(); - dim_t C = bdesc_->C(); - dim_t C_PADDED = get_c_padded(bdesc_); - dim_t D = bdesc_->D(); - dim_t H = bdesc_->H(); - dim_t W = bdesc_->W(); - dim_t SP = D * H * W; - dim_t img_size = C_PADDED * D * H * W; - const int vlen = isa == sse42 ? 32 : cpu_isa_traits::vlen; - - typename jit_bnorm_t::call_params_t p; - - p.eps = bdesc_->desc()->batch_norm_epsilon; - p.one = 1.0f; - p.spat_size = D * H * W; - p.chan_size = 1.0f * N * p.spat_size; - - dim_t C_blks = C_PADDED / simd_w; - - int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0}; - dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0}; - - dim_t C_blks_per_iter{ 1 }; - int64_t iters{ 1 }; - if (do_blocking_) { - int num_tensors = bdesc_->is_fwd() ? 1 : 2; - size_t working_set_size - = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors; - bnorm_utils::cache_balance(working_set_size, C_blks, - C_blks_per_iter, iters); - } - - bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, - true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks, - SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, - S_ithr, S_nthr, S_s, S_e); - - int SP_N_ithr = N_ithr * S_nthr + S_ithr; - int SP_N_nthr = N_nthr * S_nthr; - assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1)); - - p.N_ithr = SP_N_ithr; - p.N_nthr = SP_N_nthr; - - int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter; - int global_C_blk_s; - int global_barriers_per_iter = C_nthr; - - for (int64_t it = 0; it < iters; it++) { - if (it == iters - 1 && iters > 1) { - C_blk_s = C_blk_e = N_s = N_e = 0; - spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, - spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, - C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, - N_e, S_ithr, S_nthr, S_s, S_e); - - // Update call parameters for JIT, last iteration - p.N_ithr = N_ithr * S_nthr + S_ithr; - p.N_nthr = N_nthr * S_nthr; - } - - global_C_blk_s = do_blocking_ ? - (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s : - C_blk_s; - - int C_blks_thr = C_blk_e - C_blk_s; - int N_thr = N_e - N_s; - - size_t coff_base = global_C_blk_s * simd_w; - size_t soff_base - = global_C_blk_s * p.spat_size * simd_w + N_s * img_size; - - p.spat_size_loc = S_e - S_s; - p.S_s = S_s * vlen; - p.S_tail = (p.spat_size - S_e) * vlen; - p.coff_max = C_blks_thr * simd_w; - p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base; - p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base; - p.scale_shift = scale_shift + coff_base; - p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_) - ? pbuf : diff_scale_shift) + coff_base; - - p.soff_max = N_thr * img_size; - p.src = src + soff_base; - p.dst = dst + soff_base; - p.diff_src = diff_src + soff_base; - p.diff_dst = diff_dst + soff_base; - p.ws = ws + soff_base / 8; - - p.mb_stride_Bc = img_size - p.coff_max * p.spat_size; - - // use SP_N_nthr which is the same as p.N_nthr except maybe for - // the last iteration. - p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr - + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w; - // rbuf1 and rbuf2 have to be disjoint - p.rbuf2 = p.rbuf1 + C_PADDED * nthr; - p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C; - - size_t iter_bariers - = do_blocking_ ? it * global_barriers_per_iter : 0; - p.barrier = barriers + C_ithr + iter_bariers; - if (p.soff_max != 0 && p.coff_max != 0) - ker_(&p); - } - } - - void init_barriers(const memory_tracking::grantor_t &scratchpad) { - auto barriers = scratchpad.get(key_barrier); - if (barriers) { - const int n_barriers = get_c_padded(bdesc_) / simd_w; - for (int i = 0; i < n_barriers; ++i) - barrier::ctx_init(&barriers[i]); - } - } - -private: - enum { - simd_w = isa == sse42 ? 8 : cpu_isa_traits::vlen / sizeof(data_t) - }; - - static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) { - return true - && !bdesc->stats_is_src() - && bdesc->desc()->prop_kind == prop_kind::forward_inference; - } - - static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc) - { - return false - || (bdesc->is_bwd() && !bdesc->use_scaleshift()) - || bdesc->desc()->prop_kind == prop_kind::backward_data; - } - - static dim_t get_c_padded(const batch_normalization_pd_t *bdesc) - { return bdesc->src_md()->padded_dims[1]; } - - const batch_normalization_pd_t *bdesc_; - bool do_blocking_; - size_t l3_size_; - - jit_bnorm_t ker_; -}; - -} - -using namespace data_type; -using namespace format_tag; -using namespace utils; - -/* fwd */ - -template -status_t jit_uni_batch_normalization_fwd_t::pd_t::init() { - auto desired_fmt_tag = (ndims() == 4) - ? isa == avx512_common ? nChw16c : nChw8c - : isa == avx512_common ? nCdhw16c : nCdhw8c; - - bool ok = true - && mayiuse(isa) - && is_fwd() - && !has_zero_dim_memory() - && one_of(ndims(), 4, 5) - && src_md()->data_type == f32 - && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) - && memory_desc_matches_tag(*src_md(), desired_fmt_tag) - && (attr()->has_default_values() || this->with_relu_post_op()); - if (!ok) return status::unimplemented; - - if (is_training() && fuse_bn_relu()) { - if (isa < avx2) return status::unimplemented; - init_default_ws(1); - } - - if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() - && isa < avx2) - return status::unimplemented; - - auto scratchpad = scratchpad_registry().registrar(); - uni_bnorm_driver_t::init_scratchpad(scratchpad, this); - - return status::success; -} - -template -jit_uni_batch_normalization_fwd_t::jit_uni_batch_normalization_fwd_t( - const pd_t *apd): cpu_primitive_t(apd) -{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } - -template -status_t jit_uni_batch_normalization_fwd_t::execute( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - - auto mean = pd()->stats_is_src() - ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)) - : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); - auto var = pd()->stats_is_src() - ? const_cast(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)) - : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); - - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); - - auto scratchpad = this->scratchpad(ctx); - - bnorm_driver_->init_barriers(scratchpad); - - parallel(0, [&](const int ithr, const int nthr) { - bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, - scale_shift, nullptr, mean, var, ws, scratchpad); - }); - - return status::success; -} - -template -jit_uni_batch_normalization_fwd_t::~jit_uni_batch_normalization_fwd_t() -{ delete bnorm_driver_; } - -/* bwd */ - -template -status_t jit_uni_batch_normalization_bwd_t::pd_t::init() { - auto desired_fmt_tag = (ndims() == 4) - ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c - : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c; - - bool ok = true - && mayiuse(isa) - && is_bwd() - && !has_zero_dim_memory() - && one_of(ndims(), 4, 5) - && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type) - && IMPLICATION(use_scaleshift(), - utils::everyone_is(f32, - weights_md()->data_type, - diff_weights_md()->data_type)) - && memory_desc_matches_tag(*src_md(), desired_fmt_tag) - && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() - && isa < avx2) - return status::unimplemented; - - if (fuse_bn_relu()) { - if (isa < avx2) return status::unimplemented; - init_default_ws(1); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - /* TODO: extra checks required */ - - auto scratchpad = scratchpad_registry().registrar(); - uni_bnorm_driver_t::init_scratchpad(scratchpad, this); - - return status::success; -} - -template -jit_uni_batch_normalization_bwd_t::jit_uni_batch_normalization_bwd_t( - const pd_t *apd): cpu_primitive_t(apd) -{ bnorm_driver_ = new uni_bnorm_driver_t(pd()); } - -template -status_t jit_uni_batch_normalization_bwd_t::execute( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); - auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); - - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); - - auto scratchpad = this->scratchpad(ctx); - - bnorm_driver_->init_barriers(scratchpad); - - parallel(0, [&](const int ithr, const int nthr) { - bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, - scale_shift, diff_scale_shift, mean, var, ws, scratchpad); - }); - - return status::success; -} - -template -jit_uni_batch_normalization_bwd_t::~jit_uni_batch_normalization_bwd_t() -{ delete bnorm_driver_; } - -/* struct instantiation */ -template struct jit_uni_batch_normalization_fwd_t; -template struct jit_uni_batch_normalization_bwd_t; -template struct jit_uni_batch_normalization_fwd_t; -template struct jit_uni_batch_normalization_bwd_t; -template struct jit_uni_batch_normalization_fwd_t; -template struct jit_uni_batch_normalization_bwd_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp deleted file mode 100644 index 96410ec84..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp +++ /dev/null @@ -1,100 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP -#define JIT_UNI_BATCH_NORMALIZATION_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_batch_normalization_pd.hpp" -#include "cpu_isa_traits.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace { template struct uni_bnorm_driver_t; } - -template -struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_batch_normalization_fwd_pd_t { - pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_batch_normalization_fwd_t); - - status_t init(); - }; - - typedef typename prec_traits::type data_t; - - jit_uni_batch_normalization_fwd_t(const pd_t *apd); - ~jit_uni_batch_normalization_fwd_t(); - - virtual status_t execute(const exec_ctx_t &ctx) const override; - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - uni_bnorm_driver_t *bnorm_driver_; -}; - -template -struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_batch_normalization_bwd_pd_t { - pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_batch_normalization_bwd_t); - - status_t init(); - }; - - typedef typename prec_traits::type data_t; - - jit_uni_batch_normalization_bwd_t(const pd_t *apd); - ~jit_uni_batch_normalization_bwd_t(); - - virtual status_t execute(const exec_ctx_t &ctx) const override; - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - uni_bnorm_driver_t *bnorm_driver_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp deleted file mode 100644 index b7dc5f85c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp +++ /dev/null @@ -1,1302 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "cpu_memory.hpp" - -#include "jit_uni_dw_conv_kernel_f32.hpp" - -#define GET_OFF(field) offsetof(jit_conv_call_s, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::prop_kind; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -using namespace Xbyak; - -template -void jit_uni_dw_conv_fwd_kernel_f32::load_src(int ur_ch_blocks, int ur_w) { - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - for (int ow = 0; ow < ur_w; ow++) { - Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); - - int b_off = ch*jcp.ch_block + i*4; - if (this->jcp.with_bias) - uni_vmovups(vmm_acc, - vmmword[reg_bias + b_off*sizeof(float)]); - else - uni_vpxor(vmm_acc, vmm_acc, vmm_acc); - - int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block - + ow*jcp.ch_block + i*4; - if (this->jcp.with_sum) - uni_vaddps(vmm_acc, vmm_acc, - vmmword[reg_output + o_off*sizeof(float)]); - } - } - } -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::apply_filter( - int ur_ch_blocks, int ur_w) { - int ch_blk = jcp.ch_block; - int dilate_h = jcp.dilate_h + 1; - int dilate_w = jcp.dilate_w + 1; - int stride_w = jcp.stride_w; - - Label iter_exit_label; - - cmp(reg_kh, 0); - je(iter_exit_label, T_NEAR); - cmp(reg_kw, 0); - je(iter_exit_label, T_NEAR); - - mov(iter_kh, reg_kh); - Label kh_label; - L(kh_label); { - mov(iter_kw, reg_kw); - mov(aux1_reg_input, aux_reg_input); - mov(aux1_reg_kernel, aux_reg_kernel); - - Label kw_label; - L(kw_label); { - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4; - Vmm vmm_ker = get_ker_reg(0); - uni_vmovups(vmm_ker, ptr[aux1_reg_kernel - + ker_off*sizeof(float)]); - - for (int ow = 0; ow < ur_w; ow++) { - int inp_off = ch*jcp.ih*jcp.iw*ch_blk - + ow*stride_w*ch_blk + i*4; - Vmm vmm_src = get_src_reg(0); - uni_vmovups(vmm_src, ptr[aux1_reg_input - + inp_off*sizeof(float)]); - - Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w - + ch*ur_w + ow); - uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); - } - } - } - add(aux1_reg_kernel, ch_blk*sizeof(float)); - add(aux1_reg_input, ch_blk*dilate_w*sizeof(float)); - - dec(iter_kw); - cmp(iter_kw, 0); - jg(kw_label, T_NEAR); - } - add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); - add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); - - dec(iter_kh); - cmp(iter_kh, 0); - jg(kh_label, T_NEAR); - } - - L(iter_exit_label); -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::apply_filter_unrolled( - int ur_ch_blocks, int ur_w) { - int ch_blk = jcp.ch_block; - int dilate_h = jcp.dilate_h + 1; - int dilate_w = jcp.dilate_w + 1; - int stride_w = jcp.stride_w; - - Label iter_exit_label; - - cmp(reg_kh, 0); - je(iter_exit_label, T_NEAR); - - mov(iter_kh, reg_kh); - Label kh_label; - L(kh_label); { - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - for (int kw = 0; kw < jcp.kw; kw++) { - int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4; - - Vmm vmm_ker = get_ker_reg(0); - uni_vmovups(vmm_ker, ptr[aux_reg_kernel - + ker_off*sizeof(float)]); - - for (int ow = 0; ow < ur_w; ow++) { - int inp_off = ch*jcp.ih*jcp.iw*ch_blk - + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4; - - Vmm vmm_src = get_src_reg(0); - uni_vmovups(vmm_src, ptr[aux_reg_input - + inp_off*sizeof(float)]); - - Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w - + ch*ur_w + ow); - uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); - } - } - } - } - - add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); - add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); - - dec(iter_kh); - cmp(iter_kh, 0); - jg(kh_label, T_NEAR); - } - - L(iter_exit_label); -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::apply_activation( - int ur_ch_blocks, int ur_w) { - if (this->jcp.with_eltwise) { - int repeats = isa == sse42 ? 2 : 1; - eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4); - } -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::store_dst( - int ur_ch_blocks, int ur_w) { - int ch_blk = jcp.ch_block; - - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - for (int ow = 0; ow < ur_w; ow++) { - int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4; - Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); - - uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst); - } - } - } -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::loop_body(int ur_ch_blocks) { - Label unrolled_w_label; - Label tail_w_label; - Label exit_label; - - L(unrolled_w_label); { - int ur_w = jcp.ur_w; - - cmp(reg_ur_w, ur_w); - jl(tail_w_label, T_NEAR); - - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - - load_src(ur_ch_blocks, ur_w); - apply_filter_unrolled(ur_ch_blocks, ur_w); - apply_activation(ur_ch_blocks, ur_w); - store_dst(ur_ch_blocks, ur_w); - - add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); - add(reg_output, sizeof(float) * ur_w * jcp.ch_block); - - sub(reg_ur_w, ur_w); - jmp(unrolled_w_label); - } - - L(tail_w_label); { - int ur_w = 1; - - cmp(reg_ur_w, ur_w); - jl(exit_label, T_NEAR); - - mov(aux_reg_input, reg_input); - mov(aux_reg_kernel, reg_kernel); - - load_src(ur_ch_blocks, ur_w); - apply_filter(ur_ch_blocks, ur_w); - apply_activation(ur_ch_blocks, ur_w); - store_dst(ur_ch_blocks, ur_w); - - add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); - add(reg_output, sizeof(float) * ur_w * jcp.ch_block); - - sub(reg_ur_w, ur_w); - jmp(tail_w_label); - } - - L(exit_label); -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::generate() { - this->preamble(); - - mov(reg_input, ptr[this->param1 + GET_OFF(src)]); - mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); - if (jcp.with_bias) - mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); - mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); - mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); - mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); - mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); - - Label ch_blocks_tail_label; - Label exit_label; - - int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; - - cmp(reg_ch_blocks, jcp.nb_ch_blocking); - jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); - - loop_body(jcp.nb_ch_blocking); // channel main loop - - if (ch_blocks_tail) { - L(ch_blocks_tail_label); - - cmp(reg_ch_blocks, ch_blocks_tail); - jne(exit_label, T_NEAR); - - loop_body(ch_blocks_tail); // channel tail loop - } - - L(exit_label); - - this->postamble(); - - if (jcp.with_eltwise) - eltwise_injector_->prepare_table(); -} - -template -bool jit_uni_dw_conv_fwd_kernel_f32::post_ops_ok( - jit_conv_conf_t &jcp, const primitive_attr_t &attr) { - const auto &p = attr.post_ops_; - - auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; - - switch (p.len_) { - case 0: return true; // no post_ops - case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise - case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise - default: return false; - } - - return false; -} - -template -status_t jit_uni_dw_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, - const primitive_attr_t &attr) -{ - if (!mayiuse(isa)) return status::unimplemented; - - const int simd_w = isa == avx512_common ? 16 : 8; - - jcp.prop_kind = cd.prop_kind; - - const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; - if (!with_groups) return status::unimplemented; - - jcp.ngroups = weights_d.dims()[0]; - jcp.mb = src_d.dims()[0]; - - jcp.oc = dst_d.dims()[1]; - jcp.oc_without_padding = jcp.oc; - jcp.ic = src_d.dims()[1]; - - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = dst_d.dims()[2]; - jcp.ow = dst_d.dims()[3]; - - jcp.kh = weights_d.dims()[3]; - jcp.kw = weights_d.dims()[4]; - - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.b_pad = cd.padding[1][0]; - jcp.r_pad = cd.padding[1][1]; - - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - - if (!post_ops_ok(jcp, attr)) - return status::unimplemented; - - const auto &p = attr.post_ops_; - jcp.with_sum = p.find(primitive_kind::sum) != -1; - const int eltwise_ind = p.find(primitive_kind::eltwise); - jcp.with_eltwise = eltwise_ind != -1; - if (jcp.with_eltwise) - jcp.eltwise = p.entry_[eltwise_ind].eltwise; - - bool ok_to_pad_channels = true - && jcp.oc == jcp.ngroups - && jcp.ic == jcp.ngroups - && one_of(isa, avx512_common, avx2); - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.oc, simd_w); - jcp.ngroups = rnd_up(jcp.ngroups, simd_w); - } - - auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; - auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; - - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.oc == jcp.ngroups - && jcp.ic == jcp.ngroups - && jcp.ngroups % simd_w == 0 - && jcp.src_tag == dat_tag - && jcp.wei_tag == wei_tag - && jcp.dst_tag == dat_tag - && jcp.ic <= src_d.padded_dims()[1] - && jcp.oc <= dst_d.padded_dims()[1] - && jcp.ngroups <= weights_d.padded_dims()[0]; - if (!args_ok) return status::unimplemented; - - jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; - - jcp.ch_block = simd_w; - jcp.nb_ch = jcp.oc / jcp.ch_block; - jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; - if (jcp.nb_ch < jcp.nb_ch_blocking) - jcp.nb_ch_blocking = jcp.nb_ch; - - return status::success; -} - -template -void jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) - scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); -} - -template struct jit_uni_dw_conv_fwd_kernel_f32; -template struct jit_uni_dw_conv_fwd_kernel_f32; -template struct jit_uni_dw_conv_fwd_kernel_f32; - -template -inline void jit_uni_dw_conv_bwd_data_kernel_f32::load_ddst( - int ur_ch_blocks, int ur_str_w) { - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - for (int w = 0; w < ur_str_w; w++) { - Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w - + ch*ur_str_w + w); - uni_vpxor(vmm_acc, vmm_acc, vmm_acc); - } - } - } -} - -template -inline void jit_uni_dw_conv_bwd_data_kernel_f32::apply_filter( - int ur_ch_blocks, int ur_str_w) { - int kw = jcp.kw; - int kh = jcp.kh; - int ow = jcp.ow; - int oh = jcp.oh; - - int ch_blk = jcp.ch_block; - int stride_h = jcp.stride_h; - int stride_w = jcp.stride_w; - - Label iter_exit_label; - - cmp(reg_kh, 0); - je(iter_exit_label, T_NEAR); - - cmp(reg_kw, 0); - je(iter_exit_label, T_NEAR); - - mov(iter_kh, reg_kh); - Label kh_label; - L(kh_label); { - mov(aux1_reg_ddst, aux_reg_ddst); - mov(aux1_reg_kernel, aux_reg_kernel); - - mov(iter_kw, reg_kw); - Label kw_label; - L(kw_label); { - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - int ker_off = ch*kh*kw*ch_blk + i*4; - Vmm vmm_ker = get_ker_reg(0); - uni_vmovups(vmm_ker, ptr[aux1_reg_kernel - + ker_off*sizeof(float)]); - - for (int w = 0; w < ur_str_w; w++) { - int ddst_off = (ch*oh*ow + w)*ch_blk + i*4; - - Vmm vmm_src = get_src_reg(0); - uni_vmovups(vmm_src, ptr[aux1_reg_ddst - + ddst_off*sizeof(float)]); - - Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w - + ch*ur_str_w + w); - uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); - } - } - } - - add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float)); - sub(aux1_reg_ddst, ch_blk*sizeof(float)); - - sub(iter_kw, stride_w); - cmp(iter_kw, 0); - jg(kw_label, T_NEAR); - } - - add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float)); - sub(aux_reg_ddst, ow*ch_blk*sizeof(float)); - - sub(iter_kh, stride_h); - cmp(iter_kh, 0); - jg(kh_label, T_NEAR); - } - - L(iter_exit_label); -} - -template -inline void jit_uni_dw_conv_bwd_data_kernel_f32::store_dsrc( - int ur_ch_blocks, int ur_str_w) { - int ch_blk = jcp.ch_block; - int iw = jcp.iw; - int ih = jcp.ih; - int stride_w = jcp.stride_w; - - int repeats = isa == sse42 ? 2 : 1; - for (int i = 0; i < repeats; i++) { - for (int ch = 0; ch < ur_ch_blocks; ch++) { - for (int w = 0; w < ur_str_w; w++) { - int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4; - Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w - + ch*ur_str_w + w); - - uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc); - } - } - } -} - -template -inline void jit_uni_dw_conv_bwd_data_kernel_f32::loop_body( - int ur_ch_blocks) { - Label unrolled_w_label; - Label tail_w_label; - Label exit_label; - - L(unrolled_w_label); { - int ur_w = jcp.ur_w; - - cmp(reg_ur_str_w, ur_w); - jl(tail_w_label, T_NEAR); - - mov(aux_reg_ddst, reg_ddst); - mov(aux_reg_kernel, reg_kernel); - - load_ddst(ur_ch_blocks, ur_w); - apply_filter(ur_ch_blocks, ur_w); - store_dsrc(ur_ch_blocks, ur_w); - - add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); - add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); - - sub(reg_ur_str_w, ur_w); - jmp(unrolled_w_label); - } - - L(tail_w_label); { - int ur_w = 1; - - cmp(reg_ur_str_w, ur_w); - jl(exit_label, T_NEAR); - - mov(aux_reg_ddst, reg_ddst); - mov(aux_reg_kernel, reg_kernel); - - load_ddst(ur_ch_blocks, ur_w); - apply_filter(ur_ch_blocks, ur_w); - store_dsrc(ur_ch_blocks, ur_w); - - add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); - add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); - - sub(reg_ur_str_w, ur_w); - jmp(tail_w_label); - } - - L(exit_label); -} - -template -void jit_uni_dw_conv_bwd_data_kernel_f32::generate() { - preamble(); - - mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); - mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); - mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); - mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); - mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); - mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); - mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); - - Label ch_blocks_tail_label; - Label exit_label; - - int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; - - cmp(reg_ch_blocks, jcp.nb_ch_blocking); - jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); - - loop_body(jcp.nb_ch_blocking); // channel main loop - - if (ch_blocks_tail) { - L(ch_blocks_tail_label); - - cmp(reg_ch_blocks, ch_blocks_tail); - jne(exit_label, T_NEAR); - - loop_body(ch_blocks_tail); // channel tail loop - } - - L(exit_label); - - this->postamble(); -} - -template -status_t jit_uni_dw_conv_bwd_data_kernel_f32::init_conf( - jit_conv_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d) { - if (!mayiuse(isa)) return status::unimplemented; - - const int simd_w = isa == avx512_common ? 16 : 8; - - const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; - if (!with_groups) return status::unimplemented; - - jcp.ngroups = weights_d.dims()[0]; - jcp.mb = diff_src_d.dims()[0]; - - jcp.oc = diff_dst_d.dims()[1]; - jcp.oc_without_padding = jcp.oc; - jcp.ic = diff_src_d.dims()[1]; - - jcp.ih = diff_src_d.dims()[2]; - jcp.iw = diff_src_d.dims()[3]; - jcp.oh = diff_dst_d.dims()[2]; - jcp.ow = diff_dst_d.dims()[3]; - - jcp.kh = weights_d.dims()[3]; - jcp.kw = weights_d.dims()[4]; - - jcp.t_pad = cd.padding[0][0]; - jcp.l_pad = cd.padding[0][1]; - jcp.b_pad = cd.padding[1][0]; - jcp.r_pad = cd.padding[1][1]; - - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - - bool ok_to_pad_channels = true - && jcp.oc == jcp.ngroups - && jcp.ic == jcp.ngroups - && one_of(isa, avx512_common, avx2); - if (ok_to_pad_channels) { - jcp.oc = rnd_up(jcp.oc, simd_w); - jcp.ic = rnd_up(jcp.oc, simd_w); - jcp.ngroups = rnd_up(jcp.ngroups, simd_w); - } - - auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; - auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; - - jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.oc == jcp.ngroups - && jcp.ic == jcp.ngroups - && jcp.ngroups % simd_w == 0 - && jcp.dilate_h == 0 - && jcp.dilate_w == 0 - && jcp.src_tag == dat_tag - && jcp.wei_tag == wei_tag - && jcp.dst_tag == dat_tag - && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 - && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 - && jcp.ic <= diff_src_d.padded_dims()[1] - && jcp.oc <= diff_dst_d.padded_dims()[1] - && jcp.ngroups <= weights_d.padded_dims()[0]; - if (!args_ok) return status::unimplemented; - - jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; - - jcp.ch_block = simd_w; - jcp.nb_ch = jcp.ic / jcp.ch_block; - jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; - if (jcp.nb_ch < jcp.nb_ch_blocking) - jcp.nb_ch_blocking = jcp.nb_ch; - - return status::success; -} - -template -void jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - UNUSED(scratchpad); - UNUSED(jcp); -} - -template struct jit_uni_dw_conv_bwd_data_kernel_f32; -template struct jit_uni_dw_conv_bwd_data_kernel_f32; -template struct jit_uni_dw_conv_bwd_data_kernel_f32; - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_filter() { - for (int r = 0; r < reg_repeats; ++r) { - for (int i = 0; i < jcp.kw; ++i) { - Vmm vmm_acc = get_acc_reg(r * jcp.kw + i); - uni_vpxor(vmm_acc, vmm_acc, vmm_acc); - } - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_filter() { - for (int r = 0; r < reg_repeats; ++r) { - const int reg_set = r * jcp.kw; - for (int i = 0; i < jcp.kw; ++i) { - int off_filter = (reg_set + i) * simd_w; - Vmm vmm_acc = get_acc_reg(reg_set + i); - uni_vmovups(vmm_acc, - vmmword[reg_tmp_filter + off_filter * sizeof(float)]); - } - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::zero_bias() { - for (int r = 0; r < reg_repeats; ++r) { - Vmm vmm_bias = get_bias_reg(r); - uni_vpxor(vmm_bias, vmm_bias, vmm_bias); - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::load_bias() { - for (int r = 0; r < reg_repeats; ++r) { - Vmm vmm_bias = get_bias_reg(r); - uni_vmovups( - vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]); - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_step_unroll( - int unroll_w, int l_pad, int pad_offset, int ow_block) { - - const int iw_block = ow_block * jcp.stride_w; - const int right_border = jcp.iw - iw_block; - - const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); - - /* preamble count for number of cascaded LOAD + FMA operation */ - const int input_overlap = nstl::max(jcp.kw - l_pad, 0); - - /* LOAD initial input registers, then cascade LOADs and FMAs*/ - for (int r = 0; r < reg_repeats; ++r) { - for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { - int off_output = (i_ur * reg_repeats + r) * simd_w; - Vmm vmm_output = get_output_reg(r); - uni_vmovups(vmm_output, - ptr[reg_tmp_output + off_output * sizeof(float)]); - if (i_ur == 0) { - for (int c = 0; c < input_overlap; ++c) { - int off_input - = ((c - pad_offset) * reg_repeats + r) * simd_w; - Vmm vmm_input - = get_input_reg((c % jcp.kw) * reg_repeats + r); - uni_vmovups(vmm_input, - ptr[reg_tmp_input + off_input * sizeof(float)]); - } - } else { - for (int c = 0; c < cascade_input; ++c) { - int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; - int off_input - = ((overlap + c - pad_offset) * reg_repeats + r) - * simd_w; - Vmm vmm_input = get_input_reg( - ((overlap + c) % jcp.kw) * reg_repeats + r); - uni_vmovups(vmm_input, - ptr[reg_tmp_input + off_input * sizeof(float)]); - } - } - - for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { - int io_overlap = i_kw + (i_ur * jcp.stride_w); - - /* Don't apply FMAs that fall into the padded region */ - if (io_overlap - l_pad < 0 - || io_overlap - jcp.l_pad >= right_border) - continue; - - Vmm vmm_input = get_input_reg( - ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r); - Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r); - Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input; - if (isa == sse42) - uni_vmovups(vmm_aux, vmm_input); - uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output); - } - } - } -} - -template -inline void -jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_step_unroll( - const int unroll_w) { - for (int r = 0; r < reg_repeats; ++r) { - for (int i = 0; i < unroll_w; ++i) { - Vmm vmm_bias = get_bias_reg(r); - int off_output = (i * reg_repeats + r) * simd_w; - if (isa == sse42) { - /* Need to support unaligned address loads for SSE42*/ - Vmm vmm_output = get_output_reg(1 + r); - uni_vmovups(vmm_output, - ptr[reg_tmp_output + off_output * sizeof(float)]); - uni_vaddps(vmm_bias, vmm_bias, vmm_output); - } else { - uni_vaddps(vmm_bias, vmm_bias, - vmmword[reg_tmp_output + off_output * sizeof(float)]); - } - } - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_filter() { - for (int r = 0; r < reg_repeats; ++r) { - const int reg_set = r * jcp.kw; - for (int i = 0; i < jcp.kw; ++i) { - int off_filter = (i + reg_set) * simd_w; - Vmm vmm_acc = get_acc_reg(i + reg_set); - uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)], - vmm_acc); - } - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::store_bias() { - for (int r = 0; r < reg_repeats; ++r) { - Vmm vmm_bias = get_bias_reg(r); - uni_vmovups( - vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias); - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_bias_loop( - const int block_size) { - Label oh_label; - Label ow_blk_label; - - const int unroll_w = nstl::min(block_size, jcp.ow); - const int unroll_w_trips = jcp.ow / unroll_w; - const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0; - - const int ch_offset = jcp.ch_block; - - mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); - mov(reg_oh_worksize, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); - - mov(reg_tmp_output, reg_output_baddr); - L(oh_label); - { - - mov(iter_ow_blk, unroll_w_trips); - L(ow_blk_label); - { - - compute_bias_step_unroll(unroll_w); - add(reg_tmp_output, unroll_w * ch_offset * sizeof(float)); - - dec(iter_ow_blk); - cmp(iter_ow_blk, 0); - jg(ow_blk_label, T_NEAR); - } - - if (tail_w > 0) { - compute_bias_step_unroll(tail_w); - add(reg_tmp_output, tail_w * ch_offset * sizeof(float)); - } - - inc(reg_oh); - cmp(reg_oh, reg_oh_worksize); - jl(oh_label, T_NEAR); - } -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_zero_filter() { - - const int ch_offset = jcp.ch_block; - - Label kh_loop_label, skip_zeroing_label; - - mov(reg_exec_flags, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); - and_(reg_exec_flags, FLAG_ZERO_FILTER); - test(reg_exec_flags, reg_exec_flags); - je(skip_zeroing_label); - - zero_filter(); - - mov(reg_tmp_filter, reg_filter_baddr); - mov(reg_kh, jcp.kh); - L(kh_loop_label); - { - store_filter(); - - add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); - dec(reg_kh); - cmp(reg_kh, 0); - jg(kh_loop_label); - } - - /* Comeback pointers */ - sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float)); - - L(skip_zeroing_label); -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_step( - int unroll_w, int l_pad, int pad_offset, int ow_block) { - - const int ch_offset = jcp.ch_block; - - Label kh_loop_label, skip_loop_label; - - cmp(reg_kh_count, 0); - je(skip_loop_label, T_NEAR); - - mov(reg_kh, reg_kh_count); - L(kh_loop_label); - { - load_filter(); - compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block); - store_filter(); - - add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); - add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); - dec(reg_kh); - cmp(reg_kh, 0); - jg(kh_loop_label); - } - - /* Comeback pointers */ - Label kh_comeback_label; - mov(reg_kh, reg_kh_count); - L(kh_comeback_label); - { - sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); - sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); - dec(reg_kh); - cmp(reg_kh, 0); - jg(kh_comeback_label, T_NEAR); - } - - L(skip_loop_label); -} - -template -inline void jit_uni_dw_conv_bwd_weights_kernel_f32::compute_h_loop( - int unroll_w, int l_pad, int pad_offset, int ow_block) { - - const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ? - jcp.ih / jcp.stride_h - 1 : - jcp.oh - jcp.b_pad - 1; - const int ch_offset = jcp.ch_block; - const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; - const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; - - Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label, - end_h_loop_label; - - mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); - mov(reg_oh_worksize, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); - mov(reg_kh_count, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]); - - mov(reg_tmp_output, reg_output_baddr); - mov(reg_tmp_input, reg_input_baddr); - mov(reg_tmp_filter, reg_filter_baddr); - - L(h_loop_label); - { - - compute_h_step(unroll_w, l_pad, pad_offset, ow_block); - - add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float)); - - /* If within the top_pad region */ - if (jcp.t_pad > 0) { - /* Skip t_pad area if no longer in initial h_block */ - cmp(reg_oh, jcp.t_pad); - jg(skip_tpad_label, T_NEAR); - - cmp(reg_kh_count, jcp.kh); - jge(skip_tpad_label, T_NEAR); - - add(reg_kh_count, t_overlap_off); - sub(reg_tmp_filter, - t_overlap_off * jcp.kw * ch_offset * sizeof(float)); - - /* kernel has moved beyond padding (adjust for stride effects) */ - if (jcp.t_pad % jcp.stride_h != 0) { - int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; - add(reg_tmp_input, - inp_corr * jcp.iw * ch_offset * sizeof(float)); - } - jmp(tpad_loop_label, T_NEAR); - } - - L(skip_tpad_label); - - cmp(reg_oh, io_overlap); - jl(skip_bpad_label, T_NEAR); - sub(reg_kh_count, b_overlap_off); - - L(skip_bpad_label); - add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float)); - - L(tpad_loop_label); - - cmp(reg_oh, jcp.ih / jcp.stride_h); - jge(end_h_loop_label, T_NEAR); - - inc(reg_oh); - - cmp(reg_oh, reg_oh_worksize); - jl(h_loop_label, T_NEAR); - } - L(end_h_loop_label); -} - -template -inline void -jit_uni_dw_conv_bwd_weights_kernel_f32::compute_ow_block_unroll() { - - const int ch_offset = jcp.ch_block; - int ow = jcp.ow; - int pad_offset = 0; - int l_pad = jcp.l_pad; - - /* Calculate effective padding */ - int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - - (jcp.iw + jcp.l_pad - 1)); - - /* Is this strictly defined by: - * -code-size (?) - * -address size (?) */ - const int max_unroll_w = 30; - const int block_size = 15; - - int unroll_w_tail = 0; - int unroll_w = 0; - int unroll_w_trips = 0; - - if (jcp.ow > max_unroll_w) { - unroll_w = nstl::min(block_size, jcp.ow); - unroll_w_trips = ow / unroll_w; - /* calculate tail */ - unroll_w_tail = ow % unroll_w; - /* Perform some rebalancing if tail too small*/ - if ((unroll_w_tail == 0 && r_pad != 0) - || (r_pad > 0 && r_pad >= unroll_w_tail)) { - if (unroll_w_trips > 1) { - unroll_w_tail += unroll_w; - unroll_w_trips--; - } else { - /* Idealy, this case shouldn't happen */ - unroll_w_tail += (unroll_w - unroll_w / 2); - unroll_w = unroll_w / 2; - } - } - } else { - unroll_w = jcp.ow; - unroll_w_trips = nstl::max(1, ow / unroll_w); - } - if (jcp.with_bias) { - Label skip_load_bias; - mov(reg_bias_baddr, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]); - - zero_bias(); - - mov(reg_exec_flags, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); - and_(reg_exec_flags, FLAG_ZERO_BIAS); - test(reg_exec_flags, reg_exec_flags); - jne(skip_load_bias); - - load_bias(); - - L(skip_load_bias); - compute_bias_loop(block_size); - - store_bias(); - } - - /* Pass filter address, then offset for h_padding. */ - compute_zero_filter(); - mov(reg_kh_offset, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]); - add(reg_filter_baddr, reg_kh_offset); - - /* compute left padded block */ - if (l_pad) { - compute_h_loop(unroll_w, l_pad, 0, 0); - add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); - add(reg_input_baddr, - unroll_w * jcp.stride_w * ch_offset * sizeof(float)); - unroll_w_trips--; - pad_offset = l_pad; - l_pad = 0; - } - - /* compute middle block */ - Label ow_blk_label; - - /* Insert loop for 'ow' block when middle block needs to execute more - * than once */ - bool do_ow_blk_loop = unroll_w_trips > 1; - if (do_ow_blk_loop) { - mov(iter_ow_blk, unroll_w_trips); - L(ow_blk_label); - } - if (unroll_w_trips > 0) { - compute_h_loop(unroll_w, l_pad, pad_offset, 0); - add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); - add(reg_input_baddr, - unroll_w * jcp.stride_w * ch_offset * sizeof(float)); - } - if (do_ow_blk_loop) { - dec(iter_ow_blk); - cmp(iter_ow_blk, 0); - jg(ow_blk_label, T_NEAR); - } - - /* compute right padded block */ - if (unroll_w_tail) { - compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail); - } -} - -template -void jit_uni_dw_conv_bwd_weights_kernel_f32::generate() { - preamble(); - - mov(reg_input_baddr, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]); - mov(reg_output_baddr, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]); - mov(reg_filter_baddr, - ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]); - - compute_ow_block_unroll(); - - this->postamble(); -} - -template -status_t jit_uni_dw_conv_bwd_weights_kernel_f32::init_conf( - jit_conv_conf_t &jcp, const convolution_desc_t &cd, - const memory_desc_wrapper &src_d, - const memory_desc_wrapper &diff_weights_d, - const memory_desc_wrapper &diff_dst_d, int nthreads) { - if (!mayiuse(isa)) - return status::unimplemented; - - jcp.ngroups = diff_weights_d.dims()[0]; - jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; - jcp.ic = src_d.dims()[1] / jcp.ngroups; - - const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; - - jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic); - - if (!jcp.is_depthwise) - return status::unimplemented; - - jcp.ch_block = isa == avx512_common ? 16 : 8; - - jcp.mb = src_d.dims()[0]; - - jcp.ih = src_d.dims()[2]; - jcp.iw = src_d.dims()[3]; - jcp.oh = diff_dst_d.dims()[2]; - jcp.ow = diff_dst_d.dims()[3]; - - jcp.kh = diff_weights_d.dims()[3]; - jcp.kw = diff_weights_d.dims()[4]; - - jcp.stride_h = cd.strides[0]; - jcp.stride_w = cd.strides[1]; - - jcp.t_pad = cd.padding[0][0]; - jcp.b_pad = cd.padding[1][0]; - - jcp.l_pad = cd.padding[0][1]; - jcp.r_pad = cd.padding[1][1]; - - jcp.dilate_h = cd.dilates[0]; - jcp.dilate_w = cd.dilates[1]; - - jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; - jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; - - jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; - - auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; - auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; - - jcp.src_tag = src_d.matches_one_of_tag(dat_tag); - jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); - jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); - - bool args_ok = true - && jcp.src_tag == dat_tag - && jcp.wei_tag == wei_tag - && jcp.dst_tag == dat_tag - && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0 - && jcp.dilate_w == 0 && jcp.kw <= 3 - && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 - && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; - if (!args_ok) - return status::unimplemented; - - jcp.nb_ch = jcp.ngroups / jcp.ch_block; - - /* kernel applicability check wrt boundaries - * the conditions are quite general across the kernels we have, - * but ideally the check should belong to a specific kernel... */ - const int max_hpad = (jcp.kh - 1 + 1) / 2; - const int max_wpad = (jcp.kw - 1 + 1) / 2; - const bool boundaries_ok = true && jcp.t_pad <= max_hpad - && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad - && jcp.r_pad <= max_wpad; - if (!boundaries_ok) - return status::unimplemented; - - balance(jcp, nthreads); - - return status::success; -} - -template -void jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( - memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { - /* Notes: if splitting thread work on 'mb', then a reduction has to take - * place. Hence, book a per-thread, local weights-buffer for the - * reduction */ - if (jcp.nthr_mb > 1) { - const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; - scratchpad.book(key_conv_wei_reduction, - sizeof(float) * wei_size * (jcp.nthr_mb - 1)); - - if (jcp.with_bias) - scratchpad.book(key_conv_bia_reduction, - sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1)); - } -} - -template -void jit_uni_dw_conv_bwd_weights_kernel_f32::balance(jit_conv_conf_t &jcp, - int nthreads) { - jcp.nthr = nthreads; - jcp.nthr_g = jcp.nthr_mb = 1; - - /* Basic-Heuristics for parallel strategy: - * 1) Tries to parallel on the number of Groups (g) where tasks are - * independent. Otherwise, - * 2) Tries to split the work across g and MiniBatch (mb). - * Parallelizing on mb requires computing a reduction for weights. - * - * NOTE: because of 'task partitioning' scheme, there will be unbalanced - * per-thread load when the number of threads is high (e.g. > 16). - */ - jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr); - jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb); - - jcp.nthr = jcp.nthr_g * jcp.nthr_mb; -} - -template struct jit_uni_dw_conv_bwd_weights_kernel_f32; -template struct jit_uni_dw_conv_bwd_weights_kernel_f32; -template struct jit_uni_dw_conv_bwd_weights_kernel_f32; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp deleted file mode 100644 index 9c08fc4a0..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp +++ /dev/null @@ -1,253 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP -#define JIT_UNI_DW_CONV_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" -#include "jit_uni_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) - - jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp) - : jcp(ajcp), eltwise_injector_(nullptr) - { - if (jcp.with_eltwise) - eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, - jcp.eltwise); - - this->generate(); - jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); - } - - ~jit_uni_dw_conv_fwd_kernel_f32() { - delete eltwise_injector_; - } - - static bool post_ops_ok(jit_conv_conf_t &jcp, - const primitive_attr_t &attr); - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_conv_call_s *); - -private: - using Vmm = typename utils::conditional3::type; - using reg64_t = const Xbyak::Reg64; - const Xbyak::AddressFrame &vmmword = (isa == sse42) - ? xword : (isa == avx2) ? yword : zword; - const int vlen = cpu_isa_traits::vlen; - - // dw convolution - reg64_t reg_input = r8; - reg64_t aux_reg_input = r9; - reg64_t aux1_reg_input = r10; - reg64_t reg_kernel = r11; - reg64_t aux_reg_kernel = r12; - reg64_t aux1_reg_kernel = r13; - reg64_t reg_output = r14; - reg64_t reg_bias = r15; - reg64_t reg_kh = rax; - reg64_t reg_kw = rbx; - reg64_t iter_kh = rdx; - reg64_t iter_kw = rsi; - reg64_t reg_ur_w = rbp; - reg64_t reg_ch_blocks = aux1_reg_input; - reg64_t imm_addr64 = aux1_reg_input; - - inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } - inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } - inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } - - inline void load_src(int ur_ch_blocks, int ur_w); - inline void apply_filter(int ur_ch_blocks, int ur_w); - inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w); - inline void apply_activation(int ur_ch_blocks, int ur_w); - inline void store_dst(int ur_ch_blocks, int ur_w); - inline void loop_body(int ur_ch_blocks); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; - - void generate(); -}; - -template -struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) - - jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) { - this->generate(); - jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); - } - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, - const memory_desc_wrapper &diff_src_d, - const memory_desc_wrapper &weights_d, - const memory_desc_wrapper &diff_dst_d); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_conv_call_s *); - -private: - using Vmm = typename utils::conditional3::type; - using reg64_t = const Xbyak::Reg64; - - inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } - inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } - inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } - - reg64_t reg_ddst = rax; - reg64_t aux_reg_ddst = r8; - reg64_t aux1_reg_ddst = abi_not_param1; - reg64_t reg_kernel = rdx; - reg64_t aux_reg_kernel = r10; - reg64_t aux1_reg_kernel = rbp; - reg64_t reg_dsrc = rsi; - - reg64_t reg_ur_str_w = r9; - reg64_t reg_ch_blocks = rbx; - - reg64_t iter_kh = r11; - reg64_t iter_kw = r12; - reg64_t reg_kh = r13; - reg64_t reg_kw = r14; - - inline void loop_body(int ur_ch_blocks); - inline void load_ddst(int ur_ch_blocks, int ur_str_w); - inline void apply_filter(int ur_ch_blocks, int ur_str_w); - inline void store_dsrc(int ur_ch_blocks, int ur_str_w); - - void generate(); -}; - -template -struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) - - jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) { - this->generate(); - jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode(); - } - - static status_t init_conf(jit_conv_conf_t &jcp, - const convolution_desc_t &cd, const memory_desc_wrapper &src_d, - const memory_desc_wrapper &diff_weights_d, - const memory_desc_wrapper &diff_dst_d, int nthreads); - - static void init_scratchpad(memory_tracking::registrar_t &scratchpad, - const jit_conv_conf_t &jcp); - - static void balance(jit_conv_conf_t &jcp, int nthreads); - - jit_conv_conf_t jcp; - void (*jit_ker)(jit_dw_conv_call_s *); - -private: - using Vmm = typename utils::conditional3::type; - using reg64_t = const Xbyak::Reg64; - const int simd_w = cpu_isa_traits::vlen / sizeof(float); - const int reg_repeats = (isa == sse42) ? 2 : 1; - - const Xbyak::AddressFrame &vmmword - = (isa == sse42) ? xword : (isa == avx2) ? yword : zword; - - /* XXX: offset between input and accummulators is 3, therefore, assume 'kw' - * is no larger than 3*/ - inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); } - inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); } - inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); } - inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); } - inline Vmm get_aux_reg() { return Vmm(0); } - - reg64_t reg_tmp_input = r9; - reg64_t reg_tmp_output = r10; - reg64_t reg_tmp_filter = r13; - reg64_t reg_kh_offset = rax; - - /* parameter passed by driver into kernel */ - Xbyak::Reg8 reg_exec_flags = bl; - - reg64_t reg_oh_worksize = r14; - reg64_t reg_oh = rax; - - reg64_t iter_ow_blk = r11; - - reg64_t reg_kh = rsi; - reg64_t reg_kh_count = rdx; - - /* Base addresses for convolution parameters. */ - reg64_t reg_input_baddr = r15; - reg64_t reg_output_baddr = r12; - reg64_t reg_filter_baddr = abi_not_param1; - reg64_t reg_bias_baddr = r13; - - /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs - */ - inline void compute_ow_step_unroll( - int unroll_w, int l_pad, int pad_offset, int ow_block); - - /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ - inline void compute_h_step( - int unroll_w, int l_pad, int pad_offset, int ow_block); - inline void compute_h_loop( - int unroll_w, int l_pad, int pad_offset, int ow_block); - - /* Write 'width' micro-kernel JITs; depending on the padding and convolution - * size, write a micro-kernel for the left ow-block, middle ow-block(s), and - * right ow-block.*/ - inline void compute_ow_block_unroll(); - - inline void compute_zero_filter(); - inline void load_filter(); - inline void zero_filter(); - inline void load_bias(); - inline void zero_bias(); - inline void compute_bias_step_unroll(const int unroll_w); - inline void compute_bias_loop(const int block_size); - inline void store_filter(); - inline void store_bias(); - - void generate(); -}; -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp deleted file mode 100644 index 58449601a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp +++ /dev/null @@ -1,427 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "mkldnn_thread.hpp" - -#include "jit_uni_dw_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::memory_tracking::names; -using namespace mkldnn::impl::utils; - -template -void _jit_uni_dw_convolution_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const auto &jcp = kernel_->jcp; - - if (pd()->wants_padded_bias()) { - auto padded_bias = this->scratchpad(ctx).template get( - key_conv_padded_bias); - utils::array_copy(padded_bias, bias, jcp.oc_without_padding); - utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, - jcp.oc - jcp.oc_without_padding); - bias = padded_bias; - } - - int dil_h = jcp.dilate_h + 1; - int dil_w = jcp.dilate_w + 1; - int str_h = jcp.stride_h; - int str_w = jcp.stride_w; - - auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh, - int kh_padding, int ch, int ch_num, int n) { - auto par_conv = jit_conv_call_s(); - - const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w)); - const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w - + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw; - - const int iw = nstl::max((ow*str_w - jcp.l_pad - + div_up(i_l_overflow, dil_w)*dil_w), 0); - const int kw = div_up(i_l_overflow, dil_w); - - const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w) - - div_up(i_r_overflow, dil_w); - - par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)]; - par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)]; - - par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)]; - if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)]; - - par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); - par_conv.kw_padding = (size_t)nstl::max(0, kw_padding); - - par_conv.ur_w = (size_t)ur_w_step; - - par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; - - return par_conv; - }; - - const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); - parallel_nd(jcp.mb, chb_work, jcp.oh, - [&](int n, int chb, int oh) { - int ch = chb * jcp.nb_ch_blocking; - int ch_num = jcp.nb_ch_blocking; - - const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h)); - const int i_b_overflow = nstl::max(jcp.ih, - (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih; - - const int ih = nstl::max((int)(oh*str_h - jcp.t_pad - + div_up(i_t_overflow, dil_h)*dil_h), 0); - const int kh = div_up(i_t_overflow, dil_h); - const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) - - div_up(i_b_overflow, dil_h); - - // left border - int ow = 0; - int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow); - int ur_w_step = 1; - for (; ow < l_border; ow++) { - jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, - kh, kh_padding, ch, ch_num, n); - - kernel_->jit_ker(&par_conv); - } - - // main loop - ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1) - / jcp.stride_w - ow + 1; - if (ur_w_step > 0) { - jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, - kh, kh_padding, ch, ch_num, n); - - kernel_->jit_ker(&par_conv); - - ow += ur_w_step; - } - - // right border - ur_w_step = 1; - for (; ow < jcp.ow; ow++) { - jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, - kh, kh_padding, ch, ch_num, n); - - kernel_->jit_ker(&par_conv); - } - }); - - if (pd()->wants_zero_pad_dst()) - ctx.memory(MKLDNN_ARG_DST)->zero_pad(); -} - -template struct _jit_uni_dw_convolution_fwd_t; -template struct _jit_uni_dw_convolution_fwd_t; -template struct _jit_uni_dw_convolution_fwd_t; - -template -void _jit_uni_dw_convolution_bwd_data_t::execute_backward_data( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - - const auto &jcp = kernel_->jcp; - - auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, - int i_t_overflow, int i_b_overflow, int stride_off_h, - int ch, int ch_num, int n) { - auto par_conv = jit_conv_call_s(); - - const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); - const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) - - jcp.r_pad)); - - int ow = iw + jcp.l_pad - i_r_overflow; - int stride_off_w = ow % jcp.stride_w; - ow /= jcp.stride_w; - - par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)]; - par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)]; - par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow - + stride_off_h, i_r_overflow + stride_off_w)]; - - par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow - - stride_off_h); - par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow - - stride_off_w); - - par_conv.ur_str_w = ur_str_w; - - par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; - - return par_conv; - }; - - const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); - parallel_nd(jcp.mb, chb_work, jcp.ih, - [&](int n, int chb, int ih) { - int ch = chb * jcp.nb_ch_blocking; - int ch_num = jcp.nb_ch_blocking; - - const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih - - jcp.t_pad)); - const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1 - - (jcp.ih - 1 - ih) - jcp.b_pad)); - - int oh = ih + jcp.t_pad - i_b_overflow; - int stride_off_h = oh % jcp.stride_h; - oh /= jcp.stride_h; - - for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { - // left border - int iw = i_str_w; - int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); - int ur_str_w = 1; - for (; iw < l_border; iw += jcp.stride_w) { - jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, - ih, i_t_overflow, i_b_overflow, - stride_off_h, ch, ch_num, n); - - kernel_->jit_ker(&par_conv); - } - - // main loop - ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw) - / jcp.stride_w, jcp.iw); - if (ur_str_w > 0) { - jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, - ih, i_t_overflow, i_b_overflow, - stride_off_h, ch, ch_num, n); - - kernel_->jit_ker(&par_conv); - - iw += ur_str_w * jcp.stride_w; - } - - // right border - ur_str_w = 1; - for (; iw < jcp.iw; iw += jcp.stride_w) { - jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, - ih, i_t_overflow, i_b_overflow, - stride_off_h, ch, ch_num, n); - - kernel_->jit_ker(&par_conv); - } - } - }); -} - -template struct _jit_uni_dw_convolution_bwd_data_t; -template struct _jit_uni_dw_convolution_bwd_data_t; -template struct _jit_uni_dw_convolution_bwd_data_t; - -template -_jit_uni_dw_convolution_bwd_weights_t:: -_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd) - : cpu_primitive_t(apd) - , kernel_(nullptr), acc_ker_(nullptr) -{ - kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32(pd()->jcp_); - if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction()) - acc_ker_ = new cpu_accumulator_1d_t(); -} - -template -void _jit_uni_dw_convolution_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - auto diff_wei_reduction_buf = - scratchpad(ctx).template get(key_conv_wei_reduction); - auto diff_bia_reduction_buf = - scratchpad(ctx).template get(key_conv_bia_reduction); - - const auto &jcp = kernel_->jcp; - - /* Used when executing a parallel reduction */ - simple_barrier::ctx_t reduction_bctx; - simple_barrier::ctx_init(&reduction_bctx); - - const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; - const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; - - const int ch_block = jcp.ch_block; - - auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params, - const int batch, const int group, const int oh_start, - const int work_size, const unsigned char exec_flag, - const size_t kh_padding, const size_t filter_off) { - const int tpad_underflow_off = jcp.t_pad - filter_off; - - conv_params->exec_flags = exec_flag; - conv_params->kh_count = jcp.kh - kh_padding; - - const int oh_s = oh_start; - const int oh_e = oh_start + work_size; - const int ih_s = oh_s * jcp.stride_h; - - conv_params->filter_pad_off - = filter_off * jcp.kw * ch_block * sizeof(float); - conv_params->oh_index = oh_s; - conv_params->oh_count = oh_e; - - size_t diff_dst_off - = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh - + oh_start) - * jcp.ow; - - size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih - + ih_s - tpad_underflow_off) * jcp.iw; - - conv_params->output = &diff_dst[diff_dst_off * ch_block]; - conv_params->input = &src[src_off * ch_block]; - }; - - parallel(jcp.nthr, [&](const int ithr, const int nthr) { - assert(nthr == jcp.nthr); - - auto conv_params = jit_dw_conv_call_s(); - const int h_block_size = 15; - - /* assign iteration space to thread */ - const int ithr_g = ithr % jcp.nthr_g; - const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; - - /* split dimensions */ - int g_start{ 0 }, g_end{ 0 }; - balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end); - - int mb_start{ 0 }, mb_end{ 0 }; - balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); - - auto diff_wei = ithr_mb == 0 - ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size; - auto diff_bia = ithr_mb == 0 - ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size; - - for (int g = g_start; g < g_end; ++g) { - unsigned char zero_filter_flag = FLAG_ZERO_FILTER; - unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; - - size_t diff_wei_off = g * jcp.kh * jcp.kw; - conv_params.filter = &diff_wei[diff_wei_off * ch_block]; - - if (jcp.with_bias) - conv_params.bias = &diff_bia[g * ch_block]; - - for (int mb = mb_start; mb < mb_end; ++mb) { - int oh = 0; - while (oh < jcp.oh) { - const int h_work = nstl::min(h_block_size, jcp.oh - oh); - auto kh_t_padding = nstl::max(0, jcp.t_pad - oh); - auto kh_b_padding - = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ? - jcp.b_pad - (h_work - 1) : - 0; - - set_kernel_params(&conv_params, mb, g, oh, h_work, - zero_filter_flag | zero_bias_flag, - kh_t_padding + kh_b_padding, kh_t_padding); - kernel_->jit_ker(&conv_params); - - zero_bias_flag &= ~FLAG_ZERO_BIAS; - zero_filter_flag &= ~FLAG_ZERO_FILTER; - oh += h_work; - } - } - } - - if (do_parallel_reduction() && jcp.nthr_mb > 1) { - size_t reduct_start{ 0 }, reduct_end{ 0 }; - balance211(wei_size, nthr, ithr, reduct_start, reduct_end); - - const int acc_size = reduct_end - reduct_start; - const size_t reduct_off = reduct_start; - auto *acc_data = diff_weights + reduct_off; - - simple_barrier::barrier(&reduction_bctx, nthr); - - for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { - auto *src_data = diff_wei_reduction_buf - + (thr_mb - 1) * wei_size + reduct_off; - acc_ker_->accumulate(acc_data, src_data, acc_size); - } - } - }); - - if (jcp.nthr_mb <= 1) return; - - /* Apply single-threaded 'mb' reduction */ - for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { - size_t mb_accum_offset = (thr_mb - 1) * wei_size; - size_t b_accum_offset = (thr_mb - 1) * bias_size; - - for (int g = 0; g < jcp.nb_ch; ++g) { - /* Reduction on Bias */ - if (jcp.with_bias) { - PRAGMA_OMP_SIMD() - for (int g_block = 0; g_block < ch_block; ++g_block) { - size_t bias_offset = g * ch_block + g_block; - diff_bias[bias_offset] += diff_bia_reduction_buf[ - b_accum_offset + bias_offset]; - } - } - - if (do_parallel_reduction()) continue; - - for (int kh = 0; kh < jcp.kh; ++kh) - for (int kw = 0; kw < jcp.kw; ++kw) - { - size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw; - PRAGMA_OMP_SIMD() - for (int g_block = 0; g_block < ch_block; ++g_block) { - const size_t off = wei_offset * ch_block + g_block; - diff_weights[off] += - diff_wei_reduction_buf[mb_accum_offset + off]; - } - } - } - } -} - -template struct _jit_uni_dw_convolution_bwd_weights_t; -template struct _jit_uni_dw_convolution_bwd_weights_t; -template struct _jit_uni_dw_convolution_bwd_weights_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp deleted file mode 100644 index ca53749ec..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp +++ /dev/null @@ -1,266 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP -#define CPU_JIT_UNI_DW_CONVOLUTION_HPP - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" - -#include "cpu_barrier.hpp" -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" -#include "cpu_reducer.hpp" - -#include "jit_uni_dw_conv_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - pd_t(engine_t *engine, const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const typename pd_t::base_class *hint_fwd_pd) - : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), - _jit_uni_dw_convolution_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - status_t status = jit_uni_dw_conv_fwd_kernel_f32::init_conf( - jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_uni_dw_conv_fwd_kernel_f32::init_scratchpad(scratchpad, - jcp_); - - return status::success; - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; - auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32(pd()->jcp_); } - - ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_uni_dw_conv_fwd_kernel_f32 *kernel_; -}; - -using jit_avx512_common_dw_convolution_fwd_t = - _jit_uni_dw_convolution_fwd_t; -using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; -using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t; - -template -struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() - {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), - _jit_uni_dw_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::undef, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - - if (!ok) return status::unimplemented; - - status_t status = jit_uni_dw_conv_bwd_data_kernel_f32:: - init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), - *diff_dst_md()); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_uni_dw_conv_bwd_data_kernel_f32::init_scratchpad( - scratchpad, jcp_); - - return status::success; - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; - auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32(pd()->jcp_); } - ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; }; - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_uni_dw_conv_bwd_data_kernel_f32 *kernel_; -}; - -using jit_avx512_common_dw_convolution_bwd_data_t = - _jit_uni_dw_convolution_bwd_data_t; -using jit_avx2_dw_convolution_bwd_data_t = - _jit_uni_dw_convolution_bwd_data_t; -using jit_sse42_dw_convolution_bwd_data_t = - _jit_uni_dw_convolution_bwd_data_t; - -template -struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_weights_pd_t { - pd_t(engine_t *engine, - const convolution_desc_t *adesc, - const primitive_attr_t *attr, - const convolution_fwd_pd_t *hint_fwd_pd) - : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , jcp_() {} - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), - _jit_uni_dw_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(data_type::f32, data_type::f32, - data_type::f32, data_type::f32, data_type::f32) - && !has_zero_dim_memory() - && set_default_formats(); - if (!ok) return status::unimplemented; - - const int max_threads = mkldnn_in_parallel() - ? 1 : mkldnn_get_max_threads(); - - status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32:: - init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(), - *diff_dst_md(), max_threads); - if (status != status::success) return status; - - auto scratchpad = scratchpad_registry().registrar(); - jit_uni_dw_conv_bwd_weights_kernel_f32::init_scratchpad( - scratchpad, jcp_); - - return status::success; - } - - jit_conv_conf_t jcp_; - - protected: - bool set_default_formats() { - using namespace format_tag; - - auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; - auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; - - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); - ~_jit_uni_dw_convolution_bwd_weights_t() { - delete kernel_; - delete acc_ker_; - }; - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - bool do_parallel_reduction() const { return false; } - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_uni_dw_conv_bwd_weights_kernel_f32 *kernel_; - cpu_accumulator_1d_t *acc_ker_; -}; - -using jit_avx512_common_dw_convolution_bwd_weights_t = - _jit_uni_dw_convolution_bwd_weights_t; -using jit_avx2_dw_convolution_bwd_weights_t = - _jit_uni_dw_convolution_bwd_weights_t; -using jit_sse42_dw_convolution_bwd_weights_t = - _jit_uni_dw_convolution_bwd_weights_t; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp deleted file mode 100644 index 2af643587..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp +++ /dev/null @@ -1,1142 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "utils.hpp" - -#include "jit_uni_eltwise.hpp" - -#define GET_OFF(field) offsetof(jit_args, field) - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - -template -void jit_uni_eltwise_injector_f32::injector_preamble(size_t start_idx, - size_t end_idx) { - preserved_vecs_count = 0; - vecs_to_preserve = (size_t)aux_vecs_count(alg_); - start_idx_tail = start_idx; - - // For sse42 mask register has to be Xmm(0) - if (isa == sse42 && vecs_to_preserve > 0) { - size_t idx = 0; - assert(idx < start_idx); - preserved_vec_idxs[preserved_vecs_count++] = idx; - } - - for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) { - if (preserved_vecs_count >= vecs_to_preserve) break; - if (start_idx <= idx && idx < end_idx) continue; - - preserved_vec_idxs[preserved_vecs_count++] = idx; - } - - size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; - for (size_t i = 0; i < preserved_vecs_count_tail; i++) { - preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++; - } - - assert(preserved_vecs_count == vecs_to_preserve); - - if (save_state_) { - h->push(p_table); - - if (preserved_vecs_count) - h->sub(h->rsp, preserved_vecs_count * vlen); - - for (size_t i = 0; i < preserved_vecs_count; ++i) - h->uni_vmovups(h->ptr[h->rsp + i * vlen], - Vmm(preserved_vec_idxs[i])); - - load_table_addr(); - } - - assign_regs(); -} - -template -void jit_uni_eltwise_injector_f32::injector_preamble_tail(size_t start_idx) -{ - size_t tail_vecs_to_preserve = start_idx_tail - start_idx; - if (tail_vecs_to_preserve == 0) return; - - const int idx_off = vecs_to_preserve - tail_vecs_to_preserve; - - if (save_state_) { - if (idx_off) - h->add(h->rsp, idx_off * vlen); - - for (size_t i = 0; i < tail_vecs_to_preserve; ++i) - h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), - h->ptr[h->rsp + i * vlen]); - } - - for (size_t i = 0; i < tail_vecs_to_preserve; ++i) - preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; - - if (save_state_) { - for (size_t i = 0; i < tail_vecs_to_preserve; ++i) - h->uni_vmovups(h->ptr[h->rsp + i * vlen], - Vmm(preserved_vec_idxs[idx_off + i])); - - if (idx_off) - h->sub(h->rsp, idx_off * vlen); - } - - assign_regs(); -} - -template -void jit_uni_eltwise_injector_f32::injector_postamble() { - if (!save_state_) return; - - for (size_t i = 0; i < preserved_vecs_count; ++i) - h->uni_vmovups(Vmm(preserved_vec_idxs[i]), - h->ptr[h->rsp + i * vlen]); - - if (preserved_vecs_count) - h->add(h->rsp, preserved_vecs_count * vlen); - - h->pop(p_table); -} - -template -void jit_uni_eltwise_injector_f32::assign_regs() { - vmm_mask = Vmm(preserved_vec_idxs[0]); - vmm_aux0 = Vmm(preserved_vec_idxs[0]); - vmm_aux1 = Vmm(preserved_vec_idxs[1]); - vmm_aux2 = Vmm(preserved_vec_idxs[2]); - vmm_aux3 = Vmm(preserved_vec_idxs[3]); - vmm_aux4 = Vmm(preserved_vec_idxs[4]); -} - -template -void jit_uni_eltwise_injector_f32::exp_compute_vector(const Vmm &vmm_src) { - h->uni_vminps(vmm_src, vmm_src, table_val(10)); - h->uni_vmaxps(vmm_src, vmm_src, table_val(11)); - h->uni_vmovups(vmm_aux0, vmm_src); - //calculate exp(x) - // fx = x * log2ef + 0.5 - h->uni_vmulps(vmm_src, vmm_src, table_val(2)); - h->uni_vaddps(vmm_src, vmm_src, table_val(1)); - - // tmp = floorf(fx) - if (isa == avx512_common) { - h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src); - h->vcvtdq2ps(vmm_aux1, vmm_aux1); - - h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us); - h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); - - h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3); - } else { - h->uni_vroundps(vmm_aux1, vmm_src, _op_floor); - } - - //keep fx for further computations - h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx - - //x = x - fx * ln2 - h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3)); - - // compute 2^n - h->uni_vcvtps2dq(vmm_aux1, vmm_src); - h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); - h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx - - // y = p5 - h->uni_vmovups(vmm_src, table_val(9)); - // y = y * x + p4 - h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8)); - // y = y * x + p3 - h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7)); - // y = y * x + p2 - h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6)); - // y = y * x + p1 - h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0)); - // y = y * x + p0 - h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q) - // y = y * 2^n - h->uni_vmulps(vmm_src, vmm_src, vmm_aux1); -} - -template -void jit_uni_eltwise_injector_f32::relu_compute_vector(const Vmm &vmm_src) -{ - const int alpha_off = 0, zero_off = 1; - - h->uni_vmovups(vmm_aux1, vmm_src); - if (isa == sse42) { - h->movups(vmm_mask, vmm_src); - h->mulps(vmm_src, table_val(alpha_off)); - h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us); - h->blendvps(vmm_src, vmm_aux1); - } else if (isa == avx2) { - h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); - h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off)); - h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); - } else if (isa == avx512_common) { - h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); - h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us); - h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); - } -} - -template -void jit_uni_eltwise_injector_f32::relu_zero_ns_compute_vector( - const Vmm &vmm_src) { - const int zero_off = 1; - h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off)); -} - -template -void jit_uni_eltwise_injector_f32::elu_compute_vector(const Vmm &vmm_src) { - const int alpha_off = 23, zero_off = 24; - - // compute exponent - h->uni_vmovups(vmm_aux2, vmm_src); - exp_compute_vector(vmm_src); - - // alpha * (exp(x) - 1) - h->uni_vsubps(vmm_src, vmm_src, table_val(0)); - h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off)); - - // combine with mask - if (isa == sse42) { - h->pxor(vmm_mask, vmm_mask); - h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os); - h->blendvps(vmm_src, vmm_aux2); - } else if (isa == avx2) { - h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off)); - h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask); - } else if (isa == avx512_common) { - h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us); - h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2); - } -} - -template -void jit_uni_eltwise_injector_f32::tanh_compute_vector(const Vmm &vmm_src) -{ - // # comes from Taylor expansion error bound - // > linear_sat_point = single(sqrt(3) * 1b-12); - // # comes from the exp formula cancellation - // > exp_bound_point = (single(log(3)/2)); - // # comes from rounding accuracy in float - // > one_sat_point = round(atanh(1 - 1b-25), single, RU); - // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |], - // [linear_sat_point, exp_bound_point], relative, floating); - // > err_bound = D(sup(supnorm(P, tanh(x), - // [linear_sat_point, exp_bound_point], relative, theta))); - // 0x1.fffd6f00b9539p-25 - // > P; - // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 * - // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5 - // + x^0x1p1 * 0x1.09fa1p-6)))) - - // register mapping - // vmm_src contains input - // vmm_aux0 contains mask of currently valid results. - // 1 is need computation, 0 is already computed - // vmm_aux1 contains current output - // vmm_aux2, vmm_aux3 contains auxiliary values - // vmm_aux4 contains the original sign of inputs - - Label end_tanh_label; - - auto test_exit =[&](Xbyak::Address threshold){ - // is not necessary for >AVX, but should not matter on perf - h->uni_vmovups(vmm_aux0, vmm_src); - if (isa == avx512_common){ - h->vcmpps(k_mask, vmm_aux0, threshold, 0x5); - h->kortestw(k_mask, k_mask); - } else { - h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold); - h->uni_vtestps(vmm_aux0, vmm_aux0); - } - h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR); - }; - - auto blend_results=[&](Vmm vmm_partial_res){ - if (isa == avx512_common) - h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res); - else - h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0); - }; - - // because tanh(x) = -tanh(-x), we extract sign to make x postive - // and reapply sign at the end - // mov is not necessary for >AVX, but should not matter for performance - h->uni_vmovups(vmm_aux4, vmm_src); - h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12)); - h->uni_vandps(vmm_src, vmm_src, table_val(17)); - - // if x < linear_sat_point for all inputs, we just return the input - h->uni_vmovups(vmm_aux1, vmm_src); - test_exit(table_val(13)); - - // if one of the mask is one, we have to compute an better approx - h->uni_vmovups(vmm_aux2, vmm_src); - h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2); - h->uni_vmovups(vmm_aux3, table_val(22)); - h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21)); - h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20)); - h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19)); - h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18)); - h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src); - - // we blend only the result that need update - blend_results(vmm_aux3); - - // if x < exp_bound_point, we go to return point - test_exit(table_val(14)); - - // if not we use a better approx 1 - 2 / (1 + exp(2x)) - // compute 2x - h->uni_vmovups(vmm_aux3, vmm_src); - h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3); - - // Compute exp(2x) - // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them - // vmm_src is not more read afterwards, so we do not have to save it - auto stack_size = 3 * vlen + (isa == avx512_common) * 4; - h->sub(h->rsp, stack_size); - h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0); - h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1); - h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src); - if (isa == avx512_common) - h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask); - - exp_compute_vector(vmm_aux3); - - h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]); - h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]); - h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]); - if (isa == avx512_common) - h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]); - h->add(h->rsp, stack_size); - - // 1 + exp(2x) - h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0)); - - // 1 - 2 / (1 + exp(2x)) - h->uni_vmovups(vmm_aux2, table_val(16)); - h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3); - h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0)); - - // we blend only the result that need update - blend_results(vmm_aux2); - - // finally, we saturate to 1 if needed - // TODO: maybe move that up if most inputs saturate in practice - if (isa == avx512_common) - h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5); - else { - h->uni_vmovups(vmm_aux0, vmm_src); - h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15)); - } - h->uni_vmovups(vmm_aux2, table_val(0)); - blend_results(vmm_aux2); - - h->L(end_tanh_label); - { - // we apply the sign of x to the result and we are done - h->uni_vmovups(vmm_src, vmm_aux1); - h->uni_vpxor(vmm_src, vmm_src, vmm_aux4); - } -} - -template -void jit_uni_eltwise_injector_f32::square_compute_vector( - const Vmm &vmm_src) { - h->uni_vmulps(vmm_src, vmm_src, vmm_src); -} - -template -void jit_uni_eltwise_injector_f32::abs_compute_vector(const Vmm &vmm_src) { - // compute abs(x) = _mm_and_ps(x, 01111..111)); - h->uni_vandps(vmm_src, vmm_src, table_val(0)); -} - -template -void jit_uni_eltwise_injector_f32::sqrt_compute_vector(const Vmm &vmm_src) -{ - if (isa == avx512_common) { - h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us); - h->uni_vsqrtps(vmm_aux1, vmm_src); - h->uni_vmovups(vmm_src, table_val(0)); - h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); - } else { - h->uni_vmovups(vmm_mask, vmm_src); - h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0)); - h->uni_vsqrtps(vmm_aux1, vmm_src); - h->uni_vmovups(vmm_src, table_val(0)); - h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); - } -} - -template -void jit_uni_eltwise_injector_f32::linear_compute_vector( - const Vmm &vmm_src) { - // compute x = alpha * x + beta; - h->uni_vmovups(vmm_aux0, table_val(0)); - h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1)); -} - -template -void jit_uni_eltwise_injector_f32::bounded_relu_compute_vector( - const Vmm &vmm_src) { - // compute bounded relu */ - h->uni_vmaxps(vmm_src, vmm_src, table_val(1)); - h->uni_vminps(vmm_src, vmm_src, table_val(0)); -} - -template -void jit_uni_eltwise_injector_f32::soft_relu_compute_vector( - const Vmm &vmm_src) { - // duplicate src - h->uni_vmovups(vmm_aux2, vmm_src); - - h->uni_vminps(vmm_src, vmm_src, table_val(24)); - h->uni_vmaxps(vmm_src, vmm_src, table_val(25)); - h->uni_vmovups(vmm_aux1, vmm_src); - // calculate exp(x) - // fx = x * log2ef + 0.5 - h->uni_vmulps(vmm_src, vmm_src, table_val(2)); - h->uni_vaddps(vmm_src, vmm_src, table_val(1)); - - // tmp = floorf(fx) - if (isa == avx512_common) { - h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src); - h->vcvtdq2ps(vmm_aux0, vmm_aux0); - - h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us); - h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); - - h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3); - } else { - h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); - } - - // keep fx for further computations - h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx - // calculation fx * ln2 - h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3)); - // x = x - fx * ln2 - h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); - // y = p5 - h->uni_vmovups(vmm_aux3, table_val(22)); - // y = y * x + p4 - h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21)); - // y = y * x + p3 - h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20)); - // y = y * x + p2 - h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19)); - // y = y * x + p1 - h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0)); - // y = y * x + p0 - h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17)); - - // compute 2^(-n) - if (isa == avx512_common) { - h->vmulps(vmm_aux1, vmm_src, table_val(23)); - h->vcvtps2dq(vmm_aux1, vmm_aux1); - } else { - h->uni_vcvtps2dq(vmm_aux1, vmm_src); - h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23)); - } - - h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); - h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx - // calculate ln(1 + y) - h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); - // x = y; y is free; keep x for further computations - h->uni_vmovups(vmm_src, vmm_aux3); - // frexp() - h->uni_vpsrld(vmm_src, vmm_src, 23); - h->uni_vcvtdq2ps(vmm_src, vmm_src); - // got n. where n is x = 2^n * y. y = 0.5 .. 1 - h->uni_vsubps(vmm_src, vmm_src, table_val(5)); - - h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6)); - // got y. (mantisa) 0.5 < y < 1 - h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7)); - // y = y - 1 - h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0)); - // y = p8 - h->uni_vmovups(vmm_aux1, table_val(16)); - // y = y * x + p7 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15)); - // y = y * x + p6 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14)); - // y = y * x + p5 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13)); - // y = y * x + p4 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12)); - // y = y * x + p3 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11)); - // y = y * x + p2 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10)); - // y = y * x + p1 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9)); - // y = y * x + p0 ; p0 = 0 - h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8)); - //calculate ln(2) * n - h->uni_vmulps(vmm_src, vmm_src, table_val(3)); - h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src); - h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0); - - // get vmm_mask = src > max logf - h->uni_vmovups(vmm_mask, vmm_aux2); - if (isa == avx512_common) { - // y = (x < max log f) ? soft_relu(x) : x - h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us); - h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2); - } else { - // y = (x < max log f) ? soft_relu(x) : x - h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24)); - h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask); - } - - h->uni_vmovups(vmm_src, vmm_aux1); -} - -template -void jit_uni_eltwise_injector_f32::logistic_compute_vector( - const Vmm &vmm_src) { - // we store the original sign and make x negative - // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required - // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it. - h->uni_vmovups(vmm_aux2, vmm_src); - h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12)); - h->uni_vorps(vmm_src, vmm_src, table_val(12)); - - exp_compute_vector(vmm_src); - // dup exp(x) - h->uni_vmovups(vmm_aux1, vmm_src); - // (exp(x) + 1) - h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0)); - // y = exp(x) / (exp(x) + 1) - h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); - - // Now we have to apply the "symmetry" based on original sign - h->uni_vmovups(vmm_aux3, table_val(0)); - h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); - if (isa == avx512_common) { - h->vptestmd(k_mask, vmm_aux2, vmm_aux2); - h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src); - } else { - h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2 - h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0); - } - h->uni_vmovups(vmm_src, vmm_aux3); -} - -template -void jit_uni_eltwise_injector_f32::relu_prepare_table() { - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); -} - -template -void jit_uni_eltwise_injector_f32::elu_prepare_table() { - const unsigned int cvals[] = { - 0x3f800000, // [0] 1.0f - 0x3f000000, // [1] 0.5f - 0x3fb8aa3b, // [2] log2ef = 1.44269502f - 0x3f317218, // [3] ln2f = 0.69314718f - 0x0000007f, // [4] 0x7f - // exp(x) polynom - 0x3f800001, // [5] p0 = 1.0000001f - 0x3efffe85, // [6] p2 = 0.4999887f - 0x3e2aaa3e, // [7] p3 = 0.16666505f - 0x3d2bb1b1, // [8] p4 = 0.041917507f - 0x3c091ec1, // [9] p5 = 0.008369149f - 0x42b0c0a5, //[10] max logf = 88.3762589f - 0xc1766666, //[11] min logf = -14.5f - // tanh(x) constants, - 0x80000000, //[12] mask to extract sign - 0x39ddb3d7, //[13] arg below which tanh(x) = x - 0x3f0c9f54, //[14] arg below which pol approx is valid - 0x41102cb4, //[15] arg after which tanh(x) = 1 - 0xc0000000, //[16] -2.0f - 0x7fffffff, //[17] mask to make positive - // tanh pol approx - 0x3f7fffff, //[18] p0 - 0xbeaaa9cf, //[19] p1 - 0x3e085f1f, //[20] p2 - 0xbd572bda, //[21] p3 - 0x3c84fd08, //[22] p4 - }; - - for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]); - } - - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); -} - -template -void jit_uni_eltwise_injector_f32::soft_relu_prepare_table() { - const unsigned int cvals[] = { - 0x3f800000, // [0] 1.0f - 0x3f000000, // [1] 0.5f - 0x3fb8aa3b, // [2] log2ef = 1.44269502f - 0x3f317218, // [3] ln2f = 0.69314718f - 0x0000007f, // [4] 0x7f - 0x42fc0000, // [5] 126 - 0x807fffff, // [6] and with (to get 0.5 * mantissa) - 0x3f000000, // [7] or with (to get 0.5 * mantissa) - // ln(1 + x) polynomial - 0xb2b4637d, // [8] p0 = 0.0000000244f - 0x3f7fff8e, // [9] p1 = 0.9999976971f - 0xbf001759, //[10] p2 = -0.5002478215f - 0x3ea70608, //[11] p3 = 0.3272714505f - 0xbea3d7bf, //[12] p4 = -0.3153830071f - 0xbe361d04, //[13] p5 = -0.1701777461f - 0xbfa8f1e6, //[14] p6 = -1.3254635147f - 0xbfe1e812, //[15] p7 = -1.7971917960f - 0xbfc4d30e, //[16] p8 = -1.5652673123f - // exp(x) polynomial - 0x3f800001, //[17] p0 = 1.0000001f - 0x3f800000, //[18] p1 = 1.0f - 0x3efffe85, //[19] p2 = 0.4999887f - 0x3e2aaa3e, //[20] p3 = 0.16666505f - 0x3d2bb1b1, //[21] p4 = 0.041917507f - 0x3c091ec1, //[22] p5 = 0.008369149f - 0xbf800000, //[23] is required for sign changing - 0x42b0c0a5, //[24] max logf = 88.3762589f - 0xc1766666 //[25] min logf = -14.5f - }; - - for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { - for (size_t d = 0; d < vlen / sizeof(float); ++d) { - h->dd(cvals[i]); - } - } -} - -template -void jit_uni_eltwise_injector_f32::abs_prepare_table() { - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff); -} - -template -void jit_uni_eltwise_injector_f32::sqrt_prepare_table() { - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); -} - -template -void jit_uni_eltwise_injector_f32::linear_prepare_table() { - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_)); -} - -template -void jit_uni_eltwise_injector_f32::bounded_relu_prepare_table() { - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); - for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); -} - -template -int jit_uni_eltwise_injector_f32::aux_vecs_count(alg_kind_t alg_) { - switch (alg_) { - case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2; - case alg_kind::eltwise_elu: return 4; - case alg_kind::eltwise_tanh: return 5; - case alg_kind::eltwise_square: return 0; - case alg_kind::eltwise_abs: return 0; - case alg_kind::eltwise_sqrt: return 2; - case alg_kind::eltwise_linear: return 1; - case alg_kind::eltwise_bounded_relu: return 0; - case alg_kind::eltwise_soft_relu: return 4; - case alg_kind::eltwise_logistic: return 4; - default: assert(!"unsupported eltwise algorithm"); - } - - return 0; -} - -template -void jit_uni_eltwise_injector_f32::compute_body(size_t start_idx, - size_t end_idx) { - using namespace alg_kind; - for (size_t idx = start_idx; idx < end_idx; idx++) { - switch (alg_) { - case eltwise_relu: - if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx)); - else relu_compute_vector(Vmm(idx)); - break; - case eltwise_elu: elu_compute_vector(Vmm(idx)); break; - case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break; - case eltwise_square: square_compute_vector(Vmm(idx)); break; - case eltwise_abs: abs_compute_vector(Vmm(idx)); break; - case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break; - case eltwise_linear: linear_compute_vector(Vmm(idx)); break; - case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break; - case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break; - case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break; - default: assert(!"unsupported eltwise algorithm"); - } - } -} - -template -void jit_uni_eltwise_injector_f32::compute_vector_range(size_t start_idx, - size_t end_idx) { - assert(start_idx < end_idx && end_idx <= vecs_count); - - injector_preamble(start_idx, end_idx); - compute_body(start_idx_tail, end_idx); - injector_preamble_tail(start_idx); - compute_body(start_idx, start_idx_tail); - injector_postamble(); -} - -template -void jit_uni_eltwise_injector_f32::prepare_table(bool gen_table) { - using namespace alg_kind; - - h->align(64); - h->L(l_table); - - if (gen_table) { - switch (alg_) { - case eltwise_relu: relu_prepare_table(); break; - case eltwise_elu: - case eltwise_tanh: - case eltwise_logistic: - elu_prepare_table(); break; - case eltwise_soft_relu: soft_relu_prepare_table(); break; - case eltwise_abs: abs_prepare_table(); break; - case eltwise_sqrt: sqrt_prepare_table(); break; - case eltwise_linear: linear_prepare_table(); break; - case eltwise_bounded_relu: bounded_relu_prepare_table(); break; - case eltwise_square: break; - default: assert(!"unsupported eltwise algorithm"); - } - } -} - -template struct jit_uni_eltwise_injector_f32; -template struct jit_uni_eltwise_injector_f32; -template struct jit_uni_eltwise_injector_f32; - - -struct jit_args { - const float *from; - const float *for_comparison; - const float *to; - size_t work_amount; -}; - -struct jit_uni_eltwise_kernel_f32 : public c_compatible { - const eltwise_desc_t &desc_; - - void (*ker_)(const jit_args *); - void operator()(const jit_args *args) { assert(ker_); ker_(args); } - - jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc) - : desc_(desc), ker_(nullptr) {} - virtual ~jit_uni_eltwise_kernel_f32() {} - -protected: - bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; } -}; - -/* jit kernels */ -namespace { - -template -struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32, - public jit_generator -{ - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32) - - void compute_step(bool vectorize, const int uf, const int shift) { - for (int i = 0; i < uf; i++) { - if (vectorize) { - uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]); - if (is_bwd()) - uni_vmovups(Vmm(uf + i + 1), - ptr[reg_for_comparison + i * shift]); - } else { - movss(Xmm(i + 1), ptr[reg_from + i * shift]); - if (is_bwd()) - movss(Xmm(uf + i + 1), - ptr[reg_for_comparison + i * shift]); - } - } - - if (isa == sse42) { - for (int i = 0; i < uf; i++) { - movups(Vmm(2 * uf + i + 1), Vmm(i + 1)); - mulps(Vmm(2 * uf + i + 1), vmm_ns); - - Vmm mask = Vmm(0); - if (is_bwd()) { - movups(mask, Vmm(uf + i + 1)); - cmpps(mask, vmm_zero, _cmp_nle_us); - } else { - movups(mask, Vmm(i + 1)); - cmpps(mask, vmm_zero, _cmp_nle_us); - } - blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1)); - } - } else { - for (int i = 0; i < uf; i++) { - vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns); - if (isa == avx2) { - if (is_bwd()) - vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero); - else - vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero); - - vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1), - Vmm(i + 1), vmm_mask); - - } else { - if (is_bwd()) - vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us); - else - vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us); - vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1), - Vmm(i + 1)); - } - } - } - - for (int i = 0; i < uf; i++) { - if (vectorize) { - uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1)); - } else { - movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1)); - } - } - } - - jit_uni_relu_kernel_f32(const eltwise_desc_t &desc) - : jit_uni_eltwise_kernel_f32(desc), jit_generator() { - assert(desc.alg_kind == alg_kind::eltwise_relu); - assert(isa == sse42 || isa == avx2 || isa == avx512_common); - - Reg64 param = abi_param1; - - const int simd_w = cpu_isa_traits::vlen / sizeof(float); - const int loop_dec[] = {simd_w, 1}; - const int uf[] = {1, 1}; - const int shift[] = {cpu_isa_traits::vlen, sizeof(float)}; - const bool loop_vectorize[] = {true, false}; - - this->preamble(); - - mov(reg_from, ptr[param + GET_OFF(from)]); - if (is_bwd()) - mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]); - mov(reg_to, ptr[param + GET_OFF(to)]); - mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); - - mov(imm_addr64, float2int(desc.alpha)); - movq(xmm_ns, imm_addr64); - uni_vbroadcastss(vmm_ns, xmm_ns); - - uni_vpxor(vmm_zero, vmm_zero, vmm_zero); - - Label loop_label[3]; - - for (int id = 0; id < 2; id++) { - L(loop_label[id]); - cmp(reg_work_amount, uf[id] * loop_dec[id] - 1); - jle(loop_label[id + 1], T_NEAR); - - compute_step(loop_vectorize[id], uf[id], shift[id]); - - add(reg_from, uf[id] * shift[id]); - add(reg_to, uf[id] * shift[id]); - if (is_bwd()) - add(reg_for_comparison, uf[id] * shift[id]); - - sub(reg_work_amount, uf[id] * loop_dec[id]); - jmp(loop_label[id]); - } - - L(loop_label[2]); - this->postamble(); - - ker_ = (decltype(ker_))this->getCode(); - } - -private: - using Vmm = typename utils::conditional3::type; - - Reg64 reg_from = rax; - Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from; - Reg64 reg_to = r8; - Reg64 reg_work_amount = rsi; - Reg64 imm_addr64 = rbx; - - Xmm xmm_ns = Xmm(14); - - Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14); - Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15); - - Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12); - Opmask k_mask = Opmask(1); -}; - -template -struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32, - public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32) - - jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc) - : jit_uni_eltwise_kernel_f32(desc), jit_generator() { - - eltwise_injector_ = new jit_uni_eltwise_injector_f32(this, - desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1)); - - using namespace alg_kind; - - assert(is_bwd() == false); - assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu, - eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, - eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); - - preamble(); - - Reg64 param = abi_param1; - mov(reg_from, ptr[param + GET_OFF(from)]); - mov(reg_to, ptr[param + GET_OFF(to)]); - mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); - eltwise_injector_->load_table_addr(); - - Label reminder_loop_start, reminder_loop_end; - Label vectorized_loop_start, vectorized_loop_end; - - cmp(reg_work_amount, simd_w); - jl(reminder_loop_start, T_NEAR); - - L(vectorized_loop_start); - - uni_vmovups(vmm_src, ptr[reg_from]); - eltwise_injector_->compute_vector(vmm_src.getIdx()); - uni_vmovups(ptr[reg_to], vmm_src); - - add(reg_from, vlen); - add(reg_to, vlen); - - sub(reg_work_amount, simd_w); - cmp(reg_work_amount, simd_w); - jge(vectorized_loop_start, T_NEAR); - - L(vectorized_loop_end); - - L(reminder_loop_start); - - cmp(reg_work_amount, 0); - jle(reminder_loop_end, T_NEAR); - - movss(xmm_src, ptr[reg_from]); - eltwise_injector_->compute_vector(xmm_src.getIdx()); - movss(ptr[reg_to], xmm_src); - - add(reg_from, sizeof(float)); - add(reg_to, sizeof(float)); - - dec(reg_work_amount); - jmp(reminder_loop_start, T_NEAR); - - L(reminder_loop_end); - - postamble(); - - eltwise_injector_->prepare_table(); - - ker_ = (decltype(ker_))this->getCode(); - } - - ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; } - -private: - using Vmm = typename utils::conditional3::type; - - const int simd_w = cpu_isa_traits::vlen / sizeof(float); - const int vlen = cpu_isa_traits::vlen; - - Reg64 reg_from = rax; - Reg64 reg_to = r8; - Reg64 reg_work_amount = rsi; - Reg64 imm_addr64 = rbx; - - Xmm xmm_src = Xmm(1); - Vmm vmm_src = Vmm(1); - - jit_uni_eltwise_injector_f32 *eltwise_injector_; -}; - -} /* namespace */ - -template -status_t jit_uni_eltwise_fwd_t::pd_t::init() { - using namespace alg_kind; - - bool ok = true - && mayiuse(isa) - && is_fwd() - && utils::everyone_is(data_type::f32, desc()->data_desc.data_type) - && !has_zero_dim_memory() - && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh, - eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, - eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, - eltwise_logistic) - && memory_desc_wrapper(src_md()).is_dense(true) - && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false), - math::eltwise_fwd_preserves_zero(desc()->alg_kind, true)) - && attr()->has_default_values(); - - return ok ? status::success : status::unimplemented; -} - -template -jit_uni_eltwise_fwd_t::jit_uni_eltwise_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd), kernel_(nullptr) { - const auto &desc = *pd()->desc(); - switch (desc.alg_kind) { - case alg_kind::eltwise_relu: - kernel_ = new jit_uni_relu_kernel_f32(desc); break; - default: - kernel_ = new jit_uni_kernel_fwd_f32(desc); - } -} - -template -jit_uni_eltwise_fwd_t::~jit_uni_eltwise_fwd_t() -{ delete kernel_; } - -template -void jit_uni_eltwise_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper data_d(pd()->src_md()); - - const size_t nelems = data_d.nelems(true); - - src += data_d.offset0(); - dst += data_d.offset0(); - - parallel(0, [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - - const int cache_line = 16; - - balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); - start = nstl::min(nelems, start * cache_line); - end = nstl::min(nelems, end * cache_line); - - auto arg = jit_args(); - arg.from = &src[start]; - arg.for_comparison = &src[start]; - arg.to = &dst[start]; - arg.work_amount = end - start; - if (arg.work_amount) - (*kernel_)(&arg); - }); -} - -template -status_t jit_uni_eltwise_bwd_t::pd_t::init() { - bool ok = true - && !is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu) - && src_md()->data_type == data_type::f32 - && !has_zero_dim_memory() - && mayiuse(isa) - && memory_desc_wrapper(src_md()).is_dense() - && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md()) - && attr()->has_default_values(); - - return ok ? status::success : status::unimplemented; -} - -template -jit_uni_eltwise_bwd_t::jit_uni_eltwise_bwd_t(const pd_t *apd) - : cpu_primitive_t(apd), kernel_(nullptr) { - const auto &desc = *pd()->desc(); - switch (desc.alg_kind) { - case alg_kind::eltwise_relu: - kernel_ = new jit_uni_relu_kernel_f32(desc); break; - default: assert(!"unknown eltwise alg_kind"); - } -} - -template -jit_uni_eltwise_bwd_t::~jit_uni_eltwise_bwd_t() -{ delete kernel_; } - -template -void jit_uni_eltwise_bwd_t::execute_backward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper data_d(pd()->src_md()); - const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); - - const size_t nelems = data_d.nelems(); - - src += data_d.offset0(); - diff_dst += diff_data_d.offset0(); - diff_src += diff_data_d.offset0(); - - parallel(0, [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - - const int cache_line = 16; - - balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); - start = nstl::min(nelems, start * cache_line); - end = nstl::min(nelems, end * cache_line); - - auto arg = jit_args(); - arg.from = &diff_dst[start]; - arg.to = &diff_src[start]; - arg.for_comparison = &src[start]; - arg.work_amount = end - start; - if (arg.work_amount) - (*kernel_)(&arg); - }); -} - -template struct jit_uni_eltwise_fwd_t; -template struct jit_uni_eltwise_bwd_t; -template struct jit_uni_eltwise_fwd_t; -template struct jit_uni_eltwise_bwd_t; -template struct jit_uni_eltwise_fwd_t; -template struct jit_uni_eltwise_bwd_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp deleted file mode 100644 index 45436b9f4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp +++ /dev/null @@ -1,193 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_UNI_ELTWISE_HPP -#define CPU_JIT_UNI_ELTWISE_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_eltwise_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_uni_eltwise_injector_f32 { - using Vmm = typename utils::conditional3::type; - - jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, - float alpha, float beta, bool save_state = true, - Xbyak::Reg64 p_table = Xbyak::util::rax, - Xbyak::Opmask k_mask = Xbyak::Opmask(1)) - : alg_(alg), alpha_(alpha), beta_(beta), h(host) - , save_state_(save_state), p_table(p_table), k_mask(k_mask) - { - using namespace alg_kind; - assert(utils::one_of(isa, sse42, avx2, avx512_common)); - assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, - eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, - eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); - } - - // note that eltwise.scale is ignored - jit_uni_eltwise_injector_f32(jit_generator *host, - const post_ops_t::entry_t::eltwise_t &eltwise, - bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, - Xbyak::Opmask k_mask = Xbyak::Opmask(1)) - : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, - eltwise.beta, save_state, p_table, k_mask) {} - - void compute_vector_range(size_t start_idx, size_t end_idx); - void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } - void prepare_table(bool gen_table=true); - void load_table_addr() { h->mov(p_table, l_table); } - - const alg_kind_t alg_; - const float alpha_; - const float beta_; - - jit_generator * const h; - - const bool save_state_; - const Xbyak::Reg64 p_table; - const Xbyak::Opmask k_mask; - Xbyak::Label l_table; - -private: - // if only the injector was inherited from jit_generator... - enum { - _cmp_le_os = jit_generator::_cmp_le_os, - _cmp_nle_us = jit_generator::_cmp_nle_us, - _op_floor = jit_generator::_op_floor, - }; - - size_t vlen = cpu_isa_traits::vlen; - - const static size_t preserved_vecs_max = 5; - - size_t vecs_to_preserve = 0; - size_t vecs_count = isa == avx512_common ? 32 : 16; - size_t preserved_vecs_count = 0; - size_t preserved_vec_idxs[preserved_vecs_max] = {0}; - size_t start_idx_tail = 0; - - Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4; - - Xbyak::Address table_val(int index) - { return h->ptr[p_table + index * vlen]; } - - int aux_vecs_count(alg_kind_t alg); - - void compute_body(size_t start_idx, size_t end_idx); - void injector_preamble(size_t start_idx, size_t end_idx); - void injector_preamble_tail(size_t start_idx); - void injector_postamble(); - void assign_regs(); - - void exp_compute_vector(const Vmm &vmm_src); - void relu_compute_vector(const Vmm &vmm_src); - void relu_zero_ns_compute_vector(const Vmm &vmm_src); - void elu_compute_vector(const Vmm &vmm_src); - void tanh_compute_vector(const Vmm &vmm_src); - void square_compute_vector(const Vmm &vmm_src); - void abs_compute_vector(const Vmm &vmm_src); - void sqrt_compute_vector(const Vmm &vmm_src); - void linear_compute_vector(const Vmm &vmm_src); - void bounded_relu_compute_vector(const Vmm &vmm_src); - void soft_relu_compute_vector(const Vmm &vmm_src); - void logistic_compute_vector(const Vmm &vmm_src); - - void relu_prepare_table(); - void elu_prepare_table(); - void soft_relu_prepare_table(); - void abs_prepare_table(); - void sqrt_prepare_table(); - void linear_prepare_table(); - void bounded_relu_prepare_table(); -}; - -struct jit_uni_eltwise_kernel_f32; - -template -struct jit_uni_eltwise_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_eltwise_fwd_pd_t { - using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_eltwise_fwd_t); - - status_t init(); - }; - - jit_uni_eltwise_fwd_t(const pd_t *apd); - ~jit_uni_eltwise_fwd_t(); - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_uni_eltwise_kernel_f32 *kernel_; -}; - -template -struct jit_uni_eltwise_bwd_t : public cpu_primitive_t { - struct pd_t : public cpu_eltwise_bwd_pd_t { - using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_eltwise_bwd_t); - - status_t init(); - }; - - jit_uni_eltwise_bwd_t(const pd_t *apd); - ~jit_uni_eltwise_bwd_t(); - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_uni_eltwise_kernel_f32 *kernel_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp deleted file mode 100644 index a3ca6273a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp +++ /dev/null @@ -1,949 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "jit_uni_i8i8_pooling.hpp" - -#include - -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::types; -using namespace alg_kind; - -template -struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t) - - struct call_params_t { - const char *src_i8; - const char *dst_i8; - size_t kw_range; - size_t kh_range; - float idivider; - }; - - using Vmm = typename cpu_isa_traits::Vmm; - Xmm xreg(int idx) const { return Xmm(idx); } - Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); } - Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); } - - // In case of avx2 with data type i8 we need to use - // maskmovdqu instruction which has its destination hardcoded in rdi. - // Windows ABI: abi_param1 is rcx - nothing to do else - // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1 - Reg64 reg_param = rcx; // Our "unified abi_param1" - Reg64 reg_ptr_src_i8 = r8; - Reg64 reg_ptr_dst_i8 = r9; - Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi - - Reg64 ki = r10; - Reg64 kj = r11; - Reg64 reg_kw = r12; - Reg64 reg_kh = r13; - Reg64 c_iter = r14; - - Reg64 aux_reg_src_h = rax; - Reg64 aux_reg_src_w = rbx; - - Reg64 reg_tmp = rdx; - - Reg64 reg_mask = r15; - - Opmask k_cmp_mask = Opmask(7); - - Opmask mask(int idx) { - return Opmask(6 - idx); - } - - // ref to any of XYZ-regs via xreg/yreg/vreg functions - Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp - Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type - Vmm vreg_zeros = vreg(1); - - // only in case of == avx2 - Vmm vreg_mask = vreg(2); // full byte-mask - Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask) - Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately) - Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations - Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails - - enum:int {vidx_base = isa == avx2 ? 4 : 2}; - Vmm base_vr(int idx) const { return vreg(vidx_base + idx); } - - size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); } - size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); } - - /* max pooling */ - Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1] - Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1] - - /* avg pooling */ - // s32 used for processing of s8/u8 data - // thus we need to take into account ratio of sizes s32/i8 = 4 - static constexpr data_type_t avg_proc_dt = data_type::s32; - enum:int { - s32_to_i8_ratio = sizeof(typename prec_traits::type) - / sizeof(typename prec_traits::type), - max_num_ll = s32_to_i8_ratio - }; - Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3] - Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7] - Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11] - - void (*ker_)(const call_params_t *); - jit_pool_conf_t jpp; - - void init_tmp_reg(); - void init_mask(); - - void load_vreg_mask_q(int ll) {}; - - void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); - void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); - void load_src(int jj, int ll, int c_tail); - - void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); - void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); - void store_dst(int jj, int ll, int c_tail); - - void compute_avg_step(int ur_c, int c_tail); - void compute_max_op(const int jj); - void compute_max_step(int ur_c, int c_tail); - void compute_step(int ur_c, int c_tail); - - void compute_c_block(); - void generate(); - - static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd); - - jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_) - : jpp(jpp_) { - generate(); - ker_ = reinterpret_cast(const_cast( - getCode())); - } -}; - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::load_vreg_mask_q(int ll) { - - // extract ll-th part of mask (ll-th QWORD) - vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD - - // Move mask from ll-th pos to 0-th pos - if (ll>0) - vpermq(vreg_mask_q, vreg_mask_q, ll); -}; - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - if (masked) { - if (jpp.src_dt == s32) { - vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast(msk)); - } else { - vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask); - } - } else - vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); -}; - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::load_src_max_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - if (masked) { - if (jpp.src_dt == s32) - vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); - else - vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); - } else - vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); -}; - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - // Don't generate useless code - if (masked && !msk) - return; - - auto load_i8 = [&](bool is_signed, const Vmm& vr_src) { - - // Need to use mask of tail? - if (masked) { - - // load ll-th part of mask into vreg_mask_q - load_vreg_mask_q(ll); - - // Load by mask from mem into register vr_src - vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q); - - // Conversion s8/u8 -> s32 - if (is_signed) - vpmovsxbd(vr_src, vr_src); - else - vpmovzxbd(vr_src, vr_src); - } else { - - // Load from mem into vr_src with conversion - if (is_signed) - vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); - else - vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); - } - }; - - switch (jpp.src_dt) { - case s32: - if (masked) - vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset], - static_cast(msk)); - else - vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]); - break; - case s8: - load_i8(true, vreg_src_s32(jj, ll)); - break; - case u8: - load_i8(false, vreg_src_s32(jj, ll)); - break; - default: assert(!"unsupported src data type"); - } -}; - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::load_src_avg_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - // Don't generate useless code - if (masked && !msk) - return; - - const Vmm& vr_src = masked ? - vreg_src_s32(jj, ll) | mask(ll) : - vreg_src_s32(jj, ll); - - switch (jpp.src_dt) { - case s32: - vmovups(vr_src, ptr[aux_reg_src_w + offset]); - break; - case s8: - vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); - break; - case u8: - vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); - break; - default: assert(!"unsupported src data type"); - } -}; - -template -void jit_uni_i8i8_pooling_fwd_ker_t::load_src(int jj, int ll, int c_tail) { - using namespace data_type; - - int c_block = jpp.c_block; - int ur_c = jpp.ur_c; - - switch (jpp.alg) { - case pooling_max: { - auto offset = jj*c_block*sizeof_src_dt(); - bool masked = jj == ur_c - 1 && c_tail; - load_src_max_op(jj, ll, offset, masked, jpp.tail[0]); - break; - } - case pooling_avg_include_padding: - case pooling_avg_exclude_padding: { - auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt(); - bool masked = jj == ur_c - 1 && c_tail; - load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]); - break; - } - default: assert(!"unsupported algorithm"); - } -} - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - int c_block = jpp.c_block; - - if (masked) { - switch (jpp.src_dt) { - case s32: - vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj)); - break; - case s8: - case u8: { - // Store low half by mask (bytes 0...15) - lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); - maskmovdqu(vreg_dst(jj), xreg_mask_lo); - - // Do we need to store high half (bytes 16...31) ? - const uint64_t low_mask = (1ULL << (c_block/2))-1; - if (msk & ~low_mask) { - vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1); - add(reg_ptr_maskmovdqu_dst, c_block / 2); - maskmovdqu(vreg_dst(jj), xreg_mask_hi); - } - } break; - default: assert(!"unsupported src data type"); - } - } else - vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); -} - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_max_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - if (masked) { - switch (jpp.src_dt) { - case s32: - vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); - break; - case s8: - case u8: - vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); - break; - default: assert(!"unsupported src data type"); - } - } else - vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); -} - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk){ - using namespace data_type; - - // Don't generate useless code - if (masked && !msk) - return; - - auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) { - - // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16} - // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0} - if (is_signed) - vpackssdw(vr_dst, vr_dst, vreg_zeros); - else - vpackusdw(vr_dst, vr_dst, vreg_zeros); - - // Permute qwords to restore original order - // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} - vpermq(vr_dst, vr_dst, 0x58); - - // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8} - // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx} - if (is_signed) - vpacksswb(vr_dst, vr_dst, vreg_zeros); - else - vpackuswb(vr_dst, vr_dst, vreg_zeros); - - }; - - auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) { - - // Conversion s32 -> s8/u8 - s32_to_i8(is_signed, vr_dst); - - // Need to use mask of tail? - if (is_masked) { - // load ll-th part of mask into vreg_mask_q - load_vreg_mask_q(ll); - } - - // store 8 bytes - lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); - maskmovdqu(vr_dst, xreg_mask_q); - }; - - switch (jpp.dst_dt) { - case s32: - if (masked) { - vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll)); - } else - vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll)); - break; - case s8: - store_i8(true, masked, vreg_dst_s32(jj, ll)); - break; - case u8: - store_i8(false, masked, vreg_dst_s32(jj, ll)); - break; - default: assert(!"unsuppotred dst data_type"); - } -} - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::store_dst_avg_op(int jj, int ll, - size_t offset, bool masked, uint64_t msk) { - using namespace data_type; - - // Don't generate useless code - if (masked && !msk) - return; - - const Vmm& vr_dst = masked ? - vreg_dst_s32(jj, ll) | mask(ll) : - vreg_dst_s32(jj, ll); - - switch (jpp.dst_dt) { - case s32: - vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); - break; - case s8: - vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); - break; - case u8: - vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); - break; - default: assert(!"unsupported dst data_type"); - } -} - - -template -void jit_uni_i8i8_pooling_fwd_ker_t::store_dst(int jj, int ll, - int c_tail) { - using namespace data_type; - - int c_block = jpp.c_block; - int ur_c = jpp.ur_c; - - switch(jpp.alg) { - case pooling_max: { - auto offset = jj*c_block*sizeof_dst_dt(); - bool masked = jj == ur_c - 1 && c_tail; - store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]); - break; - } - case pooling_avg_include_padding: - case pooling_avg_exclude_padding: { - auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt(); - bool masked = jj == ur_c - 1 && c_tail; - store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]); - break; - } - default: assert(!"unsupported pooling algorithm"); - } -} - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) -{ - using namespace data_type; - switch (jpp.src_dt) { - case s32: - vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); - break; - case s8: - vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); - break; - case u8: - vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); - break; - default: assert(!"unsupported src data type"); - } -} - -template <> -void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_op(const int jj) -{ - using namespace data_type; - - // Compare - switch (jpp.src_dt) { - case s32: - vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); - break; - case s8: - vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); - break; - case u8: - vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); - break; - default: assert(!"unsupported src data type"); - } - - // move max values into vreg_dst - if (jpp.src_dt == s32) - vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); - else - vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); -} - - -template -void jit_uni_i8i8_pooling_fwd_ker_t::compute_max_step(int ur_c, int c_tail) -{ - Label l_kw, l_kh; - - int iw = jpp.iw; - int c = jpp.c; - - for (int jj = 0; jj < ur_c; jj++) - vmovups(vreg_dst(jj), vreg_tmp); - - mov(aux_reg_src_h, reg_ptr_src_i8); - - xor_(kj, kj); - L(l_kh); - { - mov(aux_reg_src_w, aux_reg_src_h); - xor_(ki, ki); - L(l_kw); - { - for (int jj = 0; jj < ur_c; jj++) { - load_src(jj, 0, c_tail); - compute_max_op(jj); - } - add(aux_reg_src_w, c * sizeof_src_dt()); - inc(ki); - cmp(ki, reg_kw); - jl(l_kw, T_NEAR); - } - add(aux_reg_src_h, iw * c * sizeof_src_dt()); - inc(kj); - cmp(kj, reg_kh); - jl(l_kh, T_NEAR); - } - - for (int jj = 0; jj < ur_c; jj++) - store_dst(jj, 0, c_tail); -} - -template -void jit_uni_i8i8_pooling_fwd_ker_t::compute_avg_step(int ur_c, int c_tail) -{ - using namespace data_type; - - Label l_kw, l_kh; - - int iw = jpp.iw; - int c = jpp.c; - - const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt); - - for (int jj = 0; jj < ur_c; jj++) { - for (int ll = 0; ll < num_ll; ll++) { - bool masked = jj == ur_c - 1 && c_tail; - size_t msk = jpp.tail[ll]; - if (!(masked && !msk)) { - uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll)); - uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll)); - } - } - } - - mov(aux_reg_src_h, reg_ptr_src_i8); - - xor_(kj, kj); - L(l_kh); - { - mov(aux_reg_src_w, aux_reg_src_h); - xor_(ki, ki); - L(l_kw); - { - for (int jj = 0; jj < ur_c; jj++) { - for (int ll = 0; ll < num_ll; ll++) { - bool masked = jj == ur_c - 1 && c_tail; - size_t msk = jpp.tail[ll]; - if (!(masked && !msk)) { - load_src(jj, ll, c_tail); - vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), - vreg_src_s32(jj, ll)); - } - } - } - add(aux_reg_src_w, c * sizeof_src_dt()); - inc(ki); - cmp(ki, reg_kw); - jl(l_kw, T_NEAR); - } - add(aux_reg_src_h, iw * c * sizeof_src_dt()); - inc(kj); - cmp(kj, reg_kh); - jl(l_kh, T_NEAR); - } - - for (int jj = 0; jj < ur_c; jj++) { - for (int ll = 0; ll < num_ll; ll++) { - bool masked = jj == ur_c - 1 && c_tail; - size_t msk = jpp.tail[ll]; - if (!(masked && !msk)) { - vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll)); - vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp); - vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll)); - store_dst(jj, ll, c_tail); - } - } - } -} - -template -void jit_uni_i8i8_pooling_fwd_ker_t::compute_step(int ur_c, int c_tail) { - switch (jpp.alg) { - case pooling_max: - compute_max_step(ur_c, c_tail); break; - case pooling_avg_include_padding: - case pooling_avg_exclude_padding: - compute_avg_step(ur_c, c_tail); break; - default: assert(!"unsupported pooling algorithm"); - } -} - -template -void jit_uni_i8i8_pooling_fwd_ker_t::compute_c_block(){ - Label l_main_loop; - - int nb_c = jpp.nb_c; - int c_block = jpp.c_block; - int ur_c = jpp.ur_c; - int ur_c_tail = jpp.ur_c_tail; - int c_steps = nb_c / ur_c; - int c_tail = jpp.c_tail; - - xor_(c_iter, c_iter); - if (c_steps > 0) { - L(l_main_loop); { - compute_step(ur_c, 0); - add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt()); - add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt()); - inc(c_iter); - cmp(c_iter, c_steps); - jl(l_main_loop, T_NEAR); - } - } - - if (ur_c_tail != 0) { - compute_step(ur_c_tail, c_tail); - } -} - -template<> -void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { - using namespace data_type; - using cpu_isa = cpu_isa_traits; - - // AVX2 mask initialization: mask stored in Ymm-regs - auto init = [&](uint64_t bit_mask, bool init_mask_q) { - const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t); - - uint64_t vmask[QW_PER_VREG]; - for (size_t i = 0; i < QW_PER_VREG; i++){ - - uint64_t qw_vmask=0ULL; - const size_t DBITS = 8*sizeof_src_dt(); - const uint64_t VMSK = 1ULL << (DBITS-1); - const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS; - for (size_t j = 0; j < D_PER_QW; j++) { - if (bit_mask & 1) - qw_vmask |= VMSK << DBITS * j; - bit_mask >>= 1; - } - vmask[i] = qw_vmask; - } - - // Put QWORDS with target mask into xmm regs - const int xdst_i[QW_PER_VREG] = { - xreg_mask_lo.getIdx(), - xreg_mask_lo.getIdx(), - xreg_mask_hi.getIdx(), - xreg_mask_hi.getIdx() - }; - const int xsrc_i[QW_PER_VREG] = { - vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0} - xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1} - vreg_zeros.getIdx(), - xreg_mask_hi.getIdx() - }; - const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg - - for (size_t i = 0; i < QW_PER_VREG; i++) { - mov(reg_mask, vmask[i]); - vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]); - } - - // Merge Low (xreg_mask_lo alias for vreg_mask.xreg) - // and High (xreg_mask_hi) into full vreg_mask - // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg} - vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1); - - // Keep only low qword of mask in xreg_mask_q - if (init_mask_q) { - mov(reg_mask, vmask[0]); - vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0); - } - }; - - uint64_t tail_mask = (1ULL << jpp.c_tail) - 1; - switch (jpp.alg) { - case pooling_max: - // For "max" we need mask only in case of non-zero tail - if (tail_mask) - init(tail_mask, false); - break; - case pooling_avg_include_padding: - case pooling_avg_exclude_padding: - // For "avg" we need mask: - // - s32 - in case of the non-zero tail - // - s8/u8 - irrespective of the tail - switch (jpp.src_dt) { - case s32: - if (tail_mask) - init(tail_mask, false); - break; - case s8: - case u8: - init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0); - break; - default: assert(!"unsupported src data type"); - } - break; - default: assert(!"unsupported pooling algorithm"); - } -} - -template<> -void jit_uni_i8i8_pooling_fwd_ker_t::init_mask() { - - for (int ll = 0; ll < max_num_ll; ll++) { - mov(reg_mask, jpp.tail[ll]); - kmovq(mask(ll), reg_mask); - } -} - -template -void jit_uni_i8i8_pooling_fwd_ker_t::init_tmp_reg() { - using namespace data_type; - - switch (jpp.alg) { - case pooling_avg_include_padding: - case pooling_avg_exclude_padding: - mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]); - movq(xmm_tmp, reg_tmp); - vpbroadcastd(vreg_tmp, xmm_tmp); - break; - case pooling_max: - switch (jpp.src_dt) { - case s32: - mov(reg_tmp, nstl::numeric_limits::lowest()); - break; - case s8: - mov(reg_tmp, nstl::numeric_limits::lowest()); - break; - case u8: - mov(reg_tmp, nstl::numeric_limits::lowest()); - break; - default: assert(!"unsupported src data_type"); - } - - movq(xmm_tmp, reg_tmp); - if (jpp.src_dt == s32) - vpbroadcastd(vreg_tmp, xmm_tmp); - else - vpbroadcastb(vreg_tmp, xmm_tmp); - break; - default: assert(!"unsupported pooling algorithm"); - } - -} - -template -void jit_uni_i8i8_pooling_fwd_ker_t::generate() { - preamble(); - -#if !defined(_WIN32) - // Always use rcx as abi_param1 - - // see the note about maskmovdqu near reg_param. - mov(rcx, rdi); -#endif - -# define READ_PARAM(reg, field) \ - mov(reg, ptr[reg_param + offsetof(call_params_t, field)]) - READ_PARAM(reg_ptr_src_i8, src_i8); - READ_PARAM(reg_ptr_dst_i8, dst_i8); - READ_PARAM(reg_kw, kw_range); - READ_PARAM(reg_kh, kh_range); - -# undef READ_PARAM - - uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); - - init_mask(); - - init_tmp_reg(); - - compute_c_block(); - - postamble(); -} - -template -status_t jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jit_pool_conf_t &jpp, - const pooling_pd_t *ppd) { - if (!mayiuse(isa)) - return status::unimplemented; - - const auto &pd = *ppd->desc(); - const memory_desc_wrapper src_d(ppd->src_md()); - const memory_desc_wrapper dst_d(ppd->dst_md()); - - jpp.mb = src_d.dims()[0]; - jpp.c = src_d.dims()[1]; - jpp.ih = src_d.dims()[2]; - jpp.iw = src_d.dims()[3]; - jpp.oh = dst_d.dims()[2]; - jpp.ow = dst_d.dims()[3]; - - jpp.stride_h = pd.strides[0]; - jpp.stride_w = pd.strides[1]; - jpp.kh = pd.kernel[0]; - jpp.kw = pd.kernel[1]; - - jpp.t_pad = pd.padding[0][0]; - jpp.l_pad = pd.padding[0][1]; - - jpp.alg = pd.alg_kind; - - jpp.src_dt = pd.src_desc.data_type; - jpp.dst_dt = pd.dst_desc.data_type; - - // data_type items per one vreg on the - // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 - // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 - int simd_w = cpu_isa_traits::vlen / data_type_size(jpp.src_dt); - - jpp.c_block = simd_w; - jpp.c_tail = jpp.c % jpp.c_block; - jpp.nb_c = jpp.c / jpp.c_block; - jpp.ur_c = 1; - jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + - (jpp.c_tail != 0); - - size_t tail_mask = (1ULL << jpp.c_tail) - 1; - - switch (jpp.alg) { - case pooling_max: - jpp.tail[0] = tail_mask; - jpp.tail[1] = 0; - jpp.tail[2] = 0; - jpp.tail[3] = 0; - break; - case pooling_avg_include_padding: - case pooling_avg_exclude_padding: { - // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32) - // avx2 : 8, avx512 : 16 - const size_t msk_gran = cpu_isa_traits::vlen / data_type_size(avg_proc_dt); - const size_t msk_msk = (1ULL << msk_gran) - 1; - size_t m = tail_mask; - for (size_t ll = 0; ll < max_num_ll; ll++) { - jpp.tail[ll] = m & msk_msk; - m = m >> msk_gran; - } - break; - } - default: return status::unimplemented; - } - - return status::success; -} - -template -status_t jit_uni_i8i8_pooling_fwd_t::pd_t::jit_conf() { - return jit_uni_i8i8_pooling_fwd_ker_t::init_conf(jpp_, this); -} - -template -jit_uni_i8i8_pooling_fwd_t:: -jit_uni_i8i8_pooling_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd), ker_(nullptr) -{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t(pd()->jpp_); } - -template -jit_uni_i8i8_pooling_fwd_t:: -~jit_uni_i8i8_pooling_fwd_t() { delete ker_; } - -template -void jit_uni_i8i8_pooling_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC); - auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - - const auto &jpp = pd()->jpp_; - - parallel_nd(jpp.mb, jpp.oh, jpp.ow, - [&](int n, int oh, int ow) { - const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0); - const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0); - - const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h); - const int kh_end = nstl::min(jpp.kh, - jpp.ih + jpp.t_pad - oh * jpp.stride_h); - const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w); - const int kw_end = nstl::min(jpp.kw, - jpp.iw + jpp.l_pad - ow * jpp.stride_w); - - auto p = typename jit_uni_i8i8_pooling_fwd_ker_t::call_params_t(); - p.src_i8 = &src_i8[ - src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()]; - p.dst_i8 = &dst_i8[ - dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()]; - p.kw_range = (size_t)(kw_end - kw_start); - p.kh_range = (size_t)(kh_end - kh_start); - p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ? - p.kh_range*p.kw_range : jpp.kw*jpp.kh); - - ker_->ker_(&p); - }); -} - -// Explicit instantiation only for supported values. -// -template struct jit_uni_i8i8_pooling_fwd_ker_t; -template struct jit_uni_i8i8_pooling_fwd_t; - -template struct jit_uni_i8i8_pooling_fwd_ker_t; -template struct jit_uni_i8i8_pooling_fwd_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp deleted file mode 100644 index d757679df..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP -#define CPU_JIT_UNI_I8I8_POOLING_HPP - -#include "c_types_map.hpp" - -#include "cpu_pooling_pd.hpp" -#include "cpu_primitive.hpp" - -#include "cpu_isa_traits.hpp" -#include "jit_primitive_conf.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_uni_i8i8_pooling_fwd_ker_t; - -template -struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_pooling_fwd_pd_t { - using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_i8i8_pooling_fwd_t); - - status_t init() { - bool ok = true - && mayiuse(isa) - && ndims() == 4 - && set_default_params() == status::success - && desc()->prop_kind == prop_kind::forward_inference - && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, - alg_kind::pooling_avg_include_padding, - alg_kind::pooling_avg_exclude_padding) - && utils::one_of(src_md()->data_type, data_type::s32, - data_type::s8, data_type::u8) - && src_md()->data_type == dst_md()->data_type - && attr()->has_default_values() - && memory_desc_matches_tag(*src_md(), format_tag::nhwc) - && memory_desc_matches_tag(*dst_md(), format_tag::nhwc); - if (!ok) return status::unimplemented; - - return jit_conf(); - } - - jit_pool_conf_t jpp_; - - protected: - status_t jit_conf(); - }; - - jit_uni_i8i8_pooling_fwd_t(const pd_t *apd); - ~jit_uni_i8i8_pooling_fwd_t(); - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_uni_i8i8_pooling_fwd_ker_t *ker_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp deleted file mode 100644 index 2c5a8e897..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp +++ /dev/null @@ -1,305 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_uni_lrn_kernel_f32.hpp" -#include "jit_uni_lrn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::utils; - -template -jit_uni_lrn_fwd_t::jit_uni_lrn_fwd_t(const pd_t *apd) - : cpu_primitive_t(apd), ker_(nullptr) - , ker_first_(nullptr), ker_last_(nullptr) -{ - using namespace alg_kind; - - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - const int ls = pd()->desc()->local_size; - float A = pd()->desc()->lrn_alpha / ls; - float K = pd()->desc()->lrn_k; - - auto pk = pd()->desc()->prop_kind; - auto ak = pd()->desc()->alg_kind; - auto dat_tag = pd()->dat_tag_; - - if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { - ker_ = new jit_uni_lrn_fwd_kernel_f32( - nchw8c_across(H, W, 0), A, K, pk); - ker_first_ = new jit_uni_lrn_fwd_kernel_f32( - nchw8c_across(H, W, -1), A, K, pk); - ker_last_ = new jit_uni_lrn_fwd_kernel_f32( - nchw8c_across(H, W, +1), A, K, pk); - } else if (dat_tag == nChw8c && ak == lrn_within_channel) { - /* within channel, local_size (x) local_size */ - A /= ls; /* XXX: why? */ - ker_ = new jit_uni_lrn_fwd_kernel_f32( - nchw8c_within(H, W, ls), A, K, pk); - } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { - ker_ = new jit_uni_lrn_fwd_kernel_f32( - nchw_across(C, H*W, 0), A, K, pk); - int remind = (H*W) % VECTOR_LENGTH; - if (remind != 0) { - ker_last_ = new jit_uni_lrn_fwd_kernel_f32( - nchw_across(C, H*W, remind), A, K, pk); - } - } else if (true /* XXX: why */) { - ker_ = new jit_uni_lrn_fwd_kernel_f32(nhwc_across(C), A, K, pk); - } -} - -template -jit_uni_lrn_fwd_t::~jit_uni_lrn_fwd_t() -{ delete ker_; delete ker_first_; delete ker_last_; } - -template -void jit_uni_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - using namespace alg_kind; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); - - const int N = pd()->MB(); - const int C = pd()->C(); - const int HW = pd()->H() * pd()->W(); - const int ls = pd()->desc()->local_size; - - auto ak = pd()->desc()->alg_kind; - auto dat_tag = pd()->dat_tag_; - - if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { - parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { - jit_args_fwd_t args; - args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; - args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; - args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; - if (c8 == 0) - (*ker_first_)(&args); - else if (c8 == C / VECTOR_LENGTH - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - }); - } - else if (dat_tag == nChw8c && ak == lrn_within_channel) { - parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { - jit_args_fwd_t args; - args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; - args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; - args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; - (*ker_)(&args); - }); - } - else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { - parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH, - [&](int n, int hw8) { - jit_args_fwd_t args; - args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH]; - args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH]; - args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH]; - if ((hw8 + 1)*VECTOR_LENGTH > HW) - (*ker_last_)(&args); - else - (*ker_)(&args); - }); - } - else { // nhwc - parallel_nd(N, HW, [&](int n, int hw) { - jit_args_fwd_t args; - args.src = &src[n*HW*C + hw * C]; - args.dst = &dst[n*HW*C + hw * C]; - args.scratch = &ws[n*HW*C + hw * C]; - (*ker_)(&args); - }); - } -} - -template -status_t jit_uni_lrn_fwd_t::pd_t::init() { - using namespace prop_kind; - using namespace alg_kind; - - const memory_desc_wrapper data_d(src_md()); - bool ok = true - && mayiuse(isa) - && is_fwd() - && everyone_is(data_type::f32, data_d.data_type()) - && !has_zero_dim_memory() - && data_d.ndims() == 4 - && data_d.dims()[1] % VECTOR_LENGTH == 0 - && data_d.dims()[1] >= 2 * VECTOR_LENGTH - && desc()->lrn_beta == 0.75 - && attr()->has_default_values(); - if (!ok) return unimplemented; - - if (desc_.prop_kind == forward_training) ws_md_ = *src_md(); - - dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc); - - bool args_ok_across = true - && desc()->alg_kind == lrn_across_channels - && desc()->local_size == 5 - && one_of(dat_tag_, nChw8c, nchw, nhwc); - - const int jit_max_local_size = 5; // bigger size triggers too big code size - bool args_ok_within = true - && desc()->alg_kind == lrn_within_channel - && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE - ? jit_max_local_size : MAX_LOCAL_SIZE) - && data_d.dims()[2] >= desc()->local_size - && data_d.dims()[3] >= desc()->local_size - && one_of(dat_tag_, nChw8c); - - return args_ok_across || args_ok_within ? success : unimplemented; -} - -template -jit_uni_lrn_bwd_t::jit_uni_lrn_bwd_t(const pd_t *apd) - : cpu_primitive_t(apd) - , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr) -{ - using namespace alg_kind; - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - const int ls = pd()->desc()->local_size; - float A = pd()->desc()->lrn_alpha / ls; - float B = pd()->desc()->lrn_beta; - - int use_h_parallelizm = 0;// XXX - if (C / VECTOR_LENGTH == 1) { - ker_ = new jit_uni_lrn_bwd_kernel_f32( - nchw8c_across(H, W, 3), A, B, use_h_parallelizm); - } - else { - ker_ = new jit_uni_lrn_bwd_kernel_f32( - nchw8c_across(H, W, 0), A, B, use_h_parallelizm); - ker_first_ = new jit_uni_lrn_bwd_kernel_f32( - nchw8c_across(H, W, -1), A, B, use_h_parallelizm); - ker_last_ = new jit_uni_lrn_bwd_kernel_f32( - nchw8c_across(H, W, +1), A, B, use_h_parallelizm); - } -} - -template -jit_uni_lrn_bwd_t::~jit_uni_lrn_bwd_t() -{ - delete ker_; delete ker_first_; delete ker_last_; -} - -template -void jit_uni_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const int N = pd()->MB(); - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - - int use_h_parallelizm = 0; // XXX - if (use_h_parallelizm) { - parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) { - auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH - + h*W*VECTOR_LENGTH; - jit_args_bwd_t args; - args.src = &src[offset]; - args.diff_dst = &diff_dst[offset]; - args.scratch = &ws[offset]; - args.diff_src = &diff_src[offset]; - if (C / VECTOR_LENGTH == 1) - (*ker_)(&args); - else if (c8 == 0) - (*ker_first_)(&args); - else if (c8 == C / VECTOR_LENGTH - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - }); - } - else { - parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { - auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH; - jit_args_bwd_t args; - args.src = &src[offset]; - args.diff_dst = &diff_dst[offset]; - args.scratch = &ws[offset]; - args.diff_src = &diff_src[offset]; - if (C / VECTOR_LENGTH == 1) - (*ker_)(&args); - else if (c8 == 0) - (*ker_first_)(&args); - else if (c8 == C / VECTOR_LENGTH - 1) - (*ker_last_)(&args); - else - (*ker_)(&args); - }); - } -} - -template -status_t jit_uni_lrn_bwd_t::pd_t::init() { - using namespace prop_kind; - using namespace alg_kind; - - const memory_desc_wrapper data_d(src_md()); - bool ok = true - && mayiuse(isa) - && !is_fwd() - && utils::everyone_is(data_type::f32, data_d.data_type()) - && !has_zero_dim_memory() - && data_d.ndims() == 4 - && data_d.dims()[1] % VECTOR_LENGTH == 0 - && desc()->lrn_beta == 0.75 - && attr()->has_default_values(); - if (!ok) return unimplemented; - - ws_md_ = *src_md(); - if (!compare_ws(hint_fwd_pd_)) return unimplemented; - - dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c); - - bool args_ok_across = true - && desc()->alg_kind == lrn_across_channels - && desc()->local_size == 5 - && utils::one_of(dat_tag_, nChw8c); - - return args_ok_across ? success : unimplemented; -} - -template struct jit_uni_lrn_fwd_t; -template struct jit_uni_lrn_fwd_t; -template struct jit_uni_lrn_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp deleted file mode 100644 index 333cd3396..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_UNI_LRN_HPP -#define CPU_JIT_UNI_LRN_HPP - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_isa_traits.hpp" -#include "cpu_lrn_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template struct jit_uni_lrn_fwd_kernel_f32; -template struct jit_uni_lrn_bwd_kernel_f32; - -template -struct jit_uni_lrn_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_lrn_fwd_pd_t { - using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_lrn_fwd_t); - - status_t init(); - - format_tag_t dat_tag_; - }; - - jit_uni_lrn_fwd_t(const pd_t *apd); - ~jit_uni_lrn_fwd_t(); - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_uni_lrn_fwd_kernel_f32 *ker_, *ker_first_, *ker_last_; -}; - -template -struct jit_uni_lrn_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_lrn_bwd_pd_t { - using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_lrn_bwd_t); - - status_t init(); - - format_tag_t dat_tag_; - }; - - jit_uni_lrn_bwd_t(const pd_t *apd); - ~jit_uni_lrn_bwd_t(); - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - jit_uni_lrn_bwd_kernel_f32 *ker_, *ker_first_, *ker_last_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp deleted file mode 100644 index 89af47272..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp +++ /dev/null @@ -1,1487 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "utils.hpp" - -#include "jit_uni_lrn_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - -////////////////////////////////////////////////////////////////////////////// -// forward kernel -template -void jit_uni_lrn_fwd_kernel_f32::within_body( - int hoff, int Hoff, int woff, int Woff, int stride, - Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, - prop_kind_t pk) -{ - vxorps(ysum, ysum, ysum); - for (int i = hoff; i <= Hoff; ++i) - { - for (int j = woff; j <= Woff; ++j) - { - if (i == 0 && j == 0) - { - vmovups(ydst, ptr[src]); - vfmadd231ps(ysum, ydst, ydst); - } - else - { - vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]); - vfmadd231ps(ysum, ytmp, ytmp); - } - } - } - vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk - vmovaps(ytmp, ysum); - if (pk != prop_kind::forward_inference) - vmovups(ptr[scratch], ytmp); - vmulps(ysum2, ysum, ysum); - vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3; - vsqrtps(ysum, ysum); - vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75 - vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum - vmovups(ptr[dst], ydst); - add(src, 32); - add(dst, 32); - if (pk != prop_kind::forward_inference) - add(scratch, 32); -} - -template -void jit_uni_lrn_fwd_kernel_f32::within_body_sse42( - int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk) -{ - Xbyak::Xmm xtmp_lo = xmm12; - Xbyak::Xmm xtmp_hi = xmm13; - Xbyak::Xmm xsum_lo = xmm8; - Xbyak::Xmm xsum_hi = xmm9; - Xbyak::Xmm xdst_lo = xmm10; - Xbyak::Xmm xdst_hi = xmm11; - Xbyak::Xmm xsum2_lo = xmm14; - Xbyak::Xmm xsum2_hi = xmm15; - - xorps(xsum_lo, xsum_lo); - xorps(xsum_hi, xsum_hi); - for (int i = hoff; i <= Hoff; ++i) - { - for (int j = woff; j <= Woff; ++j) - { - if (i == 0 && j == 0) - { - movups(xdst_lo, ptr[src]); - movups(xdst_hi, ptr[src + 4 * sizeof(float)]); - mulps(xdst_lo, xdst_lo); - mulps(xdst_hi, xdst_hi); - addps(xsum_lo, xdst_lo); - addps(xsum_hi, xdst_hi); - } - else - { - movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]); - movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]); - mulps(xtmp_lo, xtmp_lo); - mulps(xtmp_hi, xtmp_hi); - addps(xsum_lo, xtmp_lo); - addps(xsum_hi, xtmp_hi); - } - } - } - mulps(xsum_lo, xalpha); - mulps(xsum_hi, xalpha); - addps(xsum_lo, xk); - addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk - movaps(xtmp_lo, xsum_lo); - movaps(xtmp_hi, xsum_hi); - if (pk != prop_kind::forward_inference) { - movups(ptr[scratch], xtmp_lo); - movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi); - } - movaps(xsum2_lo, xsum_lo); - movaps(xsum2_hi, xsum_hi); - mulps(xsum2_lo, xsum_lo); - mulps(xsum2_hi, xsum_hi); - mulps(xsum_lo, xsum2_lo); - mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3; - - sqrtps(xsum_lo, xsum_lo); - sqrtps(xsum_hi, xsum_hi); - sqrtps(xsum_lo, xsum_lo); - sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75 - - movups(xdst_lo, ptr[src]); - movups(xdst_hi, ptr[src + 4 * sizeof(float)]); - divps(xdst_lo, xsum_lo); - divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum - - movups(ptr[dst], xdst_lo); - movups(ptr[dst + 4 * sizeof(float)], xdst_hi); - add(src, 32); - add(dst, 32); - if (pk != prop_kind::forward_inference) - add(scratch, 32); -} - -template -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - const struct nchw8c_within &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - Xbyak::Reg64 h = r9; - Xbyak::Reg64 w = r10; - Vmm ysum = Vmm(isa == avx2 ? 9 : 9); - Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10); - Vmm ydst = Vmm(isa == avx2 ? 11 : 11); - Vmm ytmp = Vmm(isa == avx2 ? 12 : 12); - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - if (isa == avx2) { - vbroadcastss(yalpha, xalpha); - } else { - shufps(xalpha, xalpha, 0); - } - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - if (isa == avx2) { - vbroadcastss(yk, xk); - } else { - shufps(xk, xk, 0); - } - - int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1; - - for (int i = 0; i < s2; ++i) - { - Label label_t; - for (int j = 0; j < s2; ++j) { - if (isa == avx2) { - within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); - } - else { - within_body_sse42(-i, S2, -j, S2, J.W, pk); - } - } - mov(w, J.W - J.size + 1); - L(label_t); - if (isa == avx2) { - within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-i, S2, -s2, S2, J.W, pk); - } - dec(w); - cmp(w, 0); - jne(label_t, T_NEAR); - for (int j = J.W - S2; j < J.W; ++j) { - if (isa == avx2) { - within_body(-i, S2, -s2, J.W - 1 - j, J.W, - ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk); - } - } - } - - mov(h, J.H - J.size + 1); - Label lrn_loop_h; - L(lrn_loop_h); - for (int j = 0; j < s2; ++j) { - if (isa == avx2) { - within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-s2, S2, -j, S2, J.W, pk); - } - } - mov(w, J.W - J.size + 1); - Label lrn_loop_w; - L(lrn_loop_w); - if (isa == avx2) { - within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-s2, S2, -s2, S2, J.W, pk); - } - dec(w); - cmp(w, 0); - jne(lrn_loop_w, T_NEAR); - for (int j = J.W - S2; j < J.W; ++j) { - if (isa == avx2) { - within_body(-s2, S2, -s2, J.W - 1 - j, J.W, - ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk); - } - } - dec(h); - cmp(h, 0); - jne(lrn_loop_h, T_NEAR); - - for (int i = J.H - S2; i < J.H; ++i) - { - for (int j = 0; j < s2; ++j) { - if (isa == avx2) { - within_body(-s2, J.H - 1 - i, -j, S2, J.W, - ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk); - } - } - - mov(w, J.W - J.size + 1); - Label label_b; - L(label_b); - if (isa == avx2) { - within_body(-s2, J.H - 1 - i, -s2, S2, J.W, - ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk); - } - dec(w); - cmp(w, 0); - jne(label_b, T_NEAR); - - for (int j = J.W - S2; j < J.W; ++j) { - if (isa == avx2) { - within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, - ysum, ydst, ytmp, ysum2, pk); - } else { - within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk); - } - } - } - - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template<> -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - const struct nchw8c_across &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - Xbyak::Reg64 t = rsp; - Xbyak::Reg64 hw = r9; - Xbyak::Xmm xsrc_prev = xmm2; - Xbyak::Ymm ysrc = ymm3; - Xbyak::Ymm yc = ymm3; - Xbyak::Xmm xsrc_next = xmm4; - Xbyak::Ymm ya = ymm5; - Xbyak::Ymm yb = ymm6; - Xbyak::Ymm yd = ymm7; - Xbyak::Ymm ye = ymm8; - Xbyak::Ymm ysum = ymm9; - Xbyak::Ymm ysum2 = ymm10; - Xbyak::Ymm ydst = ymm11; - Xbyak::Ymm ybase = ymm12; - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - sub(t, 64); - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - vbroadcastss(yalpha, xalpha); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - vbroadcastss(yk, xk); - - if (J.version == -1) - { - vxorps(xsrc_prev, xsrc_prev, xsrc_prev); - vmovups(ptr[t + 0], xsrc_prev); - } - if (J.version == +1) - { - vxorps(xsrc_next, xsrc_next, xsrc_next); - vmovups(ptr[t + 48], xsrc_next); - } - - mov(hw, J.H*J.W); - - Label lrn_loop; - L(lrn_loop); - - if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); - vmovups(ysrc, ptr[src]); - if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); - - if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev); - vmovups(ptr[t + 16], ysrc); - if (J.version != +1) vmovups(ptr[t + 48], xsrc_next); - - vmovups(ya, ptr[t + 16 - 8]); - vmovups(yb, ptr[t + 16 - 4]); - vmovups(yd, ptr[t + 16 + 4]); - vmovups(ye, ptr[t + 16 + 8]); - vmulps(ysum, yc, yc); - vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya - vfmadd231ps(ysum, yb, yb); - vfmadd231ps(ysum, yd, yd); - vfmadd231ps(ysum, ye, ye); - vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk - - vmovaps(ybase, ysum); - if (pk != prop_kind::forward_inference) - vmovups(ptr[scratch], ybase); - vmulps(ysum2, ysum, ysum); - vmulps(ysum, ysum, ysum2); // ysum = ybase^3; - vsqrtps(ysum, ysum); - vsqrtps(ysum, ysum); // ysum = ybase^0.75 - vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum - vmovups(ptr[dst], ydst); - - add(src, 32); - add(dst, 32); - if (pk != prop_kind::forward_inference) - add(scratch, 32); - dec(hw); - cmp(hw, 0); - jne(lrn_loop, T_NEAR); - - add(t, 64); - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template<> -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - const struct nchw8c_across &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - Xbyak::Reg64 t = rsp; - Xbyak::Reg64 hw = r9; - - Xbyak::Xmm xsrc_lo = xmm2; - Xbyak::Xmm xsrc_hi = xmm3; - Xbyak::Xmm xc_lo = xmm4; - Xbyak::Xmm xc_hi = xmm5; - Xbyak::Xmm xsum_lo = xc_lo; - Xbyak::Xmm xsum_hi = xc_hi; - Xbyak::Xmm xsrc_prev = xmm6; - Xbyak::Xmm xsrc_next = xmm7; - Xbyak::Xmm xa_lo = xmm8; - Xbyak::Xmm xa_hi = xmm9; - Xbyak::Xmm xb_lo = xmm10; - Xbyak::Xmm xb_hi = xmm11; - Xbyak::Xmm xd_lo = xmm12; - Xbyak::Xmm xd_hi = xmm13; - Xbyak::Xmm xe_lo = xmm14; - Xbyak::Xmm xe_hi = xmm15; - Xbyak::Xmm xbase_lo = xmm14; - Xbyak::Xmm xbase_hi = xmm15; - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - sub(t, 64); - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - shufps(xalpha, xalpha, 0); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - shufps(xk, xk, 0); - - if (J.version == -1) - { - xorps(xsrc_prev, xsrc_prev); - movups(ptr[t + 0], xsrc_prev); - } - if (J.version == +1) - { - xorps(xsrc_next, xsrc_next); - movups(ptr[t + 48], xsrc_next); - } - - mov(hw, J.H*J.W); - Label lrn_loop; - L(lrn_loop); - - if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); - movups(xsrc_lo, ptr[src]); - movups(xsrc_hi, ptr[src + 4 * sizeof(float)]); - if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]); - - if (J.version != -1) movups(ptr[t + 0], xsrc_prev); - movups(ptr[t + 16], xsrc_lo); - movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi); - if (J.version != +1) movups(ptr[t + 48], xsrc_next); - - movups(xa_lo, ptr[t + 16 - 8]); - movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]); - movups(xb_lo, ptr[t + 16 - 4]); - movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]); - movups(xd_lo, ptr[t + 16 + 4]); - movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]); - movups(xe_lo, ptr[t + 16 + 8]); - movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]); - movaps(xc_lo, xsrc_lo); - movaps(xc_hi, xsrc_hi); - mulps(xsum_lo, xc_lo); - mulps(xsum_hi, xc_hi); - mulps(xa_lo, xa_lo); - mulps(xa_hi, xa_hi); - addps(xsum_lo, xa_lo); - addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa - mulps(xb_lo, xb_lo); - mulps(xb_hi, xb_hi); - addps(xsum_lo, xb_lo); - addps(xsum_hi, xb_hi); - mulps(xd_lo, xd_lo); - mulps(xd_hi, xd_hi); - addps(xsum_lo, xd_lo); - addps(xsum_hi, xd_hi); - mulps(xe_lo, xe_lo); - mulps(xe_hi, xe_hi); - addps(xsum_lo, xe_lo); - addps(xsum_hi, xe_hi); - - mulps(xsum_lo, xalpha); - mulps(xsum_hi, xalpha); - addps(xsum_lo, xk); - addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk - - movaps(xbase_lo, xsum_lo); - movaps(xbase_hi, xsum_hi); - if (pk != prop_kind::forward_inference) { - movups(ptr[scratch], xbase_lo); - movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); - } - mulps(xsum_lo, xsum_lo); - mulps(xsum_hi, xsum_hi); - mulps(xsum_lo, xbase_lo); - mulps(xsum_hi, xbase_hi); // xsum = xbase^3; - sqrtps(xsum_lo, xsum_lo); - sqrtps(xsum_hi, xsum_hi); - sqrtps(xsum_lo, xsum_lo); - sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75 - divps(xsrc_lo, xsum_lo); - divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum - movups(ptr[dst], xsrc_lo); - movups(ptr[dst + 4 * sizeof(float)], xsrc_hi); - - add(src, 32); - add(dst, 32); - if (pk != prop_kind::forward_inference) - add(scratch, 32); - dec(hw); - cmp(hw, 0); - jne(lrn_loop, T_NEAR); - - add(t, 64); - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template<> -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - const struct nhwc_across &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - static const uint32_t mask[] = { - 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000, - 0x80000000, 0x80000000, 0x80000000, 0, 0 - }; - - Xbyak::Reg64 c = r9; - Xbyak::Ymm ya = ymm2; - Xbyak::Ymm yb = ymm3; - Xbyak::Ymm yc = ymm4; - Xbyak::Ymm yd = ymm5; - Xbyak::Ymm ye = ymm6; - Xbyak::Ymm ysum = ymm7; - Xbyak::Ymm ydst = ymm8; - Xbyak::Ymm ybase = ymm9; - Xbyak::Ymm ymask = ymm10; - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - vbroadcastss(yalpha, xalpha); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - vbroadcastss(yk, xk); - - vxorps(ysum, ysum, ysum); - - mov(imm_addr64, reinterpret_cast(&mask[0])); - vmovups(ymask, ptr[imm_addr64]); - vmaskmovps(ya, ymask, ptr[src - 8]); - vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 - - mov(imm_addr64, reinterpret_cast(&mask[1])); - vmovups(ymask, ptr[imm_addr64]); - vmaskmovps(yb, ymask, ptr[src - 4]); - vfmadd231ps(ysum, yb, yb); - - mov(c, J.C / 8 - 1); - Label lrn_loop; - L(lrn_loop); - - vmovups(yc, ptr[src]); - vmovups(yd, ptr[src + 4]); - vmovups(ye, ptr[src + 8]); - vfmadd231ps(ysum, yc, yc); - vfmadd231ps(ysum, yd, yd); - vfmadd231ps(ysum, ye, ye); - - vmovups(ydst, ysum); - vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk - - vmovaps(ybase, ydst); - if (pk != prop_kind::forward_inference) - vmovups(ptr[scratch], ybase); - vmulps(ydst, ydst, ydst); - vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; - vsqrtps(ydst, ydst); - vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 - - vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 - vmovups(ptr[dst], ydst); - - vxorps(ysum, ysum, ysum); - - add(src, 32); - add(dst, 32); - if (pk != prop_kind::forward_inference) - add(scratch, 32); - - vmovups(ya, ptr[src - 8]); - vfmadd231ps(ysum, ya, ya); - vmovups(yb, ptr[src - 4]); - vfmadd231ps(ysum, yb, yb); - - dec(c); - cmp(c, 0); - jne(lrn_loop, T_NEAR); - - vmovups(yc, ptr[src]); - vfmadd231ps(ysum, yc, yc); - - mov(imm_addr64, reinterpret_cast(&mask[2])); - vmovups(ymask, ptr[imm_addr64]); - vmaskmovps(yd, ymask, ptr[src + 4]); - vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 - - mov(imm_addr64, reinterpret_cast(&mask[3])); - vmovups(ymask, ptr[imm_addr64]); - vmaskmovps(ye, ymask, ptr[src + 8]); - vfmadd231ps(ysum, ye, ye); - - vmovups(ydst, ysum); - vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk - - vmovaps(ybase, ydst); - if (pk != prop_kind::forward_inference) - vmovups(ptr[scratch], ybase); - vmulps(ydst, ydst, ydst); - vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; - vsqrtps(ydst, ydst); - vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 - vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 - - vmovups(ptr[dst], ydst); - - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template<> -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - const struct nhwc_across &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - static const uint32_t mask[] = { - 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, - 0xffffffff, 0xffffffff, 0xffffffff, 0, 0 - }; - - static uint32_t store[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - }; - Xbyak::Reg64 c = r9; - - Xbyak::Xmm xdst_lo = xmm0; - Xbyak::Xmm xdst_hi = xmm1; - Xbyak::Xmm xa_lo = xmm2; - Xbyak::Xmm xa_hi = xmm3; - Xbyak::Xmm xb_lo = xmm2; - Xbyak::Xmm xb_hi = xmm3; - Xbyak::Xmm xc_lo = xmm4; - Xbyak::Xmm xc_hi = xmm5; - Xbyak::Xmm xd_lo = xmm6; - Xbyak::Xmm xd_hi = xmm7; - Xbyak::Xmm xe_lo = xmm8; - Xbyak::Xmm xe_hi = xmm9; - Xbyak::Xmm xsum_lo = xmm10; - Xbyak::Xmm xsum_hi = xmm11; - Xbyak::Xmm xmask_lo = xmm12; - Xbyak::Xmm xmask_hi = xmm13; - Xbyak::Xmm xbase_lo = xmm14; - Xbyak::Xmm xbase_hi = xmm15; - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - shufps(xalpha, xalpha, 0); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - shufps(xk, xk, 0); - - mov(store_addr, reinterpret_cast(&store[0])); - and_(store_addr, -15); - movups(ptr[store_addr], xalpha); - movups(ptr[store_addr + 4 * sizeof(float)], xk); - - xorps(xsum_lo, xsum_lo); - xorps(xsum_hi, xsum_hi); - - mov(imm_addr64, reinterpret_cast(&mask[0])); - movups(xmask_lo, ptr[imm_addr64]); - movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); - movups(xa_lo, ptr[src - 8]); - movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); - andps(xa_lo, xmask_lo); - andps(xa_hi, xmask_hi); - mulps(xa_lo, xa_lo); - mulps(xa_hi, xa_hi); - addps(xsum_lo, xa_lo); - addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 - - mov(imm_addr64, reinterpret_cast(&mask[1])); - movups(xmask_lo, ptr[imm_addr64]); - movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); - movups(xb_lo, ptr[src - 4]); - movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); - andps(xb_lo, xmask_lo); - andps(xb_hi, xmask_hi); - mulps(xb_lo, xb_lo); - mulps(xb_hi, xb_hi); - addps(xsum_lo, xb_lo); - addps(xsum_hi, xb_hi); - - mov(c, J.C / 8 - 1); - Label lrn_loop; - L(lrn_loop); - - movups(xc_lo, ptr[src]); - movups(xc_hi, ptr[src + 4 * sizeof(float)]); - movups(xd_lo, ptr[src + 4]); - movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); - movups(xe_lo, ptr[src + 8]); - movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); - mulps(xc_lo, xc_lo); - mulps(xc_hi, xc_hi); - addps(xsum_lo, xc_lo); - addps(xsum_hi, xc_hi); - mulps(xd_lo, xd_lo); - mulps(xd_hi, xd_hi); - addps(xsum_lo, xd_lo); - addps(xsum_hi, xd_hi); - mulps(xe_lo, xe_lo); - mulps(xe_hi, xe_hi); - addps(xsum_lo, xe_lo); - addps(xsum_hi, xe_hi); - - movaps(xdst_lo, xsum_lo); - movaps(xdst_hi, xsum_hi); - // xdst <- xsum*xalpha+xk - mulps(xdst_lo, ptr[store_addr]); - mulps(xdst_hi, ptr[store_addr]); - addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); - addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); - - movaps(xbase_lo, xdst_lo); - movaps(xbase_hi, xdst_hi); - if (pk != prop_kind::forward_inference) { - movups(ptr[scratch], xbase_lo); - movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); - } - mulps(xdst_lo, xdst_lo); - mulps(xdst_hi, xdst_hi); - mulps(xdst_lo, xbase_lo); - mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; - sqrtps(xdst_lo, xdst_lo); - sqrtps(xdst_hi, xdst_hi); - sqrtps(xdst_lo, xdst_lo); - sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 - - movups(xc_lo, ptr[src]); - movups(xc_hi, ptr[src + 4 * sizeof(float)]); - divps(xc_lo, xdst_lo); - divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 - movups(ptr[dst], xc_lo); - movups(ptr[dst + 4 * sizeof(float)], xc_hi); - - xorps(xsum_lo, xsum_lo); - xorps(xsum_hi, xsum_hi); - - add(src, 32); - add(dst, 32); - if (pk != prop_kind::forward_inference) - add(scratch, 32); - - movups(xa_lo, ptr[src - 8]); - movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); - mulps(xa_lo, xa_lo); - mulps(xa_hi, xa_hi); - addps(xsum_lo, xa_lo); - addps(xsum_hi, xa_hi); - movups(xb_lo, ptr[src - 4]); - movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); - mulps(xb_lo, xb_lo); - mulps(xb_hi, xb_hi); - addps(xsum_lo, xb_lo); - addps(xsum_hi, xb_hi); - - dec(c); - cmp(c, 0); - jne(lrn_loop, T_NEAR); - - movups(xc_lo, ptr[src]); - movups(xc_hi, ptr[src + 4 * sizeof(float)]); - mulps(xc_lo, xc_lo); - mulps(xc_hi, xc_hi); - addps(xsum_lo, xc_lo); - addps(xsum_hi, xc_hi); - - mov(imm_addr64, reinterpret_cast(&mask[2])); - movups(xmask_lo, ptr[imm_addr64]); - movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); - movups(xd_lo, ptr[src + 4]); - movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); - andps(xd_lo, xmask_lo); - andps(xd_hi, xmask_hi); - mulps(xd_lo, xd_lo); - mulps(xd_hi, xd_hi); - addps(xsum_lo, xd_lo); - addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 - - mov(imm_addr64, reinterpret_cast(&mask[3])); - movups(xmask_lo, ptr[imm_addr64]); - movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); - movups(xe_lo, ptr[src + 8]); - movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); - andps(xe_lo, xmask_lo); - andps(xe_hi, xmask_hi); - mulps(xe_lo, xe_lo); - mulps(xe_hi, xe_hi); - addps(xsum_lo, xe_lo); - addps(xsum_hi, xe_hi); - - movups(xdst_lo, xsum_lo); - movups(xdst_hi, xsum_hi); - // xdst <- xsum*xalpha+xk - mulps(xdst_lo, ptr[store_addr]); - mulps(xdst_hi, ptr[store_addr]); - addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); - addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); - - movaps(xbase_lo, xdst_lo); - movaps(xbase_hi, xdst_hi); - if (pk != prop_kind::forward_inference) { - movups(ptr[scratch], xbase_lo); - movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); - } - mulps(xdst_lo, xdst_lo); - mulps(xdst_hi, xdst_hi); - mulps(xdst_lo, xbase_lo); - mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; - sqrtps(xdst_lo, xdst_lo); - sqrtps(xdst_hi, xdst_hi); - sqrtps(xdst_lo, xdst_lo); - sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 - movups(xc_lo, ptr[src]); - movups(xc_hi, ptr[src + 4 * sizeof(float)]); - divps(xc_lo, xdst_lo); - divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 - - movups(ptr[dst], xc_lo); - movups(ptr[dst + 4 * sizeof(float)], xc_hi); - - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template<> -void jit_uni_lrn_fwd_kernel_f32::nchw_body( - int tail, int HW, prop_kind_t pk, - Xbyak::Ymm ymask, - Xbyak::Ymm ya, - Xbyak::Ymm yb, - Xbyak::Ymm yc, - Xbyak::Ymm yd, - Xbyak::Ymm ye, - Xbyak::Ymm ysum) {} - -template<> -void jit_uni_lrn_fwd_kernel_f32::nchw_body( - int tail, int HW, prop_kind_t pk, - Xbyak::Ymm ymask, - Xbyak::Ymm ya, - Xbyak::Ymm yb, - Xbyak::Ymm yc, - Xbyak::Ymm yd, - Xbyak::Ymm ye, - Xbyak::Ymm ysum) -{ - Xbyak::Ymm ydst = ymm14; - Xbyak::Ymm ybase = ymm15; - - vfmadd231ps(ysum, ye, ye); - - vmovups(ydst, ysum); - vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk - - vmovaps(ybase, ydst); - if (pk != prop_kind::forward_inference) - { - if (tail != 0) - vmaskmovps(ptr[scratch], ymask, ybase); - else - vmovups(ptr[scratch], ybase); - } - vmulps(ydst, ydst, ydst); - vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; - vsqrtps(ydst, ydst); - vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 - vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 - - if (tail != 0) - vmaskmovps(ptr[dst], ymask, ydst); - else - vmovups(ptr[dst], ydst); - - - vfnmadd231ps(ysum, ya, ya); - vmovups(ya, yb); - vmovups(yb, yc); - vmovups(yc, yd); - vmovups(yd, ye); -} - -template<> -void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( - int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) -{} - -template<> -void jit_uni_lrn_fwd_kernel_f32::nchw_tail_sse42( - int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) -{ - Xbyak::Xmm xmm_tmp = xmm10; - movaps(xmm_tmp, xtail_lo); - size_t offset = 0; - - if (tail > 4) { - movups(ptr[reg_dst], xtail_lo); - movaps(xmm_tmp, xtail_hi); - offset += 4 * sizeof(float); - tail -= 4; - } - movss(ptr[reg_dst + offset], xmm_tmp); - for (int i = 1; i < tail; i++) - { - psrldq(xmm_tmp, 4); - movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp); - } -} - -template<> -void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( - int tail, int HW, prop_kind_t pk, - Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, - Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, - Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) -{ - Xbyak::Xmm xdst_lo = xmm0; - Xbyak::Xmm xdst_hi = xmm1; - Xbyak::Xmm xbase_lo = xmm6; - Xbyak::Xmm xbase_hi = xmm7; - Xbyak::Xmm xtmp_lo = xmm8; - Xbyak::Xmm xtmp_hi = xmm9; - Xbyak::Xmm xa_lo = xmm6; - Xbyak::Xmm xa_hi = xmm7; - Xbyak::Xmm xb_lo = xmm8; - Xbyak::Xmm xb_hi = xmm9; - Xbyak::Xmm xc_lo = xmm10; - Xbyak::Xmm xc_hi = xmm11; - Xbyak::Xmm xd_lo = xmm12; - Xbyak::Xmm xd_hi = xmm13; - - // store xe - movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo); - movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi); - - mulps(xe_lo, xe_lo); - mulps(xe_hi, xe_hi); - addps(xsum_lo, xe_lo); - addps(xsum_hi, xe_hi); - - // xdst <- xsum*xalpha+xk - movaps(xdst_lo, xsum_lo); - movaps(xdst_hi, xsum_hi); - mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]); - mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]); - addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]); - addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]); - - movaps(xbase_lo, xdst_lo); - movaps(xbase_hi, xdst_hi); - if (pk != prop_kind::forward_inference) - { - if (tail != 0) { - nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi); - } - else { - movups(ptr[scratch], xbase_lo); - movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); - } - } - mulps(xdst_lo, xdst_lo); - mulps(xdst_hi, xdst_hi); - mulps(xdst_lo, xbase_lo); - mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; - sqrtps(xdst_lo, xdst_lo); - sqrtps(xdst_hi, xdst_hi); - sqrtps(xdst_lo, xdst_lo); - sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 - movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); - movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); - divps(xtmp_lo, xdst_lo); - divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 - movaps(xdst_lo, xtmp_lo); - movaps(xdst_hi, xtmp_hi); - - if (tail != 0) { - nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi); - } - else { - movups(ptr[dst], xdst_lo); - movups(ptr[dst + 4 * sizeof(float)], xdst_hi); - } - - movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]); - movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]); - mulps(xa_lo, xa_lo); - mulps(xa_hi, xa_hi); - subps(xsum_lo, xa_lo); - subps(xsum_hi, xa_hi); - - // xa <- xb - movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]); - movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]); - movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo); - movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi); - - // xb <- xc - movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); - movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); - movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo); - movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi); - - // xc <- xd - movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]); - movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]); - movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo); - movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi); - - // xd <- xe - movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]); - movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]); - movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo); - movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi); -} - -template<> -void jit_uni_lrn_fwd_kernel_f32::nchw_body_sse42( - int tail, int HW, prop_kind_t pk, - Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, - Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, - Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {} - -template<> -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - struct nchw_across J, - float A, - float K, - prop_kind_t pk, - void* code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - static const uint32_t mask[] = { - 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, - 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0 - }; - Xbyak::Reg64 c = r10; - Xbyak::Ymm ymask = ymm2; - Xbyak::Ymm ye = ymm3; - Xbyak::Ymm ya = ymm4; - Xbyak::Ymm yb = ymm5; - Xbyak::Ymm yc = ymm6; - Xbyak::Ymm yd = ymm7; - Xbyak::Ymm ysum = ymm8; - - this->preamble(); - - if (J.tail != 0) - { - mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); - vmovups(ymask, ptr[imm_addr64]); - } - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - vbroadcastss(yalpha, xalpha); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - vbroadcastss(yk, xk); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - - vxorps(ya, ya, ya); - vxorps(yb, yb, yb); - if (J.tail != 0) - vmaskmovps(yc, ymask, ptr[src + J.HW * 0]); - else - vmovups(yc, ptr[src + J.HW * 0]); - if (J.tail != 0) - vmaskmovps(yd, ymask, ptr[src + J.HW * 4]); - else - vmovups(yd, ptr[src + J.HW * 4]); - - vxorps(ysum, ysum, ysum); - vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 - vfmadd231ps(ysum, yd, yd); - - mov(c, J.C - 2); - Label lrn_loop; - L(lrn_loop); - - if (J.tail != 0) - vmaskmovps(ye, ymask, ptr[src + J.HW * 8]); - else - vmovups(ye, ptr[src + J.HW * 8]); - - nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); - - add(src, J.HW * 4); - add(dst, J.HW * 4); - if (pk != prop_kind::forward_inference) - add(scratch, J.HW * 4); - dec(c); - cmp(c, 0); - jne(lrn_loop, T_NEAR); - - vxorps(ye, ye, ye); - - nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); - add(src, J.HW * 4); - add(dst, J.HW * 4); - if (pk != prop_kind::forward_inference) - add(scratch, J.HW * 4); - - nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); - - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template<> -jit_uni_lrn_fwd_kernel_f32::jit_uni_lrn_fwd_kernel_f32( - struct nchw_across J, - float A, - float K, - prop_kind_t pk, - void* code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , alpha(A), k(K) -{ - static const uint32_t mask[] = { - 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, - 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0 - }; - - Xbyak::Reg64 c = r10; - - Xbyak::Xmm xmask_lo = xmm2; - Xbyak::Xmm xmask_hi = xmm3; - Xbyak::Xmm xsum_lo = xmm4; - Xbyak::Xmm xsum_hi = xmm5; - Xbyak::Xmm xa_lo = xmm6; - Xbyak::Xmm xa_hi = xmm7; - Xbyak::Xmm xb_lo = xmm8; - Xbyak::Xmm xb_hi = xmm9; - Xbyak::Xmm xc_lo = xmm10; - Xbyak::Xmm xc_hi = xmm11; - Xbyak::Xmm xd_lo = xmm12; - Xbyak::Xmm xd_hi = xmm13; - Xbyak::Xmm xe_lo = xmm14; - Xbyak::Xmm xe_hi = xmm15; - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(dst, ptr[this->param1 + 8]); - if (pk != prop_kind::forward_inference) - mov(scratch, ptr[this->param1 + 16]); - - sub(rsp, stack_space_needed); - mov(store_addr, rsp); - and_(store_addr, -15); - - mov(imm_addr64, float2int(this->alpha)); - movq(xalpha, imm_addr64); - shufps(xalpha, xalpha, 0); - - mov(imm_addr64, float2int(this->k)); - movq(xk, imm_addr64); - shufps(xk, xk, 0); - - // put alpha and k into store (free up regs) - movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha); - movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk); - - if (J.tail != 0) - { - mov(imm_addr64, reinterpret_cast(&mask[7 - J.tail])); - movups(xmask_lo, ptr[imm_addr64]); - movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); - } - // init xa, xb - xorps(xa_lo, xa_lo); - xorps(xa_hi, xa_hi); - xorps(xb_lo, xb_lo); - xorps(xb_hi, xb_hi); - - // read xc, xd - if (J.tail != 0) { - movups(xc_lo, ptr[src + J.HW * 0]); - movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); - andps(xc_lo, xmask_lo); - andps(xc_hi, xmask_hi); - } - else { - movups(xc_lo, ptr[src + J.HW * 0]); - movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); - } - if (J.tail != 0) { - movups(xd_lo, ptr[src + J.HW * 4]); - movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); - andps(xd_lo, xmask_lo); - andps(xd_hi, xmask_hi); - } - else { - movups(xd_lo, ptr[src + J.HW * 4]); - movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); - } - - // put xa, xb, xc, xd into store to free-up regs - movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo); - movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi); - movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo); - movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi); - movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo); - movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi); - movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo); - movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi); - - xorps(xsum_lo, xsum_lo); - xorps(xsum_hi, xsum_hi); - mulps(xc_lo, xc_lo); - mulps(xc_hi, xc_hi); - addps(xsum_lo, xc_lo); - addps(xsum_hi, xc_hi); - mulps(xd_lo, xd_lo); - mulps(xd_hi, xd_hi); - addps(xsum_lo, xd_lo); - addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 - - mov(c, J.C - 2); - Label lrn_loop; - L(lrn_loop); - - if (J.tail != 0) { - movups(xe_lo, ptr[src + J.HW * 8]); - movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); - andps(xe_lo, xmask_lo); - andps(xe_hi, xmask_hi); - } - else { - movups(xe_lo, ptr[src + J.HW * 8]); - movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); - } - - nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, - xe_lo, xe_hi, - xsum_lo, xsum_hi); - - add(src, J.HW * 4); - add(dst, J.HW * 4); - if (pk != prop_kind::forward_inference) - add(scratch, J.HW * 4); - dec(c); - cmp(c, 0); - jne(lrn_loop, T_NEAR); - - xorps(xe_lo, xe_lo); - xorps(xe_hi, xe_hi); - - nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, - xe_lo, xe_hi, - xsum_lo, xsum_hi); - add(src, J.HW * 4); - add(dst, J.HW * 4); - if (pk != prop_kind::forward_inference) - add(scratch, J.HW * 4); - - nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, - xe_lo, xe_hi, - xsum_lo, xsum_hi); - - add(rsp, stack_space_needed); - - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -////////////////////////////////////////////////////////////////////////////// -// backward kernel -template -jit_uni_lrn_bwd_kernel_f32::jit_uni_lrn_bwd_kernel_f32( - const struct nchw8c_across &J, - float A, - float B, - int use_h_parallel, - void *code_ptr, - size_t code_size) - : jit_generator(code_ptr, code_size) - , nalphabeta(-2 * A*B) - , use_h_parallelizm(use_h_parallel) -{ - Xbyak::Reg64 t = rsp; - Xbyak::Reg64 hw = r10; - - Xbyak::Xmm xsrc_prev = xmm1; - Xbyak::Xmm xws_prev = xmm2; - Xbyak::Xmm xdiffdst_prev = xmm3; - Xbyak::Ymm ysrc = ymm4; - Xbyak::Ymm yws = ymm5; - Xbyak::Ymm ydiffdst = ymm6; - Xbyak::Xmm xsrc_next = xmm7; - Xbyak::Xmm xws_next = xmm8; - Xbyak::Xmm xdiffdst_next = xmm9; - Xbyak::Ymm ya = ymm10; - Xbyak::Xmm xa = xmm10; - Xbyak::Ymm yb = ymm11; - Xbyak::Ymm yd = ymm12; - Xbyak::Ymm ye = ymm13; - Xbyak::Ymm ysum = ymm14; - Xbyak::Ymm ydiffsrc = ymm15; - - this->preamble(); - - mov(src, ptr[this->param1 + 0]); - mov(diffdst, ptr[this->param1 + 8]); - mov(workspace, ptr[this->param1 + 16]); - mov(diffsrc, ptr[this->param1 + 24]); - - sub(t, 64); - mov(imm_addr64, float2int(this->nalphabeta)); - movq(xnalphabeta, imm_addr64); - vbroadcastss(ynalphabeta, xnalphabeta); - - bool is_single = J.version == 3; - bool is_first = J.version == -1 || J.version == -2; - bool is_last = J.version == +1 || J.version == -2; - - if (is_first || is_single) { - vxorps(xsrc_prev, xsrc_prev, xsrc_prev); - vmovups(ptr[t + 0], xsrc_prev); - } - if (is_last || is_single) { - vxorps(xsrc_next, xsrc_next, xsrc_next); - vmovups(ptr[t + 48], xsrc_next); - } - mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W); - Label lrn_loop; - L(lrn_loop); - { - if (!is_first && !is_single) { - vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]); - vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); - vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]); - vmulps(xa, xws_prev, xws_prev); - vmulps(xa, xa, xws_prev); - vsqrtps(xa, xa); - vsqrtps(xa, xa); - vmulps(xa, xa, xws_prev); - vdivps(xsrc_prev, xsrc_prev, xa); - vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev); - } - - vmovups(ysrc, ptr[src]); - vmovups(yws, ptr[workspace]); - vmovups(ydiffdst, ptr[diffdst]); - vmulps(ya, yws, yws); - vmulps(ya, ya, yws); - vsqrtps(ya, ya); - vsqrtps(ya, ya); - vdivps(ydiffsrc, ydiffdst, ya); - vdivps(ysum, ydiffsrc, yws); - vmulps(ysum, ysum, ysrc); - - if (!is_last && !is_single) { - vmovups(xws_next, ptr[workspace + J.H*J.W * 32]); - vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); - vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]); - vmulps(xa, xws_next, xws_next); - vmulps(xa, xa, xws_next); - vsqrtps(xa, xa); - vsqrtps(xa, xa); - vmulps(xa, xa, xws_next); - vdivps(xsrc_next, xsrc_next, xa); - vdivps(xsrc_next, xsrc_next, xws_next); - vmulps(xdiffdst_next, xdiffdst_next, xsrc_next); - } - - if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev); - vmovups(ptr[t + 16], ysum); - if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next); - - vmovups(ya, ptr[t + 16 - 8]); - vmovups(yb, ptr[t + 16 - 4]); - vaddps(ysum, ysum, ya); - vmulps(ysrc, ysrc, ynalphabeta); - vaddps(ysum, ysum, yb); - - vmovups(yd, ptr[t + 16 + 4]); - vmovups(ye, ptr[t + 16 + 8]); - vaddps(ysum, ysum, yd); - vaddps(ysum, ysum, ye); - - vfmadd231ps(ydiffsrc, ysum, ysrc); - - vmovups(ptr[diffsrc], ydiffsrc); - - add(src, 32); - add(diffsrc, 32); - add(diffdst, 32); - add(workspace, 32); - - dec(hw); - cmp(hw, 0); - jne(lrn_loop, T_NEAR); - } - - add(t, 64); - this->postamble(); - - ker = reinterpret_cast(const_cast( - this->getCode())); -} - -template struct jit_uni_lrn_fwd_kernel_f32; -template struct jit_uni_lrn_fwd_kernel_f32; -template struct jit_uni_lrn_bwd_kernel_f32; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp deleted file mode 100644 index 2b3ed43cd..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp +++ /dev/null @@ -1,183 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP -#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP - -#include "c_types_map.hpp" -#include "type_helpers.hpp" - -#include "jit_generator.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - -enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 }; - -typedef struct { - const float *src; - float *dst, *scratch; -} jit_args_fwd_t; - -typedef struct { - const float *src, *diff_dst, *scratch; - float *diff_src; -} jit_args_bwd_t; - -struct nchw8c_across { - /* version: - * -1: channels 0..7, - * 1: channels C-8 .. C-1, - * 0: other channels - * 3: channels only for this kernel(without prev and next) - */ - int H, W, version; - nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {} -}; - -struct nchw8c_within { - int H, W, size; - nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {} -}; - -struct nchw_across { - int C, HW, tail; - nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {} -}; - -struct nhwc_across { - int C; - nhwc_across(int c) : C(c) {} -}; - -template -struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator { - Xbyak::Reg64 src = rax; - Xbyak::Reg64 dst = r8; - Xbyak::Reg64 scratch = rdx; - Xbyak::Reg64 imm_addr64 = rbx; - Xbyak::Reg64 store_addr = rbp; - - Xbyak::Xmm xalpha = xmm0; - Xbyak::Ymm yalpha = ymm0; - Xbyak::Xmm xk = xmm1; - Xbyak::Ymm yk = ymm1; - - float alpha; - float k; - - int stack_space_needed = 11 * 4 * sizeof(float) + 16; - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32) - - /* cpu specific part */ - using Vmm = typename utils::conditional::type; - - jit_uni_lrn_fwd_kernel_f32( - const struct nchw8c_within &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr = nullptr, - size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); - jit_uni_lrn_fwd_kernel_f32( - const struct nchw8c_across &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr = nullptr, - size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); - jit_uni_lrn_fwd_kernel_f32( - const struct nhwc_across &J, - float A, - float K, - prop_kind_t pk, - void *code_ptr = nullptr, - size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); - jit_uni_lrn_fwd_kernel_f32( - struct nchw_across J, - float A, - float K, - prop_kind_t pk, - void* code_ptr = nullptr, - size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); - - void within_body( - int hoff, int Hoff, int woff, int Woff, int stride, - Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, - prop_kind_t pk); - void within_body_sse42( - int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk); - - - void nchw_body(int tail, int HW, prop_kind_t pk, - Xbyak::Ymm ymask, - Xbyak::Ymm ya, - Xbyak::Ymm yb, - Xbyak::Ymm yc, - Xbyak::Ymm yd, - Xbyak::Ymm ye, - Xbyak::Ymm ysum); - void nchw_body_sse42(int tail, int HW, prop_kind_t pk, - Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, - Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, - Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); - void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst, - Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi); - - void operator()(jit_args_fwd_t *arg) { ker(arg); } - void(*ker)(jit_args_fwd_t *); -}; - -template -struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator { - Xbyak::Reg64 src = rax; - Xbyak::Reg64 diffsrc = r8; - Xbyak::Reg64 diffdst = r9; - Xbyak::Reg64 workspace = rdx; - Xbyak::Reg64 imm_addr64 = rsi; - - Xbyak::Xmm xnalphabeta = xmm0; - Xbyak::Ymm ynalphabeta = ymm0; - - float nalphabeta; - - int use_h_parallelizm; - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32) - - jit_uni_lrn_bwd_kernel_f32( - const struct nchw8c_across &J, - float A, - float B, - int use_h_parallel, - void *code_ptr = nullptr, - size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); - - void operator()(jit_args_bwd_t *arg) { ker(arg); } - void(*ker)(jit_args_bwd_t *); -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp deleted file mode 100644 index bf8e609d2..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp +++ /dev/null @@ -1,699 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* Copyright 2018 YANDEX LLC -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "utils.hpp" -#include "cpu_pooling_pd.hpp" - -#include "jit_uni_pool_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; -using namespace alg_kind; - -#define GET_OFF(field) offsetof(jit_pool_call_s, field) - -template -status_t jit_uni_pool_kernel_f32::init_conf(jit_pool_conf_t &jpp, - const pooling_pd_t *ppd) { - const auto &pd = *ppd->desc(); - const memory_desc_wrapper src_d( - ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); - const memory_desc_wrapper dst_d( - ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); - - bool args_ok = true - && mayiuse(isa) - && utils::one_of(pd.alg_kind, pooling_max, - pooling_avg_include_padding, - pooling_avg_exclude_padding); - if (!args_ok) return status::unimplemented; - - const int simd_w = isa == avx512_common ? 16 : 8; - const int ndims = src_d.ndims(); - - jpp.ndims = ndims; - jpp.mb = src_d.dims()[0]; - - jpp.c = utils::rnd_up(src_d.dims()[1], simd_w); - if (jpp.c > src_d.padded_dims()[1]) - return status::unimplemented; - - jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; - jpp.ih = src_d.dims()[ndims-2]; - jpp.iw = src_d.dims()[ndims-1]; - jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; - jpp.oh = dst_d.dims()[ndims-2]; - jpp.ow = dst_d.dims()[ndims-1]; - - jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1; - jpp.stride_h = pd.strides[ndims-4]; - jpp.stride_w = pd.strides[ndims-3]; - jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; - jpp.kh = pd.kernel[ndims-4]; - jpp.kw = pd.kernel[ndims-3]; - - jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0; - jpp.t_pad = pd.padding[0][ndims-4]; - jpp.l_pad = pd.padding[0][ndims-3]; - - jpp.alg = pd.alg_kind; - - jpp.is_training = pd.prop_kind == prop_kind::forward_training; - jpp.is_backward = pd.prop_kind == prop_kind::backward_data; - jpp.ind_dt = ppd->workspace_md() - ? ppd->workspace_md()->data_type : data_type::undef; - - jpp.simple_alg = jpp.is_training - || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); - - jpp.c_block = simd_w; - - jpp.nb_c = jpp.c / jpp.c_block; - if (jpp.alg == pooling_max) { - jpp.ur_w = isa == avx512_common ? 16 : 4; - if (jpp.is_training) - jpp.ur_w = isa == avx512_common ? 9 : 3; - else if (jpp.is_backward) - jpp.ur_w = isa == avx512_common ? 6 : 3; - } else { - if (jpp.is_backward) - jpp.ur_w = isa == avx512_common ? 12 : 6; - else - jpp.ur_w = isa == avx512_common ? 24 : 12; - } - if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow; - if (jpp.l_pad > jpp.ur_w) return status::unimplemented; - - jpp.ur_w_tail = jpp.ow % jpp.ur_w; - - return status::success; -} - -template -inline void jit_uni_pool_kernel_f32::maybe_recalculate_divisor(int jj, - int ur_w, int pad_l, int pad_r) { - if (jpp.alg == pooling_avg_exclude_padding) { - int kw = jpp.kw; - int stride_w = jpp.stride_w; - - int non_zero_kw = kw; - non_zero_kw -= nstl::max(0, pad_l - jj*stride_w); - non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w); - - if (non_zero_kw != prev_kw) { - mov(tmp_gpr, float2int((float)non_zero_kw)); - movq(xmm_tmp, tmp_gpr); - uni_vbroadcastss(vmm_tmp, xmm_tmp); - uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h); - prev_kw = non_zero_kw; - } - } -} - -template -inline void jit_uni_pool_kernel_f32::avg_step(int ur_w, int pad_l, - int pad_r) { - - int iw = jpp.iw; - int kw = jpp.kw; - int stride_w = jpp.stride_w; - int c_block = jpp.c_block; - Label kd_label, kh_label; - - for (int jj = 0; jj < ur_w; jj++) { - if (jpp.is_backward) { - uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); - maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); - uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); - } else { - uni_vpxor(vreg(jj), vreg(jj), vreg(jj)); - } - } - - if (jpp.simple_alg && jpp.ndims == 5) { - push(reg_input); - push(reg_output); - mov(aux_reg_input_d, reg_input); - mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); - L(kd_label); - mov(aux_reg_input, aux_reg_input_d); - } else { - mov(aux_reg_input, reg_input); - } - - xor_(kj, kj); - L(kh_label); - { - for (int ki = 0; ki < kw; ki++) { - int jj_start = nstl::max(0, pad_l - ki); - int jj_end = ur_w - - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); - for (int jj = jj_start; jj < jj_end; jj++) { - int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; - if (aux_input_offset > iw * c_block) - continue; - int input_offset = sizeof(float)*aux_input_offset; - if (jpp.is_backward) { - uni_vmovups(vreg(ur_w+jj), - ptr[aux_reg_input + input_offset]); - uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj)); - uni_vmovups(vmmword[aux_reg_input + input_offset], - vreg(ur_w+jj)); - } else { - uni_vaddps(vreg(jj), vreg(jj), - ptr[aux_reg_input + input_offset]); - } - } - } - add(aux_reg_input, sizeof(float) * iw * c_block); - inc(kj); - cmp(kj, reg_kh); - jl(kh_label, T_NEAR); - } - - if (jpp.simple_alg && jpp.ndims == 5) - { - add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); - dec(ki); - cmp(ki, 0); - jg(kd_label, T_NEAR); - pop(reg_output); - pop(reg_input); - } - - if (!jpp.is_backward) { - for (int jj = 0; jj < ur_w; jj++) { - maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); - uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); - uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], - vreg(jj)); - } - } -} - -template -inline void jit_uni_pool_kernel_f32::max_step_fwd(int ur_w, int pad_l, - int pad_r) { - int iw = jpp.iw; - int kw = jpp.kw; - int stride_w = jpp.stride_w; - int c_block = jpp.c_block; - Label kd_label, kh_label; - - mov(tmp_gpr, float2int(nstl::numeric_limits::lowest())); - movq(xmm_tmp, tmp_gpr); - uni_vbroadcastss(vmm_tmp, xmm_tmp); - - for (int jj = 0; jj < ur_w; jj++) { - uni_vmovups(vreg(jj), vmm_tmp); - if (jpp.is_training) - uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj)); - } - if (jpp.is_training) - { - movq(xmm_tmp, reg_k_shift); - uni_vpbroadcastd(vmm_k_offset, xmm_tmp); - } - - if (jpp.ndims == 5) { - push(reg_input); - push(reg_output); - mov(aux_reg_input_d, reg_input); - mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); - L(kd_label); - mov(aux_reg_input, aux_reg_input_d); - } else { - mov(aux_reg_input, reg_input); - } - xor_(kj, kj); - L(kh_label); - { - for (int ki = 0; ki < kw; ki++) { - int jj_start = nstl::max(0, pad_l - ki); - int jj_end = ur_w - - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); - for (int jj = jj_start; jj < jj_end; jj++) { - int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; - if (aux_input_offset > iw * c_block) - continue; - int input_offset = sizeof(float)*aux_input_offset; - uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]); - if (isa == sse42) { - movups(vmm_mask, vreg(jj)); - cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os); - blendvps(vreg(jj), vreg(ur_w+jj)); - if (jpp.is_training) - blendvps(vreg(2*ur_w+jj), vmm_k_offset); - } else if (isa == avx) { - vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj), - _cmp_lt_os); - vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj), - vreg(3*ur_w+jj)); - if (jpp.is_training) - vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), - vmm_k_offset, vreg(3*ur_w+jj)); - } else { - vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os); - vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj)); - if (jpp.is_training) - vblendmps(vreg(2*ur_w+jj) | k_store_mask, - vreg(2*ur_w+jj), vmm_k_offset); - } - } - if (jpp.is_training) { - if (isa == avx && !mayiuse(avx2)) { - avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); - } else { - uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); - } - } - } - add(aux_reg_input, sizeof(float) * iw * c_block); - inc(kj); - cmp(kj, reg_kh); - jl(kh_label, T_NEAR); - } - - if (jpp.ndims == 5) - { - add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); - if (jpp.is_training) { - mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]); - movq(xmm_tmp, tmp_gpr); - uni_vpbroadcastd(vmm_tmp, xmm_tmp); - if (isa == avx && !mayiuse(avx2)) { - Xmm t(vmm_mask.getIdx()); - avx_vpadd1(vmm_k_offset, xmm_tmp, t); - } else { - uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); - } - } - - dec(ki); - cmp(ki, 0); - jg(kd_label, T_NEAR); - pop(reg_output); - pop(reg_input); - } - - for (int jj = 0; jj < ur_w; jj++) { - uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj)); - if (jpp.is_training) { - const size_t step_index - = jj * c_block * types::data_type_size(jpp.ind_dt); - - auto x = xreg(2 * ur_w + jj); - if (jpp.ind_dt == data_type::u8) { - if (isa == sse42) { - for (int i = 0; i < 4; ++i) - pextrb(ptr[reg_index + step_index + i], x, 4*i); - } else if (isa == avx) { - auto y = yreg(2 * ur_w + jj); - if (jj == 0) { - movd(xmm_tmp, reg_shuf_mask); - uni_vpbroadcastd(vmm_tmp, xmm_tmp); - } - if (mayiuse(avx2)) { - vpshufb(y, y, vmm_tmp); - movd(ptr[reg_index + step_index], x); - vperm2i128(y, y, y, 0x1u); - movd(ptr[reg_index + step_index + 4], x); - } else { - Xmm t(vmm_mask.getIdx()); - vextractf128(t, y, 0); - vpshufb(t, t, xmm_tmp); - movd(ptr[reg_index + step_index], t); - vextractf128(t, y, 1); - vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] - movd(ptr[reg_index + step_index + 4], t); - } - } else { - auto v = vreg(2 * ur_w + jj); - vpmovusdb(x, v); - vmovups(ptr[reg_index + step_index], v | k_index_mask); - } - } else { - uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj)); - } - } - } -} - -template -inline void jit_uni_pool_kernel_f32::max_step_bwd(int ur_w, int pad_l, - int pad_r) { - - int iw = jpp.iw; - int kw = jpp.kw; - int stride_w = jpp.stride_w; - int c_block = jpp.c_block; - Label kd_label, kh_label; - - for (int jj = 0; jj < ur_w; jj++) { - uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); - - const size_t step_index - = jj * c_block * types::data_type_size(jpp.ind_dt); - if (jpp.ind_dt == data_type::u8) { - if (isa == sse42) { - movd(xreg(ur_w+jj), ptr[reg_index + step_index]); - pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); - } else if (isa == avx) { - movq(xreg(ur_w+jj), ptr[reg_index + step_index]); - if (!mayiuse(avx2)) { - avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp); - } else { - vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); - } - } else { - vmovups(vreg(ur_w+jj) | k_index_mask, - ptr[reg_index + step_index]); - vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); - } - } else { - uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]); - } - } - movq(xmm_tmp, reg_k_shift); - uni_vpbroadcastd(vmm_k_offset, xmm_tmp); - - if (jpp.simple_alg && jpp.ndims == 5) { - push(reg_input); - push(reg_output); - if (isa == sse42) { - // Save rdi since it is used in maskmovdqu - assert(dst_ptr == rdi); - push(dst_ptr); - } - mov(aux_reg_input_d, reg_input); - mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); - mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); - L(kd_label); - mov(aux_reg_input, aux_reg_input_d); - } else { - mov(aux_reg_input, reg_input); - } - - xor_(kj, kj); - L(kh_label); - { - for (int ki = 0; ki < kw; ki++) { - int jj_start = nstl::max(0, pad_l - ki); - int jj_end = ur_w - - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); - for (int jj = jj_start; jj < jj_end; jj++) { - int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; - if (aux_input_offset > iw * c_block) - continue; - int input_offset = sizeof(float)*aux_input_offset; - uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]); - if (isa == sse42) { - mov(dst_ptr, aux_reg_input); - add(dst_ptr, input_offset); - - movups(vreg(3*ur_w+jj), vreg(ur_w+jj)); - pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset); - addps(vreg(2*ur_w+jj), vreg(jj)); - maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj)); - } else if (isa == avx) { - if (mayiuse(avx2)) { - vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset); - } else { - avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp); - } - vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj)); - vmaskmovps(vmmword[aux_reg_input + input_offset], - vreg(3*ur_w+jj), vreg(2*ur_w+jj)); - } else { - vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset); - vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj)); - vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp); - vmovups(vmmword[aux_reg_input + - sizeof(float)*aux_input_offset], vreg(2*ur_w+jj)); - } - } - if (isa == avx && !mayiuse(avx2)) { - avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); - } else { - uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); - } - } - add(aux_reg_input, sizeof(float) * iw * c_block); - inc(kj); - cmp(kj, reg_kh); - jl(kh_label, T_NEAR); - } - if (jpp.simple_alg && jpp.ndims == 5) - { - add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); - - mov(tmp_gpr, reg_kd_pad_shift); - movq(xmm_tmp, tmp_gpr); - uni_vpbroadcastd(vmm_tmp, xmm_tmp); - if (isa == avx && !mayiuse(avx2)) { - Xmm t(vmm_mask.getIdx()); - avx_vpadd1(vmm_k_offset, vmm_tmp, t); - } else { - uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); - } - - dec(ki); - cmp(ki, 0); - jg(kd_label, T_NEAR); - if (isa == sse42) { - // Save rdi since it is used in maskmovdqu - assert(dst_ptr == rdi); - pop(dst_ptr); - } - pop(reg_output); - pop(reg_input); - } -} - -template -void jit_uni_pool_kernel_f32::maybe_zero_diff_src() { - assert(jpp.c_block * sizeof(float) % cpu_isa_traits::vlen == 0); - Label l_skip, l_zero; - - auto reg_oh = tmp_gpr; - mov(reg_oh, ptr[reg_param + GET_OFF(oh)]); - cmp(reg_oh, 0); - jz(l_skip, T_NEAR); - - if (jpp.ndims == 5) { - mov(zero_size, ptr[reg_param + GET_OFF(oh)]); - mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float)); - imul(zero_size, tmp_gpr); - } - - auto vzero = vmm_tmp; - uni_vpxor(vzero, vzero, vzero); - - auto reg_off = tmp_gpr; - xor_(reg_off, reg_off); - - L(l_zero); - { - const int dim = jpp.iw * jpp.c_block * sizeof(float); - for (int i = 0; i < dim; i += cpu_isa_traits::vlen) - uni_vmovups(ptr[reg_input + reg_off + i], vzero); - add(reg_off, dim); - if (jpp.ndims == 5) cmp(reg_off, zero_size); - else cmp(reg_off, jpp.ih * dim); - jl(l_zero, T_NEAR); - } - - L(l_skip); -} - -template -void jit_uni_pool_kernel_f32::generate() { - - this->preamble(); - - int ow = jpp.ow; - int iw = jpp.iw; - int kw = jpp.kw; - int kh = jpp.kh; - int ur_w = jpp.ur_w; - int c_block = jpp.c_block; - int stride_w = jpp.stride_w; - int l_pad = jpp.l_pad; - int ur_w_tail = jpp.ur_w_tail; - - int n_oi = ow / ur_w; - - prev_kw = 0; - - int vlen = cpu_isa_traits::vlen; - -#if defined(_WIN32) - // Always mimic the Unix ABI (see the note about maskmovdqu in the header - // file). - xor_(rdi, rcx); - xor_(rcx, rdi); - xor_(rdi, rcx); -#endif - - mov(reg_input, ptr[reg_param + GET_OFF(src)]); - mov(reg_output, ptr[reg_param + GET_OFF(dst)]); - if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) - mov(reg_index, ptr[reg_param + GET_OFF(indices)]); - mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]); - mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]); - mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); - - if (jpp.is_backward) - maybe_zero_diff_src(); - - if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { - mov(tmp_gpr, 1); - movq(xmm_one, tmp_gpr); - uni_vpbroadcastd(vmm_one, xmm_one); - - if (isa == avx) { - mov(reg_shuf_mask, 0x0c080400); - } else if (isa >= avx512_common) { - mov(tmp_gpr.cvt32(), 0x000f); - kmovw(k_index_mask, tmp_gpr.cvt32()); - } - } - - int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1)); - int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1); - if (r_pad1 > 0) n_oi--; - - if (jpp.alg == pooling_avg_exclude_padding) { - movq(xmm_ker_area_h, reg_ker_area_h); - uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h); - } - - if (jpp.alg == pooling_avg_include_padding) { - mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd))); - movq(xmm_tmp, tmp_gpr); - uni_vpbroadcastd(vmm_tmp, xmm_tmp); - } - if (l_pad > 0) { - n_oi--; - if (n_oi < 0 && r_pad1 > 0) { - step(ur_w, l_pad, r_pad1); - } else { - step(ur_w, l_pad, 0); - } - - if (isa == sse42) { - if (n_oi < 0 && r_pad1 > 0) { - step_high_half(ur_w, l_pad, r_pad1); - } else { - step_high_half(ur_w, l_pad, 0); - } - } - - if (isa == sse42) { - add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen); - add(reg_output, sizeof(float)*ur_w*c_block - vlen); - if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) - add(reg_index, (2 * ur_w - 1) * c_block / 2 - * types::data_type_size(jpp.ind_dt)); - } else { - add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block); - add(reg_output, sizeof(float)*ur_w*c_block); - if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) - add(reg_index, ur_w * c_block - * types::data_type_size(jpp.ind_dt)); - } - } - - xor_(oi_iter, oi_iter); - if (n_oi > 0) { - Label ow_loop; - L(ow_loop); { - step(ur_w, 0, 0); - - if (isa == sse42) { - step_high_half(ur_w, 0, 0); - } - - if (isa == sse42) { - add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); - add(reg_output, sizeof(float)*ur_w*c_block - vlen); - if (jpp.alg == pooling_max && - (jpp.is_training || jpp.is_backward)) - add(reg_index, (2 * ur_w - 1) * c_block / 2 - * types::data_type_size(jpp.ind_dt)); - } else { - add(reg_input, sizeof(float)*ur_w*stride_w*c_block); - add(reg_output, sizeof(float)*ur_w*c_block); - if (jpp.alg == pooling_max && - (jpp.is_training || jpp.is_backward)) - add(reg_index, ur_w * c_block - * types::data_type_size(jpp.ind_dt)); - } - - inc(oi_iter); - cmp(oi_iter, n_oi); - jl(ow_loop, T_NEAR); - } - } - - if (r_pad1 > 0 && n_oi >= 0) { - step(ur_w, 0, r_pad1); - - if (isa == sse42) { - step_high_half(ur_w, 0, r_pad1); - } - - if (isa == sse42) { - add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); - add(reg_output, sizeof(float)*ur_w*c_block - vlen); - if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) - add(reg_index, (2 * ur_w - 1) * c_block / 2 - * types::data_type_size(jpp.ind_dt)); - } else { - add(reg_input, sizeof(float)*ur_w*stride_w*c_block); - add(reg_output, sizeof(float)*ur_w*c_block); - if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) - add(reg_index, ur_w * c_block - * types::data_type_size(jpp.ind_dt)); - } - } - - if (ur_w_tail != 0) { - step(ur_w_tail, 0, r_pad); - - if (isa == sse42) { - step_high_half(ur_w_tail, 0, r_pad); - } - } - - this->postamble(); -} - -template struct jit_uni_pool_kernel_f32; -template struct jit_uni_pool_kernel_f32; // implements both and -template struct jit_uni_pool_kernel_f32; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp deleted file mode 100644 index 992b52658..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp +++ /dev/null @@ -1,192 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* Copyright 2018 YANDEX LLC -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_UNI_POOL_KERNEL_F32_HPP -#define JIT_UNI_POOL_KERNEL_F32_HPP - -#include - -#include "c_types_map.hpp" -#include "pooling_pd.hpp" -#include "type_helpers.hpp" - -#include "jit_generator.hpp" -#include "jit_primitive_conf.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace Xbyak; - -template -struct jit_uni_pool_kernel_f32: public jit_generator { - jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) - { - this->generate(); - jit_ker = (decltype(jit_ker))this->getCode(); - } - - jit_pool_conf_t jpp; - - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) - - void operator()(jit_pool_call_s *arg) { jit_ker(arg); } - static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); - -private: - using Vmm = typename utils::conditional3::type; - Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } - Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } - Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } - - const AddressFrame &vmmword = (isa == sse42) ? xword : - (isa == avx) ? yword : zword; - - Xmm vmm_mask = Xmm(0); - Xmm xmm_ker_area_h = Xmm(2); - Xmm xmm_one = Xmm(2); - Xmm xmm_tmp = Xmm(3); - - Vmm vmm_ker_area_h = Vmm(2); - Vmm vmm_one = Vmm(2); - Vmm vmm_tmp = Vmm(3); - - Vmm vmm_k_offset = Vmm(1); - - Opmask k_index_mask = Opmask(6); - Opmask k_store_mask = Opmask(7); - - // Here be some (tame) dragons. This kernel does not follow the regular - // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu - // instruction which has its destination hardcoded in rdi. Therefore: - // - all registers are hardcoded - // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI - // - // While this is only required by the backward pass, the quirk above - // is applied to the forward pass as well to keep things simpler. - - using reg64_t = const Xbyak::Reg64; - reg64_t reg_param = rdi; // Always mimic the Unix ABI - reg64_t reg_input = r8; - reg64_t aux_reg_input = r9; - reg64_t reg_index = r10; - reg64_t reg_output = r12; - reg64_t reg_kd_pad_shift = r13; - reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu - - reg64_t kj = r14; - reg64_t oi_iter = r15; - reg64_t reg_kh = rax; - reg64_t reg_k_shift = rbx; - reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above - reg64_t reg_ker_area_h = rdx; - - reg64_t zero_size = r15; - reg64_t ki = r12; - reg64_t aux_reg_input_d = r8; - - Xbyak::Reg32 reg_shuf_mask = esi; - - int prev_kw; - void (*jit_ker)(jit_pool_call_s *); - - void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); - void avg_step(int ur_w, int pad_l, int pad_r); - void max_step_fwd(int ur_w, int pad_l, int pad_r); - void max_step_bwd(int ur_w, int pad_l, int pad_r); - - void maybe_zero_diff_src(); - - void step(int ur_w, int pad_l, int pad_r) { - if (jpp.alg == alg_kind::pooling_max) { - if(jpp.is_backward) - max_step_bwd(ur_w, pad_l, pad_r); - else - max_step_fwd(ur_w, pad_l, pad_r); - } - else - avg_step(ur_w, pad_l, pad_r); - } - - void step_high_half(int ur_w, int pad_l, int pad_r) { - add(reg_input, sizeof(float) * 4); - add(reg_output, sizeof(float) * 4); - if (jpp.alg == alg_kind::pooling_max && - (jpp.is_training || jpp.is_backward)) - add(reg_index, types::data_type_size(jpp.ind_dt) * 4); - - step(ur_w, pad_l, pad_r); - } - - void generate(); - - void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { - assert(y0.getIdx() != x1.getIdx()); - vextractf128(xtmp, y0, 0); - vpaddd(xtmp, xtmp, x1); - vinsertf128(y0, y0, xtmp, 0); - vextractf128(xtmp, y0, 1); - vpaddd(xtmp, xtmp, x1); - vinsertf128(y0, y0, xtmp, 1); - } - - void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { - assert(false /*function should not be used*/); - paddd(x0, x1); - } - - void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { - Xmm x0(y0.getIdx()); - pshufd(xmm_tmp, x1, 1); - pmovzxbd(x0, x1); - pmovzxbd(xmm_tmp, xmm_tmp); - vinsertf128(y0, y0, xmm_tmp, 1); - } - - void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { - assert(false /*function should not be used*/); - pmovzxbd(x0, x1); - } - - void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { - assert(y0.getIdx() != y1.getIdx()); - assert(y0.getIdx() != y2.getIdx()); - Xmm x0(y0.getIdx()); - Xmm x2(y2.getIdx()); - vextractf128(x0, y1, 1); - vextractf128(xtmp, y2, 1); - pcmpeqd(xtmp, x0); - vextractf128(x0, y1, 0); - pcmpeqd(x0, x2); - vinsertf128(y0, y0, xtmp, 1); - } - - void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { - assert(false /*function should not be used*/); - pcmpeqd(x0, x1); - } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp deleted file mode 100644 index afbcf996d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp +++ /dev/null @@ -1,264 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_types.h" - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "nstl.hpp" - -#include "jit_uni_pooling.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -void jit_uni_pooling_fwd_t::execute_forward(const data_t *src, - data_t *dst, char *indices) const { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper indices_d(pd()->workspace_md()); - const size_t ind_dt_size = indices - ? types::data_type_size(indices_d.data_type()) : 0; - - const auto &jpp = pd()->jpp_; - - auto ker = [&](int n, int b_c, int oh) { - auto arg = jit_pool_call_s(); - - const int ij = oh * jpp.stride_h; - const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); - const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; - const int ih = nstl::max(ij - jpp.t_pad, 0); - - arg.src = &src[src_d.blk_off(n, b_c, ih)]; - arg.dst = &dst[dst_d.blk_off(n, b_c, oh)]; - if (indices) { - const size_t ind_off = indices_d.blk_off(n, b_c, oh); - arg.indices = &indices[ind_off * ind_dt_size]; - } - arg.oh = oh == 0; - arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; - arg.kh_padding_shift = i_t_overflow*jpp.kw; - arg.kw_padding = 0; - arg.ker_area_h = (float)(jpp.kh - - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - - nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); - (*kernel_)(&arg); - }; - - parallel_nd(jpp.mb, jpp.nb_c, jpp.oh, - [&](int n, int b_c, int oh) { - ker(n, b_c, oh); - }); -} - -template -void jit_uni_pooling_fwd_t::execute_forward_3d(const data_t *src, - data_t *dst, char *indices) const { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper indices_d(pd()->workspace_md()); - const size_t ind_dt_size = indices - ? types::data_type_size(indices_d.data_type()) : 0; - - const auto &jpp = pd()->jpp_; - - auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, - int d_b_overflow) { - auto arg = jit_pool_call_s(); - - const int ij = oh * jpp.stride_h; - const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); - const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; - const int ih = nstl::max(ij - jpp.t_pad, 0); - - arg.src = &src[src_d.blk_off(n, b_c, id, ih)]; - arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)]; - if (indices) { - const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); - arg.indices = &indices[ind_off * ind_dt_size]; - } - arg.oh = (oh + od == 0); - arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; - arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; - arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh; - arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; - arg.kw_padding = 0; - arg.ker_area_h = (float)(jpp.kh - - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - - nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - - nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - - nstl::max(0, jpp.f_pad - od*jpp.stride_d)); - - - (*kernel_)(&arg); - }; - - parallel_nd(jpp.mb, jpp.nb_c, jpp.od, - [&](int n, int b_c, int od) { - const int ik = od * jpp.stride_d; - const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); - const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad) - -jpp.id; - const int id = nstl::max(ik - jpp.f_pad, 0); - for (int oh = 0; oh < jpp.oh; ++oh) { - ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow); - } - }); -} - -template -void jit_uni_pooling_bwd_t::execute_backward(const data_t *diff_dst, - const char *indices, data_t *diff_src) const { - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper indices_d(pd()->workspace_md()); - const size_t ind_dt_size = indices - ? types::data_type_size(indices_d.data_type()) : 0; - - const auto &jpp = pd()->jpp_; - - auto ker = [&](int n, int b_c, int oh) { - auto arg = jit_pool_call_s(); - - const int ij = oh * jpp.stride_h; - const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); - const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; - const int ih = nstl::max(ij - jpp.t_pad, 0); - - arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)]; - arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)]; - if (indices) { - const size_t ind_off = indices_d.blk_off(n, b_c, oh); - arg.indices = &indices[ind_off * ind_dt_size]; - } - arg.oh = (oh == 0); - arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; - arg.kh_padding_shift = i_t_overflow*jpp.kw; - arg.kw_padding = 0; - arg.ker_area_h = (float)(jpp.kh - - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - - nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); - - (*kernel_)(&arg); - }; - - parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { - for (int oh = 0; oh < jpp.oh; ++oh) { - ker(n, b_c, oh); - } - }); -} - -template -void jit_uni_pooling_bwd_t::execute_backward_3d(const data_t *diff_dst, - const char *indices, data_t *diff_src) const { - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper indices_d(pd()->workspace_md()); - const size_t ind_dt_size = indices - ? types::data_type_size(indices_d.data_type()) : 0; - - const auto &jpp = pd()->jpp_; - - auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, - int d_b_overflow, int zero_size, int kd) { - auto arg = jit_pool_call_s(); - - const int ij = oh * jpp.stride_h; - const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); - const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; - const int ih = nstl::max(ij - jpp.t_pad, 0); - - arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)]; - arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)]; - if (indices) { - const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); - arg.indices = &indices[ind_off * ind_dt_size]; - } - arg.oh = zero_size; - arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; - arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; - arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh - + kd * jpp.kw * jpp.kh; - arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; - arg.kw_padding = 0; - arg.ker_area_h = (float)(jpp.kh - - nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - - nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - - nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - - nstl::max(0, jpp.f_pad - od*jpp.stride_d)); - - (*kernel_)(&arg); - }; - - if (jpp.simple_alg) { - - parallel_nd(jpp.mb, jpp.nb_c, jpp.od, - [&](int n, int b_c, int od) { - const int ik = od * jpp.stride_d; - const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); - const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd - - jpp.f_pad) - jpp.id; - const int id = nstl::max(ik - jpp.f_pad, 0); - int zero_s = jpp.stride_d - d_t_overflow - (nstl::max( - jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id); - for (int oh = 0; oh < jpp.oh; ++oh) { - ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, - (oh == 0) ? zero_s : 0, 0); - } - }); - } else { - ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c - * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw; - - parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; }); - - for (int kd = 0; kd < jpp.kd; ++kd) { - parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { - for (int od = 0; od < jpp.od; ++od) { - const int ik = od * jpp.stride_d; - const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); - const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd - - jpp.f_pad) - jpp.id; - if (kd >= jpp.kd - d_t_overflow - d_b_overflow) - continue; - const int id = nstl::max(ik - jpp.f_pad, 0); - for (int oh = 0; oh < jpp.oh; ++oh) { - ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, - 0, kd); - } - } - }); - } - } -} - - -template struct jit_uni_pooling_fwd_t; -template struct jit_uni_pooling_bwd_t; -template struct jit_uni_pooling_fwd_t; -template struct jit_uni_pooling_bwd_t; -template struct jit_uni_pooling_fwd_t; -template struct jit_uni_pooling_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp deleted file mode 100644 index 57bebacde..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp +++ /dev/null @@ -1,182 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_JIT_UNI_POOLING_HPP -#define CPU_JIT_UNI_POOLING_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_pooling_pd.hpp" -#include "cpu_primitive.hpp" - -#include "jit_uni_pool_kernel_f32.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct jit_uni_pooling_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_fwd_pd_t { - using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_pooling_fwd_t); - - status_t init() { - using namespace utils; - - bool ok = true - && set_default_params() == status::success - && is_fwd() - && !has_zero_dim_memory() - && everyone_is(data_type::f32, - src_md()->data_type, - dst_md()->data_type) - && attr()->has_default_values() - && memory_desc_matches_tag(*src_md(), desired_fmt_tag()) - && memory_desc_matches_tag(*dst_md(), desired_fmt_tag()); - if (!ok) return status::unimplemented; - - bool is_training = desc_.prop_kind == prop_kind::forward_training; - if (desc()->alg_kind == alg_kind::pooling_max && is_training) - init_default_ws(); - - return jit_uni_pool_kernel_f32::init_conf(jpp_, this); - } - - format_tag_t desired_fmt_tag() { - using namespace format_tag; - return ndims() == 4 - ? isa == avx512_common ? nChw16c : nChw8c - : isa == avx512_common ? nCdhw16c : nCdhw8c; - } - - jit_pool_conf_t jpp_; - }; - - jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } - - ~jit_uni_pooling_fwd_t() { delete kernel_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE); - - if (pd()->ndims() == 5) - execute_forward_3d(src, dst, ws); - else - execute_forward(src, dst, ws); - - return status::success; - } - -private: - void execute_forward(const data_t *src, data_t *dst, char *indices) const; - void execute_forward_3d(const data_t *src, data_t *dst, - char *indices) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_uni_pool_kernel_f32 *kernel_; -}; - -template -struct jit_uni_pooling_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_bwd_pd_t { - using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; - - DECLARE_COMMON_PD_T( - JIT_IMPL_NAME_HELPER("jit:", isa, ""), - jit_uni_pooling_bwd_t); - - status_t init() { - using namespace utils; - - bool ok = true - && set_default_params() == status::success - && !is_fwd() - && !has_zero_dim_memory() - && everyone_is(data_type::f32, - diff_src_md()->data_type, - diff_dst_md()->data_type) - && attr()->has_default_values() - && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag()) - && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag()); - if (!ok) return status::unimplemented; - - if (desc()->alg_kind == alg_kind::pooling_max) { - init_default_ws(); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - return jit_uni_pool_kernel_f32::init_conf(jpp_, this); - } - - format_tag_t desired_fmt_tag() { - using namespace format_tag; - return ndims() - ? isa == avx512_common ? nChw16c : nChw8c - : isa == avx512_common ? nCdhw16c : nCdhw8c; - } - - jit_pool_conf_t jpp_; - }; - - jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) - { kernel_ = new jit_uni_pool_kernel_f32(pd()->jpp_); } - - ~jit_uni_pooling_bwd_t() { delete kernel_; } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - if (pd()->ndims() == 5) - execute_backward_3d(diff_dst, ws, diff_src); - else - execute_backward(diff_dst, ws, diff_src); - - return status::success; - } - -private: - void execute_backward(const data_t *diff_dst, const char *indices, - data_t *diff_src) const; - void execute_backward_3d(const data_t *diff_dst, const char *indices, - data_t *diff_src) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - jit_uni_pool_kernel_f32 *kernel_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp deleted file mode 100644 index 98796503b..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp +++ /dev/null @@ -1,1006 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "memory_desc_wrapper.hpp" -#include "mkldnn_debug.h" -#include "nstl.hpp" -#include "type_helpers.hpp" - -#include "cpu_primitive.hpp" -#include "cpu_reorder_pd.hpp" -#include "jit_uni_reorder.hpp" - -#include "jit_generator.hpp" - -// #define TR_DEBUG -#if defined(TR_DEBUG) -#define DEBUg(...) do { __VA_ARGS__ } while (0) -#else -#define DEBUg(...) -#endif -#define DEBUG(...) DEBUg(__VA_ARGS__) - -#ifdef _WIN32 -/* seems like s_addr is a reserved macro on Windows */ -#undef s_addr -#endif - -using namespace Xbyak; -using namespace mkldnn::impl::types; - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace tr { - -/** Minimal reasonable/desirable kernel size. - * The constant might be used to determine how a problem should be split - * between kernel and threading driver. */ -const size_t ker_prb_size_min = 64; - -/* kernel */ -struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) - - enum { - len_unroll_max = 256, - ndims_jit_loop_max = 3, - }; - - struct simple_impl_desc_t { - int ndims_full_unroll; - int len_last_dim_unroll; - int len_unroll; - }; - - static bool simple_impl_desc_init(const prb_t &prb, - simple_impl_desc_t *desc) { - const int ndims = prb.ndims; - - int ndims_full_unroll = 0; - int len_last_dim_unroll = 1; - int len_unroll = 1; - - for (int d = 0; d < ndims; ++d) { - auto &node = prb.nodes[d]; - if (len_unroll * node.n <= len_unroll_max) { - ndims_full_unroll++; - len_unroll *= node.n; - } else { - len_last_dim_unroll = len_unroll_max / len_unroll; - while (node.n % len_last_dim_unroll) - --len_last_dim_unroll; - len_unroll *= len_last_dim_unroll; - break; - } - } - - if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) - return false; - - if (desc) { - desc->ndims_full_unroll = ndims_full_unroll; - desc->len_last_dim_unroll = len_last_dim_unroll; - desc->len_unroll = len_unroll; - } - - return true; - } - - static bool applicable(const prb_t &p) { - using namespace data_type; - - bool ok = true - && p.ndims > 0 - && utils::one_of(p.itype, f32, s32, s8, u8) - && utils::one_of(p.otype, f32, s32, s8, u8) - && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ - && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ - && simple_impl_desc_init(p, nullptr) - && mayiuse(sse42) - && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype), - mayiuse(avx)); - if (!ok) return false; - - const ptrdiff_t max_stride = (1LL<<31) - 1; - for (int d = 0; d < p.ndims; ++d) { - const ptrdiff_t cms = max_stride / p.nodes[d].n; - bool strides_ok = true - && p.nodes[d].is < cms / (int)data_type_size(p.itype) - && p.nodes[d].os < cms / (int)data_type_size(p.otype); - if (!strides_ok) return false; - } - - return true; - } - - int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; } - int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; } - int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; } - int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } - - Address i_addr(int i_off) - { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; } - - Address o_addr(int o_off) - { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; } - - Address s_addr(int s_off) - { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; } - - void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, - int &i_off, int &o_off, int &s_off, int step_size = 1) { - i_off = prev_i_off; - o_off = prev_o_off; - s_off = prev_s_off; - - if (off == 0) return; - - int start_dim = 0, dims_prod = 1; - for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) - dims_prod *= n(start_dim); - assert(start_dim < prb_.ndims); - off /= step_size; - - for (int d = start_dim; d < prb_.ndims; ++d) { - i_off += is(d); - o_off += os(d); - s_off += ss(d); - - if (off % n(d)) break; - - i_off += - n(d) * is(d); - o_off += - n(d) * os(d); - s_off += - n(d) * ss(d); - off /= n(d); - - if (off == 0) break; /* FIXME: is it really required? */ - } - } - - void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, - int step_size = 1) { - int dummy = 0; - step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, - step_size); - } - - void tr8x8_avx2(int i_off, int o_off) { - for (int i = 0; i < 8; i++) - vmovups(Ymm(i), i_addr(i_off + i * 8)); - - for (int i = 0; i < 8 / 2; i++) { - vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1)); - vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); - } - - const unsigned int lfloat = 0x44; - const unsigned int ufloat = 0xee; - for (int i = 0; i < 8 / 2; i++) { - int j = i % 2 == 0 ? 8 + i : i - 1; - vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); - vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); - } - - const unsigned int lquad = 0x20; - for (int i = 0; i < 8 / 2; i++) - vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad); - - const unsigned int uquad = 0x31; - for (int i = 8 / 2; i < 8; i++) - vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad); - - for (int i = 0; i < 8; i++) - vmovups(o_addr(o_off + i * 8), Ymm(i)); - } - - bool process_unroll_tr8x8(int len) { - bool can_do = true - && mayiuse(avx2) - && prb_.ndims >= 2 - && utils::everyone_is(4, itype_sz, otype_sz) - && utils::everyone_is(8, n(0), n(1)) - && utils::everyone_is(1, os(0), is(1)) - && utils::everyone_is(8, os(1), is(0)) - && prb_.scale_type == scale_type_t::NONE - && prb_.beta == 0.f; - if (!can_do) return false; - - const int step_size = n(0) * n(1); - int i_off = 0, o_off = 0; - for (int off = 0; off < len; off += step_size) { - step(off, i_off, o_off, i_off, o_off, step_size); - tr8x8_avx2(i_off, o_off); - } - - return true; - } - - template - bool process_direct_copy(int len) { - using namespace data_type; - - using Vmm = typename cpu_isa_traits::Vmm; - const int simd_w = cpu_isa_traits::vlen / itype_sz; - - bool can_do = true - && mayiuse(isa) - && utils::everyone_is(1, os(0), is(0)) - && (false - || prb_.itype == prb_.otype - || (prb_.itype == s32 && prb_.otype == f32) - || (prb_.itype == f32 && prb_.otype == s32) - ) - && len % simd_w == 0 - && n(0) % len == 0 - && prb_.scale_type == scale_type_t::NONE - && prb_.beta == 0.f; - if (!can_do) return false; - - for (int off = 0; off < len;) { - const int unroll = nstl::min(16, (len - off) / simd_w); - - for (int ur = 0; ur < unroll; ++ur) - uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); - - if (prb_.itype != prb_.otype) { - for (int ur = 0; ur < unroll; ++ur) { - if (prb_.itype == s32 && prb_.otype == f32) - uni_vcvtdq2ps(Vmm(ur), Vmm(ur)); - else if (prb_.itype == f32 && prb_.otype == s32) - uni_vcvtps2dq(Vmm(ur), Vmm(ur)); - else assert(!"unreachable"); - } - } - - for (int ur = 0; ur < unroll; ++ur) - uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur)); - - off += unroll * simd_w; - } - - return true; - } - - void process_unroll_generic_step(int reg_unroll, const int *i_off, - const int *o_off, const int *s_off) { - using namespace data_type; - - auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) { - Xmm dst_pure = Xmm(dst.getIdx()); - switch (idt) { - case f32: - if (src.isMEM() || src.getIdx() != dst.getIdx()) - vmovups(dst, src); - break; - case s32: vcvtdq2ps(dst, src); break; - case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; - case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; - default: assert(!"unreachable"); - } - }; - - auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { - switch (odt) { - case s32: - if (idt == f32) vcvtps2dq(xmm, xmm); - else if (idt == s8) vpmovsxbd(xmm, xmm); - else if (idt == u8) vpmovzxbd(xmm, xmm); - break; - case s8: - if (idt == f32) vcvtps2dq(xmm, xmm); - if (idt == f32 || idt == s32) { - if (mayiuse(avx512_core)) { - vpmovsdb(xmm, xmm); - } else { - vpackssdw(xmm, xmm, xmm_zero); - vpacksswb(xmm, xmm, xmm_zero); - } - } - if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); - break; - case u8: - if (idt == f32) vcvtps2dq(xmm, xmm); - if (idt == f32 || idt == s32) { - if (mayiuse(avx512_core)) { - vpmaxsd(xmm, xmm, xmm_zero); - vpmovusdb(xmm, xmm); - } else { - vpackssdw(xmm, xmm, xmm_zero); - vpackuswb(xmm, xmm, xmm_zero); - } - } - if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); - break; - default: assert(!"unreachable"); - } - }; - - auto load = [=](const Xmm &xmm, const Address &addr, int size) { - switch (size) { - case 16: movups(xmm, addr); break; - case 4: movss(xmm, addr); break; - case 1: pinsrb(xmm, addr, 0x0); break; - default: assert(!"unreachable"); - } - }; - - auto store = [=](const Address &addr, const Xmm &xmm, int size) { - switch (size) { - case 16: movups(addr, xmm); break; - case 4: movss(addr, xmm); break; - case 1: pextrb(addr, xmm, 0x0); break; - default: assert(!"unreachable"); - } - }; - - /* check whether loading 4 values at once is possible */ - bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0; - for (int ur = 1; ur < reg_unroll; ++ur) - if (i_off[ur] != i_off[ur - 1] + 1) - can_load_xmm = false; - const int load_step = can_load_xmm ? 4 : 1; - - /* check whether storing 4 values at once is possible */ - bool can_store_xmm = reg_unroll % 4 == 0; - for (int ur = 1; ur < reg_unroll; ++ur) - if (o_off[ur] != o_off[ur - 1] + 1) - can_store_xmm = false; - const int ur_step = can_store_xmm ? 4 : 1; - - const bool interim_f32 = false - || utils::one_of(f32, prb_.itype, prb_.otype) - || prb_.scale_type != scale_type_t::NONE - || prb_.beta != 0.f; - - if (!can_load_xmm && can_store_xmm) { - assert(ur_step == 4); - /* load with stride */ - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - for (int r = 0; r < ur_step; ++r) { - if (itype_sz == 4) - pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r); - else - pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r); - } - } - } else { - for (int ur = 0; ur < reg_unroll; ur += load_step) - load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); - } - - /* xmm[:] <-- (f32)xmm[:] */ - if (interim_f32) { - const int cvt_step = nstl::max(load_step, ur_step); - for (int ur = 0; ur < reg_unroll; ur += cvt_step) - cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); - } - - if (can_load_xmm && !can_store_xmm) { - const bool fast_return = true // transposition on the fly - && prb_.scale_type != scale_type_t::MANY - && prb_.beta == 0.f; - if (fast_return) { - for (int ur = 0; ur < reg_unroll; ur += load_step) { - if (prb_.scale_type == scale_type_t::COMMON) - mulps(Xmm(ur), xmm_scale); - if (prb_.otype != f32) - cvt2int(Xmm(ur), prb_.otype, - interim_f32 ? f32 : prb_.itype); - for (int r = 0; r < load_step; ++r) { - if (otype_sz == 4) - pextrd(o_addr(o_off[ur + r]), Xmm(ur), r); - else - pextrb(o_addr(o_off[ur + r]), Xmm(ur), r); - } - } - return; - } - - /* scatter elements of xmm into 4 xmms */ - if (itype_sz == 4 || interim_f32) { - for (int ur = 0; ur < reg_unroll; ur += load_step) - for (int r = 1; r < load_step; ++r) - vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); - } else { - for (int ur = 0; ur < reg_unroll; ur += load_step) - for (int r = 1; r < load_step; ++r) - vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r); - } - } - - /* scale and beta processing */ - if (can_store_xmm) { - /* xmm <-- scale * xmm[:] */ - if (prb_.scale_type == scale_type_t::COMMON) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) - mulps(Xmm(ur), xmm_scale); - } else if (prb_.scale_type == scale_type_t::MANY) { - enum class scale_load_type_t { bcast, load, gather }; - - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - scale_load_type_t scale_load_type = - scale_load_type_t::bcast; // the best case - - for (int r = ur + 1; r < ur + ur_step; ++r) - if (s_off[r] != s_off[r - 1] + 0) - scale_load_type = scale_load_type_t::load; - - if (scale_load_type == scale_load_type_t::bcast) { - movss(xmm_scale, s_addr(s_off[ur])); - shufps(xmm_scale, xmm_scale, 0x0); - mulps(Xmm(ur), xmm_scale); - continue; - } - - // bcast doesn't work, the next try -- load - for (int r = ur + 1; r < ur + ur_step; ++r) - if (s_off[r] != s_off[r - 1] + 1) - scale_load_type = scale_load_type_t::gather; - - if (scale_load_type == scale_load_type_t::load) { - movups(xmm_scale, s_addr(s_off[ur])); - mulps(Xmm(ur), xmm_scale); - continue; - } - - // load doesn't work as well - // so gather the scale factors one by one - for (int r = ur; r < ur + ur_step; ++r) - pinsrd(xmm_scale, s_addr(s_off[r]), r - ur); - mulps(Xmm(ur), xmm_scale); - } - } - - /* dst <-- beta * dst + xmm[:] */ - assert(prb_.beta == 0.f || prb_.beta == 1.f); - if (prb_.beta == 1.f) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - if (prb_.otype == f32) { - /* non VEX instructions do not support unaligned - * memory for instructions other than movups. */ - if (mayiuse(avx)) { - vaddps(Xmm(ur), o_addr(o_off[ur])); - } else { - /* register xmm(1) is unused */ - movups(Xmm(1), o_addr(o_off[ur])); - addps(Xmm(ur), Xmm(1)); - } - } else { - cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); - vaddps(Xmm(ur), Xmm(1)); - } - } - } - } else { - /* xmm[0] <-- scale * xmm[0] */ - if (prb_.scale_type == scale_type_t::COMMON) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) - mulss(Xmm(ur), xmm_scale); - } else if (prb_.scale_type == scale_type_t::MANY) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - mulss(Xmm(ur), s_addr(s_off[ur])); - } - } - - /* dst <-- beta * dst + xmm[0] */ - assert(prb_.beta == 0.f || prb_.beta == 1.f); - if (prb_.beta == 1.f) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - if (prb_.otype == f32) { - addss(Xmm(ur), o_addr(o_off[ur])); - } else { - if (prb_.otype == s32) { - vmovss(xmm_tmp, o_addr(o_off[ur])); - } else if (utils::one_of(prb_.otype, s8, u8)) { - pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0); - } else { - assert(!"unsupported o_type"); - } - cvt2ps(xmm_tmp, xmm_tmp, prb_.otype); - addps(Xmm(ur), xmm_tmp); - } - } - } - } - - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - if (prb_.otype != f32) - cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); - store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz); - } - } - - void process_unroll_generic(int len) { - const int blk = 8; - - int i_off[2 * blk] = {0}; - int o_off[2 * blk] = {0}; - int s_off[2 * blk] = {0}; - - int curr = 0; // will switch between 0 and 1 - - for (int off = 0; off < len; off += blk) { - const int reg_unroll = nstl::min(off + blk, len) - off; - - /* compute offsets */ - for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { - const int ur_c = curr * blk + ur; - const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur - step(off + ur, - i_off[ur_p], o_off[ur_p], s_off[ur_p], - i_off[ur_c], o_off[ur_c], s_off[ur_c]); - } - - process_unroll_generic_step(reg_unroll, i_off + curr * blk, - o_off + curr * blk, s_off + curr * blk); - - curr = 1 - curr; - } - } - - void loop_begin(Label &l, Reg64 reg_cnt, int len) { - mov(reg_cnt, len); - L(l); - } - - void loop_end(Label &l, Reg64 reg_cnt, int len, - int i_step, int o_step, int s_step) { - add(reg_off_in, i_step * itype_sz); - add(reg_off_out, o_step * otype_sz); - if (prb_.scale_type == scale_type_t::MANY) - add(reg_off_scale, s_step * stype_sz); - dec(reg_cnt); - jnz(l); - - sub(reg_off_in, len * i_step * itype_sz); - sub(reg_off_out, len * o_step * otype_sz); - if (prb_.scale_type == scale_type_t::MANY) - sub(reg_off_scale, len * s_step * stype_sz); - } - - bool simple_impl() { - simple_impl_desc_t d; - if (!simple_impl_desc_init(prb_, &d)) return false; - - const int nfu = d.ndims_full_unroll; - const int ldu = d.len_last_dim_unroll; - const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; - assert(n_jit_loops <= ndims_jit_loop_max); - - xor_(reg_off_in, reg_off_in); - xor_(reg_off_out, reg_off_out); - if (prb_.scale_type == scale_type_t::MANY) - xor_(reg_off_scale, reg_off_scale); - - Label l_loop[3]; - Reg64 reg_cnt[3] = {r15, r14, r13}; - - if (n_jit_loops > 2) - loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); - - if (n_jit_loops > 1) - loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); - - if (n_jit_loops > 0) - loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); - - const bool optimized = false - || process_direct_copy(d.len_unroll) - || process_direct_copy(d.len_unroll) - || process_unroll_tr8x8(d.len_unroll); - if (!optimized) - process_unroll_generic(d.len_unroll); - - if (n_jit_loops > 0) - loop_end(l_loop[0], reg_cnt[0], - n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu, - ss(nfu + 0) * ldu); - - if (n_jit_loops > 1) - loop_end(l_loop[1], reg_cnt[1], - n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1)); - - if (n_jit_loops > 2) - loop_end(l_loop[2], reg_cnt[2], - n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2)); - - return true; - } - - void impl() { - if (simple_impl()) return; - assert(!"no implementation available"); - } - - jit_uni_reorder_kernel_f32(const desc_t &desc) - : kernel_t(desc), jit_generator() { - itype_sz = data_type_size(prb_.itype); - otype_sz = data_type_size(prb_.otype); - stype_sz = sizeof(float); - - preamble(); -# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)] - if (prb_.scale_type == scale_type_t::COMMON) { - auto reg_ptr_scale_tmp = reg_ptr_in; - mov(reg_ptr_scale_tmp, PARAM(scale)); - movups(xmm_scale, ptr[reg_ptr_scale_tmp]); - } else if (prb_.scale_type == scale_type_t::MANY) { - mov(reg_ptr_scale, PARAM(scale)); - } - mov(reg_ptr_in, PARAM(in)); - mov(reg_ptr_out, PARAM(out)); -# undef PARAM - - if (mayiuse(avx)) { - vxorps(xmm_zero, xmm_zero, xmm_zero); - - if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { - mov(reg_tmp.cvt32(), 0x7f7f7f7f); - movd(xmm_4x127b, reg_tmp.cvt32()); - } - } - - impl(); - postamble(); - ker_ = (void (*)(const call_param_t *))getCode(); - } - -private: - int itype_sz; - int otype_sz; - int stype_sz; - - Reg64 reg_ptr_in = rsi; - Reg64 reg_ptr_out = rdx; - Reg64 reg_ptr_scale = abi_not_param1; - - Reg64 reg_off_in = r8; - Reg64 reg_off_out = r9; - Reg64 reg_off_scale = r10; - - Reg64 reg_tmp = rax; - - Xmm xmm_scale = xmm15; - Xmm xmm_zero = xmm14; - Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero - Xmm xmm_tmp = xmm12; -}; - -status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb, - int ndims_ker_max) { - desc.prb = prb; - desc.prb.ioff = desc.prb.ooff = 0; - - if (ndims_ker_max > prb.ndims) - return status::invalid_arguments; - - auto ndims_ker_max_f = [&]() { - size_t cur_size = 1; - for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) - if (cur_size >= ker_prb_size_min) return d; - return prb.ndims; - }; - - if (ndims_ker_max <= 0) - ndims_ker_max = ndims_ker_max_f(); - - /* traverse through kernel implementations */ - /* TODO: find a better way to do that... */ - desc.id = 0; - for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { - desc.prb.ndims = ndims_ker; - if (jit_uni_reorder_kernel_f32::applicable(desc.prb)) - return status::success; - } - - return status::unimplemented; -} - -kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { - switch (desc.id) { - case 0: return new jit_uni_reorder_kernel_f32(desc); - default: assert(!"unknown kernel id"); return nullptr; - } - - return nullptr; -} - -} - -static void prb_block_for_cache(tr::prb_t &prb) { - if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) { - /** an attempt to use caches more efficient and - * address the 4K-aliasing issue */ - /* TODO: improve the logic around here */ - int j = 1; - for (; j < prb.ndims && prb.nodes[j].is != 1; ++j); - if (j == prb.ndims) return; - - /* it makes sense to re-prioritize sequential read over - * sequential write if the former would not trash the - * cache, i.e. is == 1 and os % 2^smth != 0. Smth is - * set to 2 at the moment */ - const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1; - if (j == move_to) return; - - if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0) - prb_node_split(prb, j, 16); - - prb_node_move(prb, j, move_to); - DEBUG({ printf("cache: "); prb_dump(prb); }); - } -} - -/** finds the maximum number of dimension the kernel should process and - * optionally splits one of the dimension to achieve better balance between - * parallel driver and the kernel. */ -static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) { - size_t sz_total = 1; - for (int d = 0; d < prb.ndims; ++d) - sz_total *= prb.nodes[d].n; - - /* sz_drv_min is the minimal size for the parallel - * driver required for good parallelization */ - const size_t sz_drv_min = nstl::min( - 16 * mkldnn_get_max_threads(), - utils::div_up(sz_total, 1024)); - - /* kdims -- # of dimensions processed by a kernel - * sz_ker_cur -- product of the dimension processed by a kernel - * sz_drv_cur -- product of the dimension processed by a driver */ - - int kdims = prb.ndims; - size_t sz_drv_cur = 1; - for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) - sz_drv_cur *= prb.nodes[kdims - 1].n; - - size_t sz_ker_cur = 1; - for (int d = 0; d < kdims; ++d) - sz_ker_cur *= prb.nodes[d].n; - - /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. - * - * It might happen that for chosen kdims the sz_ker_cur is too small - * (less than tr::ker_prb_size_min). In that case try to split the - * innermost driver dimension into two, to increase sz_ker_cur. */ - bool want_borrow_ker_from_drv = true - && kdims < prb.ndims - && sz_ker_cur < tr::ker_prb_size_min - && sz_drv_cur > sz_drv_min; - if (want_borrow_ker_from_drv) { - /* sz_want_borrow is the minimal sz, so that: - * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min - * o) current innermost driver dimension is divisible by - * sz_want_borrow (so that we can evenly split that - * dimension into two) - * - * In the worst case the minimal sz_want_borrow is equal - * to the innermost driver dimension itself. In that case - * we will sacrifice it in favor of kernel (is it fine?). */ - size_t sz_want_borrow - = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); - for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow); - if (sz_want_borrow != prb.nodes[kdims].n) - prb_node_split(prb, kdims, sz_want_borrow); - kdims += 1; - } - - /* On the other hand it might happen that for chosen kdims - * the sz_drv_cur is too small (less than sz_drv_min). In that case - * try to split the outermost kernel dimension into two, to increase - * sz_drv_cur. */ - bool want_borrow_drv_from_ker = true - && sz_ker_cur > tr::ker_prb_size_min - && sz_drv_cur < sz_drv_min; - if (want_borrow_drv_from_ker) { - size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); - for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow); - if (sz_want_borrow != prb.nodes[kdims - 1].n) - prb_node_split(prb, kdims - 1, - prb.nodes[kdims - 1].n / sz_want_borrow); - } - - ndims_ker_max = kdims; - - if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { - DEBUG({ printf("split: "); prb_dump(prb); - printf("ndims_ker_max = %d\n", ndims_ker_max); }); - } -} - -struct jit_uni_reorder_t : public cpu_primitive_t { - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); - - static status_t create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { - auto prb = tr::prb_t(); - - status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); - if (prb_init_status != status::success) return prb_init_status; - - DEBUG({ printf("init : "); prb_dump(prb); }); - prb_normalize(prb); - DEBUG({ printf("norm : "); prb_dump(prb); }); - prb_simplify(prb); - DEBUG({ printf("smpl : "); prb_dump(prb); }); - - prb_block_for_cache(prb); - - int ndims_ker_max; - prb_thread_kernel_balance(prb, ndims_ker_max); - - tr::kernel_t::desc_t ker_desc; - status_t ker_init_status - = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); - if (ker_init_status != status::success) return ker_init_status; - - const int ndims_driver = prb.ndims - ker_desc.prb.ndims; - if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) - return status::unimplemented; - - DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); }); - - auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, - dst_md); - if (_pd == nullptr) return status::out_of_memory; - if (_pd->init() != status::success) { - delete _pd; - return status::unimplemented; - } - _pd->prb_ = prb; - _pd->ker_desc_ = ker_desc; - return safe_ptr_assign(*reorder_pd, _pd); - } - - tr::prb_t prb_; - tr::kernel_t::desc_t ker_desc_; - }; - - jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { - kernel_ = tr::kernel_t::create(pd()->ker_desc_); - assert(kernel_); - } - ~jit_uni_reorder_t() { delete kernel_; } - - void omp_driver_0d(int off, const char *in, char *out, - const float *scale) const { - tr::call_param_t c{in, out, scale}; - (*kernel_)(&c); - } - - void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, - const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { - auto c = tr::call_param_t(); - c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); - c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss; - (*kernel_)(&c); - }); - } - - void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, - const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d1, ptrdiff_t d0) { - auto c = tr::call_param_t(); - c.in = in + (d0 * ns[0].is + d1 * ns[1].is) - * data_type_size(pd()->prb_.itype); - c.out = out + (d0 * ns[0].os + d1 * ns[1].os) - * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; - (*kernel_)(&c); - }); - } - - void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, - const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, - (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { - auto c = tr::call_param_t(); - c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) - * data_type_size(pd()->prb_.itype); - c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) - * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; - (*kernel_)(&c); - }); - } - - void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, - const float *scale) const { - const tr::node_t *ns = pd()->prb_.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, - (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { - auto c = tr::call_param_t(); - c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is - + d3 * ns[3].is) * data_type_size(pd()->prb_.itype); - c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os - + d3 * ns[3].os) * data_type_size(pd()->prb_.otype); - c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss - + d3 * ns[3].ss; - (*kernel_)(&c); - }); - } - - void omp_driver(const char *in, char *out, const float *scale) const { - in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); - out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); - - DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); }); - DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); }); - - int ndims = pd()->prb_.ndims; - int ndims_ker = pd()->ker_desc_.prb.ndims; - assert(ndims - ndims_ker <= ndims_driver_max); - - if (ndims - ndims_ker == 0) { - omp_driver_0d(ndims_ker, in, out, scale); - } else { - parallel(0, [&](const int ithr, const int nthr) { - switch (ndims - ndims_ker) { - case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break; - case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break; - case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break; - case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break; - default: assert(!"unimplemented"); - } - }); - } - } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM); - auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); - - omp_driver(in, out, pd()->attr()->output_scales_.scales_); - - return status::success; - } - - enum { ndims_driver_max = 4 }; - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - tr::kernel_t *kernel_; -}; - -status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { - return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr, - src_engine, src_md, dst_engine, dst_md); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp deleted file mode 100644 index 0746ea61d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef _JIT_UNI_REORDER_HPP -#define _JIT_UNI_REORDER_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" - -#include "cpu_primitive.hpp" -#include "cpu_reorder_pd.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace tr { - -constexpr int max_ndims = MKLDNN_MAX_NDIMS; - -struct node_t { - size_t n; - ptrdiff_t is; // input stride - ptrdiff_t os; // output stride - ptrdiff_t ss; // scale stride -}; - -enum class scale_type_t { NONE, COMMON, MANY }; - -struct prb_t { - data_type_t itype; - data_type_t otype; - int ndims; - node_t nodes[max_ndims]; - ptrdiff_t ioff; - ptrdiff_t ooff; - scale_type_t scale_type; - float beta; -}; - -status_t prb_init(prb_t &prb, const memory_desc_t &imd, - const memory_desc_t &omd, const primitive_attr_t *attr); - -/** sorts the problem nodes so that output strides come in ascending order */ -void prb_normalize(prb_t &p); - -/** folds nodes together if possible */ -void prb_simplify(prb_t &p); - -/** splits the node dim into two of sizes n1 and n / n1 - * @warning n must be multiple of n1 */ -void prb_node_split(prb_t &p, int dim, size_t n1); - -/** swaps d0 and d1 nodes */ -void prb_node_swap(prb_t &p, int d0, int d1); - -/** moves node d0 to the d1 position. - * nodes (d0, d1] are shifted to the left if d0 < d1 or - * to the right if d0 > d1 */ -void prb_node_move(prb_t &p, int d0, int d1); - -/** dumps the problem to stdout */ -void prb_dump(const prb_t &p); - -struct call_param_t { - const void *in; - void *out; - const float *scale; -}; - -struct kernel_t { - struct desc_t { - int id; - prb_t prb; - }; - - kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {} - void operator()(const call_param_t *c) const { assert(ker_); ker_(c); } - virtual ~kernel_t() {} - - /** inits kernel descriptor: - * desc -- kernel descriptor (output) - * prb -- transposition problem (input) - * ndims_ker_max -- limit the maximum number of dimensions kernel - * will process (optional, 0 -- no limitation) */ - static status_t desc_init(desc_t &desc, const prb_t &prb, - int ndims_ker_max = 0); - - /** creates kernel for the problem described in desc */ - static kernel_t *create(const desc_t &desc); - -protected: - const desc_t desc_; - const prb_t &prb_ = desc_.prb; - void (*ker_)(const call_param_t *); -}; - -/* TODO: add trans_t class */ - -} - -/* for cpu reorder list */ -status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md); - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp deleted file mode 100644 index 69b7a3360..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp +++ /dev/null @@ -1,313 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "memory_desc_wrapper.hpp" -#include "mkldnn_debug.h" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "jit_uni_reorder.hpp" - -using namespace mkldnn::impl::types; -using namespace mkldnn::impl::status; - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace tr { - -/** ad-hoc structure to describe blocked memory layout */ -struct layout_desc_t { - data_type_t dt; - int ndims; - dims_t id; - dims_t dims; - strides_t strides; -}; - -status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, - layout_desc_t &ld) { - const auto md = memory_desc_wrapper(md_); - - bool ok = true - && md.is_blocking_desc() - && md.extra().flags == 0; - if (!ok) return invalid_arguments; - - const auto &bd = md.blocking_desc(); - - ld.ndims = 0; - ld.dt = md.data_type(); - - auto P = [&ld](int id, int dim, ptrdiff_t stride) { - assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); - ld.id[ld.ndims] = id; - ld.dims[ld.ndims] = dim; - ld.strides[ld.ndims] = stride; - ++ld.ndims; - }; - - dims_t blocks; - md.compute_blocks(blocks); - - for (int d = 0; d < md.ndims(); ++d) { - const int ld_ndims_start = ld.ndims; - if (blocks[d] != 1) { - stride_t stride = 1; - for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { - if (bd.inner_idxs[iblk] == d) - P(d, bd.inner_blks[iblk], stride); - stride *= bd.inner_blks[iblk]; - } - } - P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); - - // TODO: NOW: revisit, do we need a reverse? - // TODO: NOW: consider using strides instead of block sizes in md - // reverse the order of dims - for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { - const int idx0 = ld_ndims_start + ld_d; - const int idx1 = ld.ndims - 1 - ld_d; - nstl::swap(ld.dims[idx0], ld.dims[idx1]); - nstl::swap(ld.strides[idx0], ld.strides[idx1]); - } - } - - return success; -} - -status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, - const primitive_attr_t *attr) { - auto im_d = memory_desc_wrapper(imd); - auto om_d = memory_desc_wrapper(omd); - - bool ok = true - && im_d.is_blocking_desc() - && om_d.is_blocking_desc() - && !im_d.has_zero_dim() - && !om_d.has_zero_dim(); - if (!ok) - return unimplemented; - - dims_t iblocks, oblocks; - im_d.compute_blocks(iblocks); - om_d.compute_blocks(oblocks); - - /* padding_dim consistency check */ - for (int d = 0; d < im_d.ndims(); ++d) { - const auto pdim = im_d.padded_dims()[d]; - bool ok = true - && pdim == om_d.padded_dims()[d] - && pdim % iblocks[d] == 0 - && pdim % oblocks[d] == 0; - if (!ok) return unimplemented; - } - - layout_desc_t ild, old; - status_t status = cvt_mem_desc_to_layout_desc(imd, ild); - if (status != success) return status; - status = cvt_mem_desc_to_layout_desc(omd, old); - if (status != success) return status; - - p.itype = ild.dt; - p.otype = old.dt; - - p.scale_type = attr->output_scales_.has_default_values() - ? scale_type_t::NONE - : (attr->output_scales_.mask_ == 0 - ? scale_type_t::COMMON - : scale_type_t::MANY); - - ptrdiff_t ss[max_ndims] = {0}; - if (p.scale_type == scale_type_t::MANY) { - ptrdiff_t last_ss = 1; - for (int d = old.ndims - 1; d >=0; --d) { - assert((d == 0 || old.id[d - 1] <= old.id[d]) - && "logical dimensions should be in ascending order"); - if (attr->output_scales_.mask_ & (1 << old.id[d])) { - ss[d] = last_ss; - last_ss *= old.dims[d]; - } - } - } - - int ndims = 0; - - int i_pos = 0; /* state for input -- current dimension */ - int o_pos = 0; /* state for output -- current dimension */ - - while (i_pos < ild.ndims && o_pos < old.ndims) { - assert(ild.id[i_pos] == old.id[o_pos]); - if (ild.id[i_pos] != old.id[o_pos]) - return runtime_error; - - assert(ndims < max_ndims); - if (ndims == max_ndims) - return runtime_error; - - if (ild.dims[i_pos] == old.dims[o_pos]) { - p.nodes[ndims].n = ild.dims[i_pos]; - p.nodes[ndims].is = ild.strides[i_pos]; - p.nodes[ndims].os = old.strides[o_pos]; - p.nodes[ndims].ss = ss[o_pos]; - ++ndims; - ++i_pos; - ++o_pos; - } else if (ild.dims[i_pos] < old.dims[o_pos]) { - assert(old.dims[o_pos] % ild.dims[i_pos] == 0); - int factor = old.dims[o_pos] / ild.dims[i_pos]; - p.nodes[ndims].n = ild.dims[i_pos]; - p.nodes[ndims].is = ild.strides[i_pos]; - p.nodes[ndims].os = old.strides[o_pos] * factor; - p.nodes[ndims].ss = ss[o_pos] * factor; - ++ndims; - ++i_pos; - old.dims[o_pos] = factor; - } else if (ild.dims[i_pos] > old.dims[o_pos]) { - assert(ild.dims[i_pos] % old.dims[o_pos] == 0); - int factor = ild.dims[i_pos] / old.dims[o_pos]; - p.nodes[ndims].n = old.dims[o_pos]; - p.nodes[ndims].is = ild.strides[i_pos] * factor; - p.nodes[ndims].os = old.strides[o_pos]; - p.nodes[ndims].ss = ss[o_pos]; - ++ndims; - ++o_pos; - ild.dims[i_pos] = factor; - } - } - p.ndims = ndims; - - dims_t zero_pos = {0}; - p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); - p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); - - const int sum_idx = attr->post_ops_.find(primitive_kind::sum); - p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; - - return success; -} - -void prb_normalize(prb_t &p) { - for (int d = 0; d < p.ndims; ++d) { - int min_pos = d; - for (int j = d + 1; j < p.ndims; ++j) { - bool new_min = false - || p.nodes[j].os < p.nodes[min_pos].os - || (true - && p.nodes[j].os == p.nodes[min_pos].os - && p.nodes[j].n < p.nodes[min_pos].n); - if (new_min) min_pos = j; - } - if (min_pos != d) - nstl::swap(p.nodes[d], p.nodes[min_pos]); - } -} - -void prb_simplify(prb_t &p) { -#if defined(__GNUC__) && __GNUC__ >= 4 -/* GCC produces bogus array subscript is above array bounds warning for - * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" -#endif - for (int d = 0; d < p.ndims - 1; ++d) { - auto &this_node = p.nodes[d + 0]; - auto &next_node = p.nodes[d + 1]; - const bool fold = false - || next_node.n == (size_t)1 // trivial case, just drop next node - || (true // or real folding if possible - && next_node.is == (ptrdiff_t)this_node.n * this_node.is - && next_node.os == (ptrdiff_t)this_node.n * this_node.os - && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); - if (fold) { - this_node.n *= next_node.n; - for (int j = d + 2; j < p.ndims; ++j) - p.nodes[j - 1] = p.nodes[j]; - --p.ndims; - --d; // make another try - } - } -#if defined(__GNUC__) && __GNUC__ >= 4 -#pragma GCC diagnostic pop -#endif -} - -void prb_node_split(prb_t &p, int dim, size_t n1) { - assert(dim < p.ndims); - assert(p.ndims < max_ndims); - assert(p.nodes[dim].n % n1 == 0); - - p.ndims += 1; - - for (int d = p.ndims; d > dim + 1; --d) - p.nodes[d] = p.nodes[d - 1]; - - p.nodes[dim + 1].n = p.nodes[dim].n / n1; - p.nodes[dim + 1].is = p.nodes[dim].is * n1; - p.nodes[dim + 1].os = p.nodes[dim].os * n1; - p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; - - p.nodes[dim].n = n1; -} - -void prb_node_swap(prb_t &p, int d0, int d1) { - assert(d0 < p.ndims); - assert(d1 < p.ndims); - assert(p.ndims < max_ndims); - - if (d0 == d1) return; - - nstl::swap(p.nodes[d0], p.nodes[d1]); -} - -void prb_node_move(prb_t &p, int d0, int d1) { - assert(d0 < p.ndims); - assert(d1 < p.ndims); - assert(p.ndims < max_ndims); - - if (d0 == d1) return; - - node_t node = p.nodes[d0]; - - if (d0 < d1) - for (int d = d0; d < d1; ++d) - p.nodes[d] = p.nodes[d + 1]; - else - for (int d = d0; d > d1; --d) - p.nodes[d] = p.nodes[d - 1]; - - p.nodes[d1] = node; -} - -void prb_dump(const prb_t &p) { - printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype), - mkldnn_dt2str(p.otype), p.ndims); - for (int d = 0; d < p.ndims; ++d) - printf("[%zu:%td:%td:%td]", - p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); - printf(" off:%zu:%zu\n", p.ioff, p.ooff); -} - -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp deleted file mode 100644 index 08747aa89..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp +++ /dev/null @@ -1,115 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "utils.hpp" - -#ifndef MKLDNN_ENABLE_JIT_PROFILING -#define MKLDNN_ENABLE_JIT_PROFILING 1 -#endif - -#ifndef MKLDNN_ENABLE_JIT_DUMP -#define MKLDNN_ENABLE_JIT_DUMP 1 -#endif - -#if MKLDNN_ENABLE_JIT_PROFILING -#include "jitprofiling/jitprofiling.h" -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { -namespace jit_utils { - -// WARNING: These functions are not thread safe and must be protected by a -// mutex - -void dump_jit_code(const void *code, size_t code_size, const char *code_name) -{ -#if MKLDNN_ENABLE_JIT_DUMP - if (code && jit_dump_enabled()) { - static int counter = 0; -#define MAX_FNAME_LEN 256 - char fname[MAX_FNAME_LEN + 1]; - // TODO (Roma): support prefix for code / linux perf dumps - snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name, - counter); - counter++; - - FILE *fp = fopen(fname, "w+"); - // Failure to dump code is not fatal - if (fp) { - size_t unused = fwrite(code, code_size, 1, fp); - UNUSED(unused); - fclose(fp); - } - } -#undef MAX_FNAME_LEN -#else - UNUSED(code); - UNUSED(code_size); - UNUSED(code_name); -#endif -} - -void register_jit_code_vtune(const void *code, size_t code_size, - const char *code_name, const char *source_file_name) -{ -#if MKLDNN_ENABLE_JIT_PROFILING - if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) { - auto jmethod = iJIT_Method_Load(); - jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe - jmethod.method_name = (char *)code_name; // XXX: dropping const - jmethod.class_file_name = NULL; - jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const - jmethod.method_load_address = (void *)code; - jmethod.method_size = (unsigned int)code_size; - - iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, - (void*)&jmethod); - } -#else - UNUSED(code); - UNUSED(code_size); - UNUSED(code_name); - UNUSED(source_file_name); -#endif -} - -void register_jit_code(const void *code, size_t code_size, - const char *code_name, const char *source_file_name) -{ - // The #ifdef guards are required to avoid generating a function that only - // consists of lock and unlock code -#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP - static std::mutex m; - std::lock_guard guard(m); - - dump_jit_code(code, code_size, code_name); - register_jit_code_vtune(code, code_size, code_name, source_file_name); -#else - UNUSED(code); - UNUSED(code_size); - UNUSED(code_name); - UNUSED(source_file_name); -#endif -} - -} -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp deleted file mode 100644 index 2f52dba4a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp +++ /dev/null @@ -1,32 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef JIT_SUPPORT_HPP -#define JIT_SUPPORT_HPP - -namespace mkldnn { -namespace impl { -namespace cpu { -namespace jit_utils { - -void register_jit_code(const void *code, size_t code_size, - const char *code_name, const char *source_file_name); - -} -} -} -} -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD deleted file mode 100644 index 4fd21cea5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2011, Intel Corporation -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md deleted file mode 100644 index fc67c4f13..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md +++ /dev/null @@ -1 +0,0 @@ -This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI) diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h deleted file mode 100644 index edbf4a15f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h +++ /dev/null @@ -1,595 +0,0 @@ -/* - - Contact Information: - http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ - - BSD LICENSE - - Copyright (c) 2005-2014 Intel Corporation. All rights reserved. - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. - * Neither the name of Intel Corporation nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ -#ifndef _ITTNOTIFY_CONFIG_H_ -#define _ITTNOTIFY_CONFIG_H_ - -/** @cond exclude_from_documentation */ -#ifndef ITT_OS_WIN -# define ITT_OS_WIN 1 -#endif /* ITT_OS_WIN */ - -#ifndef ITT_OS_LINUX -# define ITT_OS_LINUX 2 -#endif /* ITT_OS_LINUX */ - -#ifndef ITT_OS_MAC -# define ITT_OS_MAC 3 -#endif /* ITT_OS_MAC */ - -#ifndef ITT_OS_FREEBSD -# define ITT_OS_FREEBSD 4 -#endif /* ITT_OS_FREEBSD */ - -#ifndef ITT_OS -# if defined WIN32 || defined _WIN32 -# define ITT_OS ITT_OS_WIN -# elif defined( __APPLE__ ) && defined( __MACH__ ) -# define ITT_OS ITT_OS_MAC -# elif defined( __FreeBSD__ ) -# define ITT_OS ITT_OS_FREEBSD -# else -# define ITT_OS ITT_OS_LINUX -# endif -#endif /* ITT_OS */ - -#ifndef ITT_PLATFORM_WIN -# define ITT_PLATFORM_WIN 1 -#endif /* ITT_PLATFORM_WIN */ - -#ifndef ITT_PLATFORM_POSIX -# define ITT_PLATFORM_POSIX 2 -#endif /* ITT_PLATFORM_POSIX */ - -#ifndef ITT_PLATFORM_MAC -# define ITT_PLATFORM_MAC 3 -#endif /* ITT_PLATFORM_MAC */ - -#ifndef ITT_PLATFORM_FREEBSD -# define ITT_PLATFORM_FREEBSD 4 -#endif /* ITT_PLATFORM_FREEBSD */ - -#ifndef ITT_PLATFORM -# if ITT_OS==ITT_OS_WIN -# define ITT_PLATFORM ITT_PLATFORM_WIN -# elif ITT_OS==ITT_OS_MAC -# define ITT_PLATFORM ITT_PLATFORM_MAC -# elif ITT_OS==ITT_OS_FREEBSD -# define ITT_PLATFORM ITT_PLATFORM_FREEBSD -# else -# define ITT_PLATFORM ITT_PLATFORM_POSIX -# endif -#endif /* ITT_PLATFORM */ - -#if defined(_UNICODE) && !defined(UNICODE) -#define UNICODE -#endif - -#include -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#include -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#include -#if defined(UNICODE) || defined(_UNICODE) -#include -#endif /* UNICODE || _UNICODE */ -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -#ifndef ITTAPI_CDECL -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# define ITTAPI_CDECL __cdecl -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# if defined _M_IX86 || defined __i386__ -# define ITTAPI_CDECL __attribute__ ((cdecl)) -# else /* _M_IX86 || __i386__ */ -# define ITTAPI_CDECL /* actual only on x86 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* ITTAPI_CDECL */ - -#ifndef STDCALL -# if ITT_PLATFORM==ITT_PLATFORM_WIN -# define STDCALL __stdcall -# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -# if defined _M_IX86 || defined __i386__ -# define STDCALL __attribute__ ((stdcall)) -# else /* _M_IX86 || __i386__ */ -# define STDCALL /* supported only on x86 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#endif /* STDCALL */ - -#define ITTAPI ITTAPI_CDECL -#define LIBITTAPI ITTAPI_CDECL - -/* TODO: Temporary for compatibility! */ -#define ITTAPI_CALL ITTAPI_CDECL -#define LIBITTAPI_CALL ITTAPI_CDECL - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -/* use __forceinline (VC++ specific) */ -#define ITT_INLINE __forceinline -#define ITT_INLINE_ATTRIBUTE /* nothing */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -/* - * Generally, functions are not inlined unless optimization is specified. - * For functions declared inline, this attribute inlines the function even - * if no optimization level was specified. - */ -#ifdef __STRICT_ANSI__ -#define ITT_INLINE static -#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) -#else /* __STRICT_ANSI__ */ -#define ITT_INLINE static inline -#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) -#endif /* __STRICT_ANSI__ */ -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -/** @endcond */ - -#ifndef ITT_ARCH_IA32 -# define ITT_ARCH_IA32 1 -#endif /* ITT_ARCH_IA32 */ - -#ifndef ITT_ARCH_IA32E -# define ITT_ARCH_IA32E 2 -#endif /* ITT_ARCH_IA32E */ - -#ifndef ITT_ARCH_ARM -# define ITT_ARCH_ARM 4 -#endif /* ITT_ARCH_ARM */ - -#ifndef ITT_ARCH_PPC64 -# define ITT_ARCH_PPC64 5 -#endif /* ITT_ARCH_PPC64 */ - -#ifndef ITT_ARCH -# if defined _M_IX86 || defined __i386__ -# define ITT_ARCH ITT_ARCH_IA32 -# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__ -# define ITT_ARCH ITT_ARCH_IA32E -# elif defined _M_IA64 || defined __ia64__ -# define ITT_ARCH ITT_ARCH_IA64 -# elif defined _M_ARM || defined __arm__ -# define ITT_ARCH ITT_ARCH_ARM -# elif defined __powerpc64__ -# define ITT_ARCH ITT_ARCH_PPC64 -# endif -#endif - -#ifdef __cplusplus -# define ITT_EXTERN_C extern "C" -# define ITT_EXTERN_C_BEGIN extern "C" { -# define ITT_EXTERN_C_END } -#else -# define ITT_EXTERN_C /* nothing */ -# define ITT_EXTERN_C_BEGIN /* nothing */ -# define ITT_EXTERN_C_END /* nothing */ -#endif /* __cplusplus */ - -#define ITT_TO_STR_AUX(x) #x -#define ITT_TO_STR(x) ITT_TO_STR_AUX(x) - -#define __ITT_BUILD_ASSERT(expr, suffix) do { \ - static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \ - __itt_build_check_##suffix[0] = 0; \ -} while(0) -#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix) -#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__) - -#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 } - -/* Replace with snapshot date YYYYMMDD for promotion build. */ -#define API_VERSION_BUILD 20151119 - -#ifndef API_VERSION_NUM -#define API_VERSION_NUM 0.0.0 -#endif /* API_VERSION_NUM */ - -#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \ - " (" ITT_TO_STR(API_VERSION_BUILD) ")" - -/* OS communication functions */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#include -typedef HMODULE lib_t; -typedef DWORD TIDT; -typedef CRITICAL_SECTION mutex_t; -#define MUTEX_INITIALIZER { 0 } -#define strong_alias(name, aliasname) /* empty for Windows */ -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#include -#if defined(UNICODE) || defined(_UNICODE) -#include -#endif /* UNICODE */ -#ifndef _GNU_SOURCE -#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */ -#endif /* _GNU_SOURCE */ -#ifndef __USE_UNIX98 -#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */ -#endif /*__USE_UNIX98*/ -#include -typedef void* lib_t; -typedef pthread_t TIDT; -typedef pthread_mutex_t mutex_t; -#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER -#define _strong_alias(name, aliasname) \ - extern __typeof (name) aliasname __attribute__ ((alias (#name))); -#define strong_alias(name, aliasname) _strong_alias(name, aliasname) -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define __itt_get_proc(lib, name) GetProcAddress(lib, name) -#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex) -#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex) -#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex) -#define __itt_load_lib(name) LoadLibraryA(name) -#define __itt_unload_lib(handle) FreeLibrary(handle) -#define __itt_system_error() (int)GetLastError() -#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2) -#define __itt_fstrnlen(s, l) strnlen_s(s, l) -#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l) -#define __itt_fstrdup(s) _strdup(s) -#define __itt_thread_id() GetCurrentThreadId() -#define __itt_thread_yield() SwitchToThread() -#ifndef ITT_SIMPLE_INIT -ITT_INLINE long -__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; -ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) -{ - return InterlockedIncrement(ptr); -} -#endif /* ITT_SIMPLE_INIT */ - -#define DL_SYMBOLS (1) -#define PTHREAD_SYMBOLS (1) - -#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */ -#define __itt_get_proc(lib, name) dlsym(lib, name) -#define __itt_mutex_init(mutex) {\ - pthread_mutexattr_t mutex_attr; \ - int error_code = pthread_mutexattr_init(&mutex_attr); \ - if (error_code) \ - __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \ - error_code); \ - error_code = pthread_mutexattr_settype(&mutex_attr, \ - PTHREAD_MUTEX_RECURSIVE); \ - if (error_code) \ - __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \ - error_code); \ - error_code = pthread_mutex_init(mutex, &mutex_attr); \ - if (error_code) \ - __itt_report_error(__itt_error_system, "pthread_mutex_init", \ - error_code); \ - error_code = pthread_mutexattr_destroy(&mutex_attr); \ - if (error_code) \ - __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \ - error_code); \ -} -#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex) -#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex) -#define __itt_load_lib(name) dlopen(name, RTLD_LAZY) -#define __itt_unload_lib(handle) dlclose(handle) -#define __itt_system_error() errno -#define __itt_fstrcmp(s1, s2) strcmp(s1, s2) - -/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */ -#ifdef SDL_STRNLEN_S -#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l) -#else -#define __itt_fstrnlen(s, l) strlen(s) -#endif /* SDL_STRNLEN_S */ -#ifdef SDL_STRNCPY_S -#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l) -#else -#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l) -#endif /* SDL_STRNCPY_S */ - -#define __itt_fstrdup(s) strdup(s) -#define __itt_thread_id() pthread_self() -#define __itt_thread_yield() sched_yield() -#if ITT_ARCH==ITT_ARCH_IA64 -#ifdef __INTEL_COMPILER -#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val) -#else /* __INTEL_COMPILER */ -/* TODO: Add Support for not Intel compilers for IA-64 architecture */ -#endif /* __INTEL_COMPILER */ -#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */ -ITT_INLINE long -__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE; -ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend) -{ - long result; - __asm__ __volatile__("lock\nxadd %0,%1" - : "=r"(result),"=m"(*(int*)ptr) - : "0"(addend), "m"(*(int*)ptr) - : "memory"); - return result; -} -#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64 -#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val) -#endif /* ITT_ARCH==ITT_ARCH_IA64 */ -#ifndef ITT_SIMPLE_INIT -ITT_INLINE long -__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; -ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) -{ - return __TBB_machine_fetchadd4(ptr, 1) + 1L; -} -#endif /* ITT_SIMPLE_INIT */ - -void* dlopen(const char*, int) __attribute__((weak)); -void* dlsym(void*, const char*) __attribute__((weak)); -int dlclose(void*) __attribute__((weak)); -#define DL_SYMBOLS (dlopen && dlsym && dlclose) - -int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak)); -int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak)); -int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak)); -int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak)); -int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak)); -int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak)); -int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak)); -pthread_t pthread_self(void) __attribute__((weak)); -#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self) - -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -typedef enum { - __itt_collection_normal = 0, - __itt_collection_paused = 1 -} __itt_collection_state; - -typedef enum { - __itt_thread_normal = 0, - __itt_thread_ignored = 1 -} __itt_thread_state; - -#pragma pack(push, 8) - -typedef struct ___itt_thread_info -{ - const char* nameA; /*!< Copy of original name in ASCII. */ -#if defined(UNICODE) || defined(_UNICODE) - const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ -#else /* UNICODE || _UNICODE */ - void* nameW; -#endif /* UNICODE || _UNICODE */ - TIDT tid; - __itt_thread_state state; /*!< Thread state (paused or normal) */ - int extra1; /*!< Reserved to the runtime */ - void* extra2; /*!< Reserved to the runtime */ - struct ___itt_thread_info* next; -} __itt_thread_info; - -#include "ittnotify_types.h" /* For __itt_group_id definition */ - -typedef struct ___itt_api_info_20101001 -{ - const char* name; - void** func_ptr; - void* init_func; - __itt_group_id group; -} __itt_api_info_20101001; - -typedef struct ___itt_api_info -{ - const char* name; - void** func_ptr; - void* init_func; - void* null_func; - __itt_group_id group; -} __itt_api_info; - -typedef struct __itt_counter_info -{ - const char* nameA; /*!< Copy of original name in ASCII. */ -#if defined(UNICODE) || defined(_UNICODE) - const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ -#else /* UNICODE || _UNICODE */ - void* nameW; -#endif /* UNICODE || _UNICODE */ - const char* domainA; /*!< Copy of original name in ASCII. */ -#if defined(UNICODE) || defined(_UNICODE) - const wchar_t* domainW; /*!< Copy of original name in UNICODE. */ -#else /* UNICODE || _UNICODE */ - void* domainW; -#endif /* UNICODE || _UNICODE */ - int type; - long index; - int extra1; /*!< Reserved to the runtime */ - void* extra2; /*!< Reserved to the runtime */ - struct __itt_counter_info* next; -} __itt_counter_info_t; - -struct ___itt_domain; -struct ___itt_string_handle; - -typedef struct ___itt_global -{ - unsigned char magic[8]; - unsigned long version_major; - unsigned long version_minor; - unsigned long version_build; - volatile long api_initialized; - volatile long mutex_initialized; - volatile long atomic_counter; - mutex_t mutex; - lib_t lib; - void* error_handler; - const char** dll_path_ptr; - __itt_api_info* api_list_ptr; - struct ___itt_global* next; - /* Joinable structures below */ - __itt_thread_info* thread_list; - struct ___itt_domain* domain_list; - struct ___itt_string_handle* string_list; - __itt_collection_state state; - __itt_counter_info_t* counter_list; -} __itt_global; - -#pragma pack(pop) - -#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \ - h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ - if (h != NULL) { \ - h->tid = t; \ - h->nameA = NULL; \ - h->nameW = n ? _wcsdup(n) : NULL; \ - h->state = s; \ - h->extra1 = 0; /* reserved */ \ - h->extra2 = NULL; /* reserved */ \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->thread_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \ - h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ - if (h != NULL) { \ - h->tid = t; \ - h->nameA = n ? __itt_fstrdup(n) : NULL; \ - h->nameW = NULL; \ - h->state = s; \ - h->extra1 = 0; /* reserved */ \ - h->extra2 = NULL; /* reserved */ \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->thread_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \ - h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ - if (h != NULL) { \ - h->flags = 1; /* domain is enabled by default */ \ - h->nameA = NULL; \ - h->nameW = name ? _wcsdup(name) : NULL; \ - h->extra1 = 0; /* reserved */ \ - h->extra2 = NULL; /* reserved */ \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->domain_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \ - h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ - if (h != NULL) { \ - h->flags = 1; /* domain is enabled by default */ \ - h->nameA = name ? __itt_fstrdup(name) : NULL; \ - h->nameW = NULL; \ - h->extra1 = 0; /* reserved */ \ - h->extra2 = NULL; /* reserved */ \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->domain_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \ - h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ - if (h != NULL) { \ - h->strA = NULL; \ - h->strW = name ? _wcsdup(name) : NULL; \ - h->extra1 = 0; /* reserved */ \ - h->extra2 = NULL; /* reserved */ \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->string_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \ - h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ - if (h != NULL) { \ - h->strA = name ? __itt_fstrdup(name) : NULL; \ - h->strW = NULL; \ - h->extra1 = 0; /* reserved */ \ - h->extra2 = NULL; /* reserved */ \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->string_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \ - h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ - if (h != NULL) { \ - h->nameA = NULL; \ - h->nameW = name ? _wcsdup(name) : NULL; \ - h->domainA = NULL; \ - h->domainW = name ? _wcsdup(domain) : NULL; \ - h->type = type; \ - h->index = 0; \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->counter_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \ - h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ - if (h != NULL) { \ - h->nameA = name ? __itt_fstrdup(name) : NULL; \ - h->nameW = NULL; \ - h->domainA = domain ? __itt_fstrdup(domain) : NULL; \ - h->domainW = NULL; \ - h->type = type; \ - h->index = 0; \ - h->next = NULL; \ - if (h_tail == NULL) \ - (gptr)->counter_list = h; \ - else \ - h_tail->next = h; \ - } \ -} - -#endif /* _ITTNOTIFY_CONFIG_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h deleted file mode 100644 index 99fbc2405..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h +++ /dev/null @@ -1,94 +0,0 @@ -/* - - Contact Information: - http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ - - BSD LICENSE - - Copyright (c) 2005-2014 Intel Corporation. All rights reserved. - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. - * Neither the name of Intel Corporation nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef _ITTNOTIFY_TYPES_H_ -#define _ITTNOTIFY_TYPES_H_ - -typedef enum ___itt_group_id -{ - __itt_group_none = 0, - __itt_group_legacy = 1<<0, - __itt_group_control = 1<<1, - __itt_group_thread = 1<<2, - __itt_group_mark = 1<<3, - __itt_group_sync = 1<<4, - __itt_group_fsync = 1<<5, - __itt_group_jit = 1<<6, - __itt_group_model = 1<<7, - __itt_group_splitter_min = 1<<7, - __itt_group_counter = 1<<8, - __itt_group_frame = 1<<9, - __itt_group_stitch = 1<<10, - __itt_group_heap = 1<<11, - __itt_group_splitter_max = 1<<12, - __itt_group_structure = 1<<12, - __itt_group_suppress = 1<<13, - __itt_group_arrays = 1<<14, - __itt_group_all = -1 -} __itt_group_id; - -#pragma pack(push, 8) - -typedef struct ___itt_group_list -{ - __itt_group_id id; - const char* name; -} __itt_group_list; - -#pragma pack(pop) - -#define ITT_GROUP_LIST(varname) \ - static __itt_group_list varname[] = { \ - { __itt_group_all, "all" }, \ - { __itt_group_control, "control" }, \ - { __itt_group_thread, "thread" }, \ - { __itt_group_mark, "mark" }, \ - { __itt_group_sync, "sync" }, \ - { __itt_group_fsync, "fsync" }, \ - { __itt_group_jit, "jit" }, \ - { __itt_group_model, "model" }, \ - { __itt_group_counter, "counter" }, \ - { __itt_group_frame, "frame" }, \ - { __itt_group_stitch, "stitch" }, \ - { __itt_group_heap, "heap" }, \ - { __itt_group_structure, "structure" }, \ - { __itt_group_suppress, "suppress" }, \ - { __itt_group_arrays, "arrays" }, \ - { __itt_group_none, NULL } \ - } - -#endif /* _ITTNOTIFY_TYPES_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c deleted file mode 100644 index 15f4b9929..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c +++ /dev/null @@ -1,293 +0,0 @@ -/* - - Contact Information: - http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ - - BSD LICENSE - - Copyright (c) 2005-2014 Intel Corporation. All rights reserved. - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. - * Neither the name of Intel Corporation nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#include "ittnotify_config.h" - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#include -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ -#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD -#include -#endif -#include - -#include "jitprofiling.h" - -static const char rcsid[] = "\n@(#) $Revision: 471937 $\n"; - -#define DLL_ENVIRONMENT_VAR "VS_PROFILER" - -#ifndef NEW_DLL_ENVIRONMENT_VAR -#if ITT_ARCH==ITT_ARCH_IA32 -#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32" -#else -#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64" -#endif -#endif /* NEW_DLL_ENVIRONMENT_VAR */ - -#if ITT_PLATFORM==ITT_PLATFORM_WIN -#define DEFAULT_DLLNAME "JitPI.dll" -HINSTANCE m_libHandle = NULL; -#elif ITT_PLATFORM==ITT_PLATFORM_MAC -#define DEFAULT_DLLNAME "libJitPI.dylib" -void* m_libHandle = NULL; -#else -#define DEFAULT_DLLNAME "libJitPI.so" -void* m_libHandle = NULL; -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - -/* default location of JIT profiling agent on Android */ -#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so" - -/* the function pointers */ -typedef unsigned int(JITAPI *TPInitialize)(void); -static TPInitialize FUNC_Initialize=NULL; - -typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*); -static TPNotify FUNC_NotifyEvent=NULL; - -static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING; - -/* end collector dll part. */ - -/* loadiJIT_Funcs() : this function is called just in the beginning - * and is responsible to load the functions from BistroJavaCollector.dll - * result: - * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1 - * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0 - */ -static int loadiJIT_Funcs(void); - -/* global representing whether the collector can't be loaded */ -static int iJIT_DLL_is_missing = 0; - -ITT_EXTERN_C int JITAPI -iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData) -{ - int ReturnValue = 0; - - /* initialization part - the collector has not been loaded yet. */ - if (!FUNC_NotifyEvent) - { - if (iJIT_DLL_is_missing) - return 0; - - if (!loadiJIT_Funcs()) - return 0; - } - - if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED || - event_type == iJVM_EVENT_TYPE_METHOD_UPDATE) - { - if (((piJIT_Method_Load)EventSpecificData)->method_id == 0) - return 0; - } - else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2) - { - if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0) - return 0; - } - else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3) - { - if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0) - return 0; - } - else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED) - { - if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 || - ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0) - return 0; - } - - ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData); - - return ReturnValue; -} - -ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive() -{ - if (!iJIT_DLL_is_missing) - { - loadiJIT_Funcs(); - } - - return executionMode; -} - -/* This function loads the collector dll and the relevant functions. - * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1 - * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0 - */ -static int loadiJIT_Funcs() -{ - static int bDllWasLoaded = 0; - char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN - DWORD dNameLength = 0; -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - - if(bDllWasLoaded) - { - /* dll was already loaded, no need to do it for the second time */ - return 1; - } - - /* Assumes that the DLL will not be found */ - iJIT_DLL_is_missing = 1; - FUNC_NotifyEvent = NULL; - - if (m_libHandle) - { -#if ITT_PLATFORM==ITT_PLATFORM_WIN - FreeLibrary(m_libHandle); -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - dlclose(m_libHandle); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - m_libHandle = NULL; - } - - /* Try to get the dll name from the environment */ -#if ITT_PLATFORM==ITT_PLATFORM_WIN - dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0); - if (dNameLength) - { - DWORD envret = 0; - dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); - if(dllName != NULL) - { - envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, - dllName, dNameLength); - if (envret) - { - /* Try to load the dll from the PATH... */ - m_libHandle = LoadLibraryExA(dllName, - NULL, LOAD_WITH_ALTERED_SEARCH_PATH); - } - free(dllName); - } - } else { - /* Try to use old VS_PROFILER variable */ - dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0); - if (dNameLength) - { - DWORD envret = 0; - dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); - if(dllName != NULL) - { - envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, - dllName, dNameLength); - if (envret) - { - /* Try to load the dll from the PATH... */ - m_libHandle = LoadLibraryA(dllName); - } - free(dllName); - } - } - } -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - dllName = getenv(NEW_DLL_ENVIRONMENT_VAR); - if (!dllName) - dllName = getenv(DLL_ENVIRONMENT_VAR); -#if defined(__ANDROID__) || defined(ANDROID) - if (!dllName) - dllName = ANDROID_JIT_AGENT_PATH; -#endif - if (dllName) - { - /* Try to load the dll from the PATH... */ - m_libHandle = dlopen(dllName, RTLD_LAZY); - } -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - - if (!m_libHandle) - { -#if ITT_PLATFORM==ITT_PLATFORM_WIN - m_libHandle = LoadLibraryA(DEFAULT_DLLNAME); -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - } - - /* if the dll wasn't loaded - exit. */ - if (!m_libHandle) - { - iJIT_DLL_is_missing = 1; /* don't try to initialize - * JIT agent the second time - */ - return 0; - } - -#if ITT_PLATFORM==ITT_PLATFORM_WIN - FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent"); -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent"); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - if (!FUNC_NotifyEvent) - { - FUNC_Initialize = NULL; - return 0; - } - -#if ITT_PLATFORM==ITT_PLATFORM_WIN - FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize"); -#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize"); -#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ - if (!FUNC_Initialize) - { - FUNC_NotifyEvent = NULL; - return 0; - } - - executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize(); - - bDllWasLoaded = 1; - iJIT_DLL_is_missing = 0; /* DLL is ok. */ - - return 1; -} - -ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID() -{ - static unsigned int methodID = 1; - - if (methodID == 0) - return 0; /* ERROR : this is not a valid value */ - - return methodID++; -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h deleted file mode 100644 index bf0489b1a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h +++ /dev/null @@ -1,673 +0,0 @@ -/* - - Contact Information: - http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ - - BSD LICENSE - - Copyright (c) 2005-2014 Intel Corporation. All rights reserved. - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. - * Neither the name of Intel Corporation nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __JITPROFILING_H__ -#define __JITPROFILING_H__ - -/** - * @brief JIT Profiling APIs - * - * The JIT Profiling API is used to report information about just-in-time - * generated code that can be used by performance tools. The user inserts - * calls in the code generator to report information before JIT-compiled - * code goes to execution. This information is collected at runtime and used - * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics - * associated with JIT-compiled code. - * - * These APIs can be used to\n - * - **Profile trace-based and method-based JIT-compiled - * code**. Some examples of environments that you can profile with these APIs: - * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM) - * software technology, Java/.NET managed execution environments, and custom - * ISV JIT engines. - * @code - * #include - * - * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) { - * return; - * } - * - * iJIT_Method_Load jmethod = {0}; - * jmethod.method_id = iJIT_GetNewMethodID(); - * jmethod.method_name = "method_name"; - * jmethod.class_file_name = "class_name"; - * jmethod.source_file_name = "source_file_name"; - * jmethod.method_load_address = code_addr; - * jmethod.method_size = code_size; - * - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod); - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL); - * @endcode - * - * * Expected behavior: - * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an - * already reported method, then such a method becomes invalid and its - * memory region is treated as unloaded. VTune Amplifier displays the metrics - * collected by the method until it is overwritten. - * * If supplied line number information contains multiple source lines for - * the same assembly instruction (code location), then VTune Amplifier picks up - * the first line number. - * * Dynamically generated code can be associated with a module name. - * Use the iJIT_Method_Load_V2 structure.\n - * Clarification of some cases: - * * If you register a function with the same method ID multiple times, - * specifying different module names, then the VTune Amplifier picks up - * the module name registered first. If you want to distinguish the same - * function between different JIT engines, supply different method IDs for - * each function. Other symbolic information (for example, source file) - * can be identical. - * - * - **Analyze split functions** (multiple joint or disjoint code regions - * belonging to the same function) **including re-JIT** - * with potential overlapping of code regions in time, which is common in - * resource-limited environments. - * @code - * #include - * - * unsigned int method_id = iJIT_GetNewMethodID(); - * - * iJIT_Method_Load a = {0}; - * a.method_id = method_id; - * a.method_load_address = 0x100; - * a.method_size = 0x20; - * - * iJIT_Method_Load b = {0}; - * b.method_id = method_id; - * b.method_load_address = 0x200; - * b.method_size = 0x30; - * - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b); - * @endcode - * - * * Expected behaviour: - * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an - * already reported method, then such a method becomes invalid and - * its memory region is treated as unloaded. - * * All code regions reported with the same method ID are considered as - * belonging to the same method. Symbolic information (method name, - * source file name) will be taken from the first notification, and all - * subsequent notifications with the same method ID will be processed - * only for line number table information. So, the VTune Amplifier will map - * samples to a source line using the line number table from the current - * notification while taking the source file name from the very first one.\n - * Clarification of some cases:\n - * * If you register a second code region with a different source file - * name and the same method ID, then this information will be saved and - * will not be considered as an extension of the first code region, but - * VTune Amplifier will use the source file of the first code region and map - * performance metrics incorrectly. - * * If you register a second code region with the same source file as - * for the first region and the same method ID, then the source file will be - * discarded but VTune Amplifier will map metrics to the source file correctly. - * * If you register a second code region with a null source file and - * the same method ID, then provided line number info will be associated - * with the source file of the first code region. - * - * - **Explore inline functions** including multi-level hierarchy of - * nested inline methods which shows how performance metrics are distributed through them. - * @code - * #include - * - * // method_id parent_id - * // [-- c --] 3000 2000 - * // [---- d -----] 2001 1000 - * // [---- b ----] 2000 1000 - * // [------------ a ----------------] 1000 n/a - * - * iJIT_Method_Load a = {0}; - * a.method_id = 1000; - * - * iJIT_Method_Inline_Load b = {0}; - * b.method_id = 2000; - * b.parent_method_id = 1000; - * - * iJIT_Method_Inline_Load c = {0}; - * c.method_id = 3000; - * c.parent_method_id = 2000; - * - * iJIT_Method_Inline_Load d = {0}; - * d.method_id = 2001; - * d.parent_method_id = 1000; - * - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b); - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c); - * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d); - * @endcode - * - * * Requirements: - * * Each inline (iJIT_Method_Inline_Load) method should be associated - * with two method IDs: one for itself; one for its immediate parent. - * * Address regions of inline methods of the same parent method cannot - * overlap each other. - * * Execution of the parent method must not be started until it and all - * its inline methods are reported. - * * Expected behaviour: - * * In case of nested inline methods an order of - * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important. - * * If any event overwrites either inline method or top parent method, - * then the parent, including inline methods, becomes invalid and its memory - * region is treated as unloaded. - * - * **Life time of allocated data**\n - * The client sends an event notification to the agent with event-specific - * data, which is a structure. The pointers in the structure refer to memory - * allocated by the client, which responsible for releasing it. The pointers are - * used by the iJIT_NotifyEvent method to copy client's data in a trace file, - * and they are not used after the iJIT_NotifyEvent method returns. - */ - -/** - * @defgroup jitapi JIT Profiling - * @ingroup internal - * @{ - */ - -/** - * @brief Enumerator for the types of notifications - */ -typedef enum iJIT_jvm_event -{ - iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent. - * Use NULL for event data. */ - - iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is - * JIT compiled and loaded into - * memory by the JIT engine, but - * before the code is executed. - * Use iJIT_Method_Load as event - * data. */ -/** @cond exclude_from_documentation */ - iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic - * code is being unloaded from memory. - * Use iJIT_Method_Load as event data.*/ -/** @endcond */ - - iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for - * a previously reported dynamic code. - * The previous content will be invalidated - * starting from the time of the notification. - * Use iJIT_Method_Load as event data but - * required fields are following: - * - method_id identify the code to update. - * - method_load_address specify start address - * within identified code range - * where update should be started. - * - method_size specify length of updated code - * range. */ - - - iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic - * code is JIT compiled and loaded - * into memory by the JIT engine, - * but before the parent code region - * starts executing. - * Use iJIT_Method_Inline_Load as event data.*/ - -/** @cond exclude_from_documentation */ - iJVM_EVENT_TYPE_METHOD_UPDATE_V2, -/** @endcond */ - - iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is - * JIT compiled and loaded into - * memory by the JIT engine, but - * before the code is executed. - * Use iJIT_Method_Load_V2 as event data. */ - - iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is - * JIT compiled and loaded into - * memory by the JIT engine, but - * before the code is executed. - * Use iJIT_Method_Load_V3 as event data. */ -} iJIT_JVM_EVENT; - -/** - * @brief Enumerator for the agent's mode - */ -typedef enum _iJIT_IsProfilingActiveFlags -{ - iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running; - * iJIT_NotifyEvent calls will - * not be processed. */ - iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and - * ready to process notifications. */ -} iJIT_IsProfilingActiveFlags; - -/** - * @brief Description of a single entry in the line number information of a code region. - * @details A table of line number entries gives information about how the reported code region - * is mapped to source file. - * Intel(R) VTune(TM) Amplifier uses line number information to attribute - * the samples (virtual address) to a line number. \n - * It is acceptable to report different code addresses for the same source line: - * @code - * Offset LineNumber - * 1 2 - * 12 4 - * 15 2 - * 18 1 - * 21 30 - * - * VTune Amplifier constructs the following table using the client data - * - * Code subrange Line number - * 0-1 2 - * 1-12 4 - * 12-15 2 - * 15-18 1 - * 18-21 30 - * @endcode - */ -typedef struct _LineNumberInfo -{ - unsigned int Offset; /**<\brief Offset from the begining of the code region. */ - unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */ - -} *pLineNumberInfo, LineNumberInfo; - -/** - * @brief Enumerator for the code architecture. - */ -typedef enum _iJIT_CodeArchitecture -{ - iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */ - - iJIT_CA_32, /**<\brief 32-bit machine code. */ - - iJIT_CA_64 /**<\brief 64-bit machine code. */ - -} iJIT_CodeArchitecture; - -#pragma pack(push, 8) - -/** - * @brief Description of a JIT-compiled method - * @details When you use the iJIT_Method_Load structure to describe - * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED - * as an event type to report it. - */ -typedef struct _iJIT_Method_Load -{ - unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. - * You must either use the API function - * iJIT_GetNewMethodID to get a valid and unique - * method ID, or else manage ID uniqueness - * and correct range by yourself.\n - * You must use the same method ID for all code - * regions of the same method, otherwise different - * method IDs specify different methods. */ - - char* method_name; /**<\brief The name of the method. It can be optionally - * prefixed with its class name and appended with - * its complete signature. Can't be NULL. */ - - void* method_load_address; /**<\brief The start virtual address of the method code - * region. If NULL, data provided with - * event are not accepted. */ - - unsigned int method_size; /**<\brief The code size of the method in memory. - * If 0, then data provided with the event are not - * accepted. */ - - unsigned int line_number_size; /**<\brief The number of entries in the line number - * table.0 if none. */ - - pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info - * array. Can be NULL if - * line_number_size is 0. See - * LineNumberInfo Structure for a - * description of a single entry in - * the line number info array */ - - unsigned int class_id; /**<\brief This field is obsolete. */ - - char* class_file_name; /**<\brief Class name. Can be NULL.*/ - - char* source_file_name; /**<\brief Source file name. Can be NULL.*/ - -} *piJIT_Method_Load, iJIT_Method_Load; - -/** - * @brief Description of a JIT-compiled method - * @details When you use the iJIT_Method_Load_V2 structure to describe - * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 - * as an event type to report it. - */ -typedef struct _iJIT_Method_Load_V2 -{ - unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. - * You must either use the API function - * iJIT_GetNewMethodID to get a valid and unique - * method ID, or else manage ID uniqueness - * and correct range by yourself.\n - * You must use the same method ID for all code - * regions of the same method, otherwise different - * method IDs specify different methods. */ - - char* method_name; /**<\brief The name of the method. It can be optionally - * prefixed with its class name and appended with - * its complete signature. Can't be NULL. */ - - void* method_load_address; /**<\brief The start virtual address of the method code - * region. If NULL, then data provided with the - * event are not accepted. */ - - unsigned int method_size; /**<\brief The code size of the method in memory. - * If 0, then data provided with the event are not - * accepted. */ - - unsigned int line_number_size; /**<\brief The number of entries in the line number - * table. 0 if none. */ - - pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info - * array. Can be NULL if - * line_number_size is 0. See - * LineNumberInfo Structure for a - * description of a single entry in - * the line number info array. */ - - char* class_file_name; /**<\brief Class name. Can be NULL. */ - - char* source_file_name; /**<\brief Source file name. Can be NULL. */ - - char* module_name; /**<\brief Module name. Can be NULL. - The module name can be useful for distinguishing among - different JIT engines. VTune Amplifier will display - reported methods grouped by specific module. */ - -} *piJIT_Method_Load_V2, iJIT_Method_Load_V2; - -/** - * @brief Description of a JIT-compiled method - * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2 - * with a newly introduced 'arch' field that specifies architecture of the code region. - * When you use the iJIT_Method_Load_V3 structure to describe - * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 - * as an event type to report it. - */ -typedef struct _iJIT_Method_Load_V3 -{ - unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. - * You must either use the API function - * iJIT_GetNewMethodID to get a valid and unique - * method ID, or manage ID uniqueness - * and correct range by yourself.\n - * You must use the same method ID for all code - * regions of the same method, otherwise they are - * treated as regions of different methods. */ - - char* method_name; /**<\brief The name of the method. It can be optionally - * prefixed with its class name and appended with - * its complete signature. Cannot be NULL. */ - - void* method_load_address; /**<\brief The start virtual address of the method code - * region. If NULL, then data provided with the - * event are not accepted. */ - - unsigned int method_size; /**<\brief The code size of the method in memory. - * If 0, then data provided with the event are not - * accepted. */ - - unsigned int line_number_size; /**<\brief The number of entries in the line number - * table. 0 if none. */ - - pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info - * array. Can be NULL if - * line_number_size is 0. See - * LineNumberInfo Structure for a - * description of a single entry in - * the line number info array. */ - - char* class_file_name; /**<\brief Class name. Can be NULL. */ - - char* source_file_name; /**<\brief Source file name. Can be NULL. */ - - char* module_name; /**<\brief Module name. Can be NULL. - * The module name can be useful for distinguishing among - * different JIT engines. VTune Amplifier will display - * reported methods grouped by specific module. */ - - iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region. - * By default, it is the same as the process - * architecture that is calling it. - * For example, you can use it if your 32-bit JIT - * engine generates 64-bit code. - * - * If JIT engine reports both 32-bit and 64-bit types - * of methods then VTune Amplifier splits the methods - * with the same module name but with different - * architectures in two different modules. VTune Amplifier - * modifies the original name provided with a 64-bit method - * version by ending it with '(64)' */ - -} *piJIT_Method_Load_V3, iJIT_Method_Load_V3; - -/** - * @brief Description of an inline JIT-compiled method - * @details When you use the_iJIT_Method_Inline_Load structure to describe - * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED - * as an event type to report it. - */ -typedef struct _iJIT_Method_Inline_Load -{ - unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. - * You must either use the API function - * iJIT_GetNewMethodID to get a valid and unique - * method ID, or else manage ID uniqueness - * and correct range by yourself. */ - - unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID. - * Cannot be 0. - * You must either use the API function - * iJIT_GetNewMethodID to get a valid and unique - * method ID, or else manage ID uniqueness - * and correct range by yourself. */ - - char* method_name; /**<\brief The name of the method. It can be optionally - * prefixed with its class name and appended with - * its complete signature. Can't be NULL. */ - - void* method_load_address; /** <\brief The virtual address on which the method - * is inlined. If NULL, then data provided with - * the event are not accepted. */ - - unsigned int method_size; /**<\brief The code size of the method in memory. - * If 0, then data provided with the event are not - * accepted. */ - - unsigned int line_number_size; /**<\brief The number of entries in the line number - * table. 0 if none. */ - - pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info - * array. Can be NULL if - * line_number_size is 0. See - * LineNumberInfo Structure for a - * description of a single entry in - * the line number info array */ - - char* class_file_name; /**<\brief Class name. Can be NULL.*/ - - char* source_file_name; /**<\brief Source file name. Can be NULL.*/ - -} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load; - -/** @cond exclude_from_documentation */ -/** - * @brief Description of a segment type - * @details Use the segment type to specify a type of data supplied - * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to - * a certain code trace. - */ -typedef enum _iJIT_SegmentType -{ - iJIT_CT_UNKNOWN = 0, - - iJIT_CT_CODE, /**<\brief Executable code. */ - - iJIT_CT_DATA, /**<\brief Data (not executable code). - * VTune Amplifier uses the format string - * (see iJIT_Method_Update) to represent - * this data in the VTune Amplifier GUI */ - - iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace. - * Can be used for the following - * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events, - * if the type of the previously reported segment - * type is the same. */ - iJIT_CT_EOF -} iJIT_SegmentType; - -/** - * @brief Description of a dynamic update of the content within JIT-compiled method - * @details The JIT engine may generate the methods that are updated at runtime - * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update - * structure to describe the update of the content within a JIT-compiled method, - * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it. - * - * On the first Update event, VTune Amplifier copies the original code range reported by - * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and - * adds the modified range to the original method. For next update events, VTune Amplifier - * does the same but it uses the latest modified version of a code region for update. - * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by - * the iJVM_EVENT_TYPE_METHOD_LOAD event. - * Notes: - * - Multiple update events with different types for the same trace are allowed - * but they must be reported for the same code ranges. - * Example, - * @code - * [-- data---] Allowed - * [-- code --] Allowed - * [code] Ignored - * [-- data---] Allowed - * [-- code --] Allowed - * [------------ trace ---------] - * @endcode - * - The types of previously reported events can be changed but they must be reported - * for the same code ranges. - * Example, - * @code - * [-- data---] Allowed - * [-- code --] Allowed - * [-- data---] Allowed - * [-- code --] Allowed - * [------------ trace ---------] - * @endcode - */ - -typedef struct _iJIT_Method_Update -{ - void* load_address; /**<\brief Start address of the update within a method */ - - unsigned int size; /**<\brief The update size */ - - iJIT_SegmentType type; /**<\brief Type of the update */ - - const char* data_format; /**<\brief C string that contains a format string - * that follows the same specifications as format in printf. - * The format string is used for iJIT_CT_CODE only - * and cannot be NULL. - * Format can be changed on the fly. */ -} *piJIT_Method_Update, iJIT_Method_Update; - -/** @endcond */ - -#pragma pack(pop) - -/** @cond exclude_from_documentation */ -#ifdef __cplusplus -extern "C" { -#endif /* __cplusplus */ - -#ifndef JITAPI_CDECL -# if defined WIN32 || defined _WIN32 -# define JITAPI_CDECL __cdecl -# else /* defined WIN32 || defined _WIN32 */ -# if defined _M_IX86 || defined __i386__ -# define JITAPI_CDECL __attribute__ ((cdecl)) -# else /* _M_IX86 || __i386__ */ -# define JITAPI_CDECL /* actual only on x86_64 platform */ -# endif /* _M_IX86 || __i386__ */ -# endif /* defined WIN32 || defined _WIN32 */ -#endif /* JITAPI_CDECL */ - -#define JITAPI JITAPI_CDECL -/** @endcond */ - -/** - * @brief Generates a new unique method ID. - * - * You must use this API to obtain unique and valid method IDs for methods or - * traces reported to the agent if you don't have your own mechanism to generate - * unique method IDs. - * - * @return a new unique method ID. When out of unique method IDs, this API - * returns 0, which is not an accepted value. - */ -unsigned int JITAPI iJIT_GetNewMethodID(void); - -/** - * @brief Returns the current mode of the agent. - * - * @return iJIT_SAMPLING_ON, indicating that agent is running, or - * iJIT_NOTHING_RUNNING if no agent is running. - */ -iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void); - -/** - * @brief Reports infomation about JIT-compiled code to the agent. - * - * The reported information is used to attribute samples obtained from any - * Intel(R) VTune(TM) Amplifier collector. This API needs to be called - * after JIT compilation and before the first entry into the JIT-compiled - * code. - * - * @param[in] event_type - type of the data sent to the agent - * @param[in] EventSpecificData - pointer to event-specific data - * - * @returns 1 on success, otherwise 0. - */ -int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData); - -#ifdef __cplusplus -} -#endif /* __cplusplus */ -/** @endcond */ - -/** @} jitapi group */ - -#endif /* __JITPROFILING_H__ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp deleted file mode 100644 index ef4c42bac..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp +++ /dev/null @@ -1,317 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" - -#include "nchw_pooling.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -void nchw_pooling_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - using namespace alg_kind; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); - - const memory_desc_wrapper ws_d(pd()->workspace_md()); - const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; - - const int MB = pd()->MB(); - const int C = pd()->C(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - const int SD = pd()->KSD(); - const int SH = pd()->KSH(); - const int SW = pd()->KSW(); - const int padF = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - auto alg = pd()->desc()->alg_kind; - - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; - - auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) { - if (ws) { - assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); - size_t ws_offset - = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c - + (size_t)OW * OH * od - + (size_t)OW * oh - + (size_t)ow; - if (ws_dt == data_type::u8) { - assert(0 <= value && value <= 255); - ws[ws_offset] = value; - } else - reinterpret_cast(ws)[ws_offset] = value; - } - }; - - auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) { - for (int kd = 0; kd < KD; ++kd) { - for (int kh = 0; kh < KH; ++kh) { - for (int kw = 0; kw < KW; ++kw) { - const int id = od * SD - padF + kd; - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - if (id < 0 || id >= ID) continue; - if (ih < 0 || ih >= IH) continue; - if (iw < 0 || iw >= IW) continue; - - auto src_offset - = (size_t)IW * IH * ID * C * mb - + (size_t)IW * IH * ID * c - + (size_t)IW * IH * id - + (size_t)IW * ih - + (size_t)iw; - auto s = src[src_offset]; - if (s > d[0]) { - d[0] = s; - set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw); - } - } - } - } - }; - - auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) { - auto id_start = apply_offset(od*SD, padF); - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto id_end = nstl::min(od*SD - padF + KD, ID); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH - : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); - - for (int id = id_start; id < id_end; ++id) { - for (int ih = ih_start; ih < ih_end; ++ih) { - for (int iw = iw_start; iw < iw_end; ++iw) { - auto src_offset - = (size_t)IW * IH * ID * C * mb - + (size_t)IW * IH * ID * c - + (size_t)IW * IH * id - + (size_t)IW * ih - + (size_t)iw; - d[0] += src[src_offset]; - } - } - } - - d[0] = math::out_round((float)d[0] / num_summands); - }; - - - if (pd()->desc()->alg_kind == pooling_max) { - parallel_nd(MB, C, OD, OH, OW, - [&](int mb, int c, int od, int oh, int ow) { - size_t dst_offset - = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c - + (size_t)OW * OH * od - + (size_t)OW * oh - + (size_t)ow; - data_t *d = &dst[dst_offset]; - d[0] = nstl::numeric_limits::lowest(); - set_ws(mb, c, od, oh, ow, 0); - ker_max(d, mb, c, od, oh, ow); - }); - } else { - parallel_nd(MB, C, OD, OH, OW, - [&](int mb, int c, int od, int oh, int ow) { - size_t dst_offset - = (size_t)OW * OH * OD * C * mb - + (size_t)OW * OH * OD * c - + (size_t)OW * OH * od - + (size_t)OW * oh - + (size_t)ow; - data_t *d = &dst[dst_offset]; - d[0] = 0; - ker_avg(d, mb, c, od, oh, ow); - }); - } -} - -template -void nchw_pooling_bwd_t::execute_backward( - const exec_ctx_t &ctx) const { - using namespace alg_kind; - - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper ws_d(pd()->workspace_md()); - - const int MB = pd()->MB(); - const int C = pd()->C(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - const int SD = pd()->KSD(); - const int SH = pd()->KSH(); - const int SW = pd()->KSW(); - const int padF = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; - - auto alg = pd()->desc()->alg_kind; - - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; - - auto ker_zero = [=](int mb, int c) { - size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW; - for (int id = 0; id < ID; ++id) { - for (int ih = 0; ih < IH; ++ih) { - for (int iw = 0; iw < IW; ++iw) { - diff_src[diff_src_offset++] = 0; - } - } - } - }; - - auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { - auto b_c = ws_d.blocking_desc().inner_nblks == 0 - ? 1 : ws_d.blocking_desc().inner_blks[0]; - auto ws_offset = is_3d - ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c - : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c; - - const int index = ws_d.data_type() == data_type::u8 - ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset]; - const int kw = index % KW; - const int kh = (index / KW) % KH; - const int kd = (index / KW) / KH; - - const int id = od * SD - padF + kd; - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - // If padding area could fit the kernel, - // then input displacement would be out of bounds. - // No need to back propagate there as padding is - // virtual in pooling_max case. - if (id < 0 || id >= ID) - return; - if (ih < 0 || ih >= IH) - return; - if (iw < 0 || iw >= IW) - return; - - size_t diff_src_offset = - (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW - + (size_t)ih*IW + (size_t)iw; - diff_src[diff_src_offset] += d[0]; - }; - - auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { - auto id_start = apply_offset(od*SD, padF); - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto id_end = nstl::min(od*SD - padF + KD, ID); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - size_t num_summands = (alg == pooling_avg_include_padding) - ? (size_t)KW*KH*KD - : (size_t)(id_end - id_start)*(ih_end - ih_start) - *(iw_end - iw_start); - - for (int id = id_start; id < id_end; ++id) { - for (int ih = ih_start; ih < ih_end; ++ih) { - for (int iw = iw_start; iw < iw_end; ++iw) { - size_t diff_src_offset = (size_t)mb*C*ID*IH*IW - + (size_t)c*ID*IH*IW + (size_t)id*IH*IW - + (size_t)ih*IW + (size_t)iw; - diff_src[diff_src_offset] += d[0] / num_summands; - } - } - } - }; - - if (pd()->desc()->alg_kind == pooling_max) { - parallel_nd(MB, C, [&](int mb, int c) { - size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW - + (size_t)c*OD*OH*OW; - ker_zero(mb, c); - for (int od = 0; od < OD; ++od) { - for (int oh = 0; oh < OH; ++oh) { - for (int ow = 0; ow < OW; ++ow) { - const data_t *d = &diff_dst[diff_dst_offset++]; - ker_max(d, mb, c, od, oh, ow); - } - } - } - }); - } else { - parallel_nd(MB, C, [&](int mb, int c) { - size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW - + (size_t)c*OD*OH*OW; - ker_zero(mb, c); - for (int od = 0; od < OD; ++od) { - for (int oh = 0; oh < OH; ++oh) { - for (int ow = 0; ow < OW; ++ow) { - const data_t *d = &diff_dst[diff_dst_offset++]; - ker_avg(d, mb, c, od, oh, ow); - } - } - } - }); - } -} - -template struct nchw_pooling_fwd_t; -template struct nchw_pooling_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp deleted file mode 100644 index bbdd04f6b..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp +++ /dev/null @@ -1,147 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_NCHW_POOLING_HPP -#define CPU_NCHW_POOLING_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_pooling_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct nchw_pooling_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_fwd_pd_t { - using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; - - DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t); - - status_t init() { - const format_tag_t desired_fmt_tag = - ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; - - bool ok = true - && set_default_params() == status::success - && is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, - alg_kind::pooling_avg_include_padding, - alg_kind::pooling_avg_exclude_padding) - && !has_zero_dim_memory() - && utils::everyone_is(data_type, src_md()->data_type, - dst_md()->data_type) - && attr()->has_default_values() - && memory_desc_matches_tag(*src_md(), desired_fmt_tag) - && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); - if (!ok) return status::unimplemented; - - bool is_training = desc_.prop_kind == prop_kind::forward_training; - if (desc()->alg_kind == alg_kind::pooling_max && is_training) - init_default_ws(); - - return status::success; - } - }; - - nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct nchw_pooling_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_bwd_pd_t { - using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; - - DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t); - - status_t init() { - const format_tag_t desired_fmt_tag = - ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; - - bool ok = true - && set_default_params() == status::success - && !is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, - alg_kind::pooling_avg_include_padding, - alg_kind::pooling_avg_exclude_padding) - && !has_zero_dim_memory() - && utils::everyone_is(data_type, - diff_dst_md()->data_type, - diff_src_md()->data_type) - && attr()->has_default_values() - && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) - && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); - if (!ok) return status::unimplemented; - - if (desc()->alg_kind == alg_kind::pooling_max) { - bool ws_ok = true - && hint_fwd_pd_ - && hint_fwd_pd_->workspace_md(); - if (!ws_ok) - return status::unimplemented; - - const auto &ws_blk = - hint_fwd_pd_->workspace_md()->format_desc.blocking; - ws_ok = ws_ok - && ws_blk.inner_nblks < 1 - && IMPLICATION(ws_blk.inner_nblks == 1, - ws_blk.inner_idxs[0] == 1); - if (!ws_ok) - return status::unimplemented; - - ws_md_ = *hint_fwd_pd_->workspace_md(); - } - - return status::success; - } - }; - - nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp deleted file mode 100644 index c0e93fefe..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp +++ /dev/null @@ -1,382 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" - -#include "cpu_batch_normalization_utils.hpp" -#include "jit_generator.hpp" - -#include "ncsp_batch_normalization.hpp" - -// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases -#if (defined __clang_major__) && (__clang_major__ >= 6) -#define SAFE_TO_USE_OMP_SIMD 0 -#else -#define SAFE_TO_USE_OMP_SIMD 1 -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace memory_tracking::names; - -void ncsp_batch_normalization_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - const bool calculate_stats = !pd()->stats_is_src(); - const bool save_stats = pd()->is_training(); - const bool is_training = pd()->is_training(); - const bool fuse_bn_relu = pd()->fuse_bn_relu(); - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - - auto scratchpad = this->scratchpad(ctx); - auto *ws_reduce = scratchpad.get(key_bnorm_reduction); - - data_t *mean, *variance; - if (!calculate_stats) { - mean = const_cast( - CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); - variance = const_cast( - CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); - } else { - if (save_stats) { - mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); - variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); - } else { - mean = scratchpad.get(key_bnorm_tmp_mean); - variance = scratchpad.get(key_bnorm_tmp_var); - } - } - - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); - - const float eps = pd()->desc()->batch_norm_epsilon; - const bool use_scaleshift = pd()->use_scaleshift(); - const bool with_relu = pd()->with_relu_post_op(); - auto maybe_post_op - = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; - const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); - dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; - dim_t N = pd()->MB(); - dim_t C = pd()->C(); - - int nthr = mkldnn_get_max_threads(); - size_t l3_size_ = get_cache_size(3, true) * nthr / 2; - size_t data_size = N * C * SP * sizeof(data_t); - bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); - - parallel(0, [&](const int ithr, const int nthr) { - int C_ithr = 0, C_nthr = 0; - int N_ithr = 0, N_nthr = 0; - int S_ithr = 0, S_nthr = 0; - - dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; - dim_t N_s = 0, N_e = 0; - dim_t S_s = 0, S_e = 0; - - dim_t C_blks_per_iter = 1; - int64_t iters = 1; - - if (do_blocking) { - size_t working_set_size = N * SP * sizeof(data_t); - bnorm_utils::cache_balance( - working_set_size, C, C_blks_per_iter, iters); - } else - C_blks_per_iter = C; - int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; - bool spatial_thr_allowed - = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, - C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, - N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); - balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); - int SP_N_ithr = N_ithr * S_nthr + S_ithr; - int SP_N_nthr = N_nthr * S_nthr; - for (int64_t it = 0; it < iters; ++it) { - if (it == iters - 1 && iters > 1) { - // On the last iteration the access pattern to ws_reduce - // might change (due to re-balance on C). So sync the - // threads if they are not synced by the algorithm. - if (SP_N_nthr == 1 && mkldnn_thr_syncable()) - mkldnn_thr_barrier(); - - S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0; - spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, - spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, - C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, - N_e, S_ithr, S_nthr, S_s, S_e); - balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); - SP_N_ithr = N_ithr * S_nthr + S_ithr; - SP_N_nthr = N_nthr * S_nthr; - } - size_t C_off = it * C_blks_per_iter; - // On the last iteration the access pattern to ws_reduce - // might change (due to re-balance on C). Since sync is not always - // possible (in case of TBB) use different parts of ws for each - // iteration if threads are not synced by the algorithm. - size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off; - - if (calculate_stats) { - data_t *mean_blk = mean + C_off; - data_t *variance_blk = variance + C_off; - for (dim_t c = C_blk_s; c < C_blk_e; c++) { - size_t off = (c + C_off) * SP; - data_t sum = 0; - for (dim_t n = N_s; n < N_e; ++n) - PRAGMA_OMP_SIMD(reduction(+ : sum)) - for (dim_t sp = S_s; sp < S_e; ++sp) { - sum += src[off + n * C * SP + sp]; - } - ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] - = sum; - } - - if (SP_N_nthr > 1) mkldnn_thr_barrier(); - - for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { - mean_blk[c] = 0.; - for (dim_t n = 0; n < SP_N_nthr; n++) - mean_blk[c] += ws_reduce[ws_iter_off - + n * C_blks_per_iter + c]; - mean_blk[c] /= (N * SP); - } - - if (SP_N_nthr > 1) mkldnn_thr_barrier(); - - for (dim_t c = C_blk_s; c < C_blk_e; c++) { - size_t off = c + C_off; - data_t sum = 0.; - for (dim_t n = N_s; n < N_e; ++n) - PRAGMA_OMP_SIMD(reduction(+ : sum)) - for (dim_t sp = S_s; sp < S_e; ++sp) { - data_t m = src[off * SP + n * C * SP + sp] - - mean[off]; - sum += m * m; - } - ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] - = sum; - } - - if (SP_N_nthr > 1) mkldnn_thr_barrier(); - - for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { - variance_blk[c] = 0.; - for (dim_t n = 0; n < SP_N_nthr; n++) - variance_blk[c] += ws_reduce[ws_iter_off - + n * C_blks_per_iter + c]; - variance_blk[c] /= (N * SP); - } - - if (SP_N_nthr > 1) mkldnn_thr_barrier(); - } - - for (dim_t c = C_blk_s; c < C_blk_e; c++) { - size_t off = c + C_off; - data_t sqrt_variance - = static_cast(sqrtf(variance[off] + eps)); - data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance; - data_t sv = use_scaleshift ? scaleshift[C + off] : 0; - for (dim_t n = N_s; n < N_e; ++n) -#if SAFE_TO_USE_OMP_SIMD - PRAGMA_OMP_SIMD() -#endif - for (dim_t sp = S_s; sp < S_e; ++sp) { - size_t d_off = off * SP + n * C * SP + sp; - data_t bn_res - = sm * (src[d_off] - mean[off]) + sv; - if (fuse_bn_relu) { - if (bn_res <= 0) { - bn_res = 0; - if (is_training) - ws[d_off] = 0; - } else { - if (is_training) - ws[d_off] = 1; - } - } - dst[d_off] = maybe_post_op(bn_res); - } - } - } - }); -} - -void ncsp_batch_normalization_bwd_t::execute_backward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); - auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); - - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); - - auto scratchpad = this->scratchpad(ctx); - auto *ws_reduce = scratchpad.get(key_bnorm_reduction); - - if (diff_scaleshift == nullptr) - diff_scaleshift = scratchpad.get(key_bnorm_tmp_diff_ss); - - const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); - dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; - dim_t C = pd()->C(), N = pd()->MB(); - const bool use_scaleshift = pd()->use_scaleshift(); - const float eps = pd()->desc()->batch_norm_epsilon; - const bool calculate_diff_stats = !pd()->use_global_stats(); - const bool fuse_bn_relu = pd()->fuse_bn_relu(); - - int nthr = mkldnn_get_max_threads(); - size_t l3_size_ = get_cache_size(3, true) * nthr / 2; - size_t data_size = N * C * SP * sizeof(data_t); - bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); - - parallel(0, [&](const int ithr, const int nthr) { - int C_ithr = 0, C_nthr = 0; - int N_ithr = 0, N_nthr = 0; - int S_ithr = 0, S_nthr = 0; - - dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; - dim_t N_s = 0, N_e = 0; - dim_t S_s = 0, S_e = 0; - - dim_t C_blks_per_iter = 1; - int64_t iters = 1; - - if (do_blocking) { - size_t working_set_size = 2 * N * SP * sizeof(data_t); - bnorm_utils::cache_balance( - working_set_size, C, C_blks_per_iter, iters); - } else - C_blks_per_iter = C; - int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; - bool spatial_thr_allowed - = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, - C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, - N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); - balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); - int SP_N_ithr = N_ithr * S_nthr + S_ithr; - int SP_N_nthr = N_nthr * S_nthr; - - for (int64_t it = 0; it < iters; ++it) { - if (it == iters - 1 && iters > 1) { - // On the last iteration the access pattern to ws_reduce - // might change (due to re-balance on C). So sync the - // threads if they are not synced by the algorithm. - if (SP_N_nthr == 1 && mkldnn_thr_syncable()) - mkldnn_thr_barrier(); - - C_blk_s = C_blk_e = N_s = N_e = 0; - spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, - spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, - C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, - N_e, S_ithr, S_nthr, S_s, S_e); - balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); - SP_N_ithr = N_ithr * S_nthr + S_ithr; - SP_N_nthr = N_nthr * S_nthr; - } - size_t C_off = it * C_blks_per_iter; - // On the last iteration the access pattern to ws_reduce - // might change (due to re-balance on C). Since sync is not always - // possible (in case of TBB) use different parts of ws for each - // iteration if threads are not synced by the algorithm. - size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off; - - data_t *diff_gamma_blk = diff_scaleshift + C_off; - data_t *diff_beta_blk = diff_scaleshift + C + C_off; - for (dim_t c = C_blk_s; c < C_blk_e; c++) { - size_t off = c + C_off; - data_t diff_gamma = 0.0, diff_beta = 0.0; - data_t v_mean = mean[off]; - for (dim_t n = N_s; n < N_e; ++n) - PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta)) - for (dim_t sp = S_s; sp < S_e; ++sp) { - const size_t d_off = off * SP + n * C * SP + sp; - data_t dd; - if (fuse_bn_relu) - dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; - else - dd = diff_dst[d_off]; - diff_gamma += (src[d_off] - v_mean) * dd; - diff_beta += dd; - } - ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] - = diff_gamma; - ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter - + SP_N_ithr * C_blks_per_iter + c] = diff_beta; - } - - if (SP_N_nthr > 1) mkldnn_thr_barrier(); - - for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { - data_t sqrt_variance = static_cast( - 1.0f / sqrtf(variance[c + C_off] + eps)); - diff_gamma_blk[c] = 0.; - diff_beta_blk[c] = 0.; - for (dim_t n = 0; n < SP_N_nthr; n++) { - diff_gamma_blk[c] += ws_reduce[ws_iter_off - + n * C_blks_per_iter + c]; - diff_beta_blk[c] += ws_reduce[ws_iter_off - + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter - + c]; - } - diff_gamma_blk[c] *= sqrt_variance; - } - - if (SP_N_nthr > 1) mkldnn_thr_barrier(); - - for (dim_t c = C_blk_s; c < C_blk_e; c++) { - size_t off = c + C_off; - data_t gamma = use_scaleshift ? scaleshift[off] : 1; - data_t sqrt_variance - = static_cast(1.0f / sqrtf(variance[off] + eps)); - data_t v_mean = mean[off]; - for (dim_t n = N_s; n < N_e; ++n) -#if SAFE_TO_USE_OMP_SIMD - PRAGMA_OMP_SIMD() -#endif - for (dim_t sp = S_s; sp < S_e; ++sp) { - const size_t d_off = off * SP + n * C * SP + sp; - - data_t v_diff_src; - if (fuse_bn_relu) - v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; - else - v_diff_src = diff_dst[d_off]; - if (calculate_diff_stats) { - v_diff_src -= diff_beta_blk[c] / (SP * N) - + (src[d_off] - v_mean) * diff_gamma_blk[c] - * sqrt_variance / (SP * N); - } - v_diff_src *= gamma * sqrt_variance; - diff_src[d_off] = v_diff_src; - } - } - } - }); -} -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp deleted file mode 100644 index 97ca3b003..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp +++ /dev/null @@ -1,160 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP -#define CPU_NCSP_BATCH_NORMALIZATION_HPP - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_batch_normalization_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_batch_normalization_fwd_pd_t { - using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t; - - DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t); - - status_t init() { - using namespace data_type; - using namespace prop_kind; - using namespace format_tag; - - bool ok = true - && is_fwd() - && !has_zero_dim_memory() - && src_md()->data_type == f32 - && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) - && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) - && (attr()->has_default_values() || this->with_relu_post_op()); - if (!ok) return status::unimplemented; - - if (is_training() && fuse_bn_relu()) init_default_ws(8); - - init_scratchpad(); - - return status::success; - } - - private: - void init_scratchpad() { - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - if (!stats_is_src()) { - scratchpad.book(key_bnorm_reduction, - sizeof(data_t) * C() * mkldnn_get_max_threads()); - - if (!is_training()) { - scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C()); - scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C()); - } - } - } - }; - - typedef typename prec_traits::type data_t; - - ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - ~ncsp_batch_normalization_fwd_t() {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t { - struct pd_t : public cpu_batch_normalization_bwd_pd_t { - using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t; - - DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t); - - status_t init() { - using namespace data_type; - using namespace format_tag; - - bool ok = true - && is_bwd() - && !has_zero_dim_memory() - && utils::everyone_is(f32, src_md()->data_type, - diff_src_md()->data_type) - && IMPLICATION(use_scaleshift(), - utils::everyone_is(f32, - weights_md()->data_type, - diff_weights_md()->data_type)) - && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) - && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - if (fuse_bn_relu()) { - init_default_ws(8); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - init_scratchpad(); - - return status::success; - } - - private: - void init_scratchpad() { - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(key_bnorm_reduction, - sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); - if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward)) - scratchpad.book(key_bnorm_tmp_diff_ss, - sizeof(data_t) * 2 * C()); - } - }; - - typedef typename prec_traits::type data_t; - - ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - ~ncsp_batch_normalization_bwd_t() {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp deleted file mode 100644 index 38cfb28dc..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp +++ /dev/null @@ -1,392 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" - -#include "nhwc_pooling.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -#define MEM_D(name) name##_d - -#define DECLARE_READ_STRIDES(name) \ - const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \ - const size_t name##_d_stride = (!is_3d) \ - ? 0 \ - : MEM_D(name).blocking_desc().strides[2]; \ - const size_t name##_h_stride = (!is_3d) \ - ? MEM_D(name).blocking_desc().strides[2] \ - : MEM_D(name).blocking_desc().strides[3]; \ - const size_t name##_w_stride = (!is_3d) \ - ? MEM_D(name).blocking_desc().strides[3] \ - : MEM_D(name).blocking_desc().strides[4]; - -namespace nhwc_pooling { - size_t strided_offset(const int _n, const size_t _sn, - const int _d, const size_t _sd, - const int _h, const size_t _sh, - const int _w, const size_t _sw) - { - return _n * _sn - + _d * _sd - + _h * _sh - + _w * _sw; - } -} - -template -void nhwc_pooling_fwd_t::array_div_by_const(const int n, - const data_t *src, const size_t num, data_t *dst) const -{ - for (int i = 0; i < n; ++i) - { - float ftmp = (float)src[i]; - ftmp = ftmp / num; - dst[i] = math::out_round(ftmp); - } -} - -template -void nhwc_pooling_fwd_t::array_add(const int n, const data_t *src, - data_t *dst) const -{ - for (int i = 0; i < n; ++i) - { - dst[i] += src[i]; - } -} - -template -void nhwc_pooling_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - using namespace alg_kind; - using namespace prop_kind; - using namespace nhwc_pooling; - - auto alg = pd()->desc()->alg_kind; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); - - const memory_desc_wrapper MEM_D(src)(pd()->src_md()); - const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); - const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); - - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - const int SD = pd()->KSD(); - const int SH = pd()->KSH(); - const int SW = pd()->KSW(); - const int padF = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - const int MB = pd()->MB(); - const int OC = pd()->C(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - - const bool is_3d = pd()->desc()->src_desc.ndims == 5; - const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; - - DECLARE_READ_STRIDES(src); - DECLARE_READ_STRIDES(dst); - - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; - - parallel_nd(MB, OD, OH, OW, - [&](int mb, int od, int oh, int ow) { - size_t dst_offset_init = strided_offset(mb, dst_n_stride, - od, dst_d_stride, - oh, dst_h_stride, - ow, dst_w_stride); - if (alg == pooling_max) { - size_t ws_offset_init = 0; - if (ws) - { - DECLARE_READ_STRIDES(ws); - ws_offset_init = strided_offset(mb, ws_n_stride, - od, ws_d_stride, - oh, ws_h_stride, - ow, ws_w_stride); - } - // Note: GCC 4.8.5 won't vectorize below - // simple loops unless they are singled out - // into separate helper routines: - // array_nhwc_initialize, array_nhwc_max - if (!ws) - array_nhwc_initialize(OC, dst + dst_offset_init, - ws, ws_offset_init, ws_dt); - else - array_nhwc_initialize(OC, dst + dst_offset_init, - ws, ws_offset_init, ws_dt); - - - for (int kd = 0; kd < KD; ++kd) - for (int kh = 0; kh < KH; ++kh) - for (int kw = 0; kw < KW; ++kw) { - const int id = od * SD - padF + kd; - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - if (id < 0 || id >= ID) - continue; - if (ih < 0 || ih >= IH) - continue; - if (iw < 0 || iw >= IW) - continue; - - size_t src_offset_init = strided_offset(mb, src_n_stride, - id, src_d_stride, - ih, src_h_stride, - iw, src_w_stride); - - if (!ws) - array_nhwc_max(OC, - dst + dst_offset_init, - src + src_offset_init, - ws, ws_offset_init, - ws_dt, - kd * KH * KW + kh * KW + kw - ); - else - array_nhwc_max(OC, - dst + dst_offset_init, - src + src_offset_init, - ws, ws_offset_init, - ws_dt, - kd * KH * KW + kh * KW + kw - ); - } - } else { - // pooling_avg - auto d = dst + dst_offset_init; - - utils::array_set(d, 0, OC); - - auto id_start = apply_offset(od * SD, padF); - auto ih_start = apply_offset(oh * SH, padT); - auto iw_start = apply_offset(ow * SW, padL); - auto id_end = nstl::min(od * SD - padF + KD, ID); - auto ih_end = nstl::min(oh * SH - padT + KH, IH); - auto iw_end = nstl::min(ow * SW - padL + KW, IW); - - // it is cheaper to actually count this in a loop - // as the typical kernel is small - size_t num_summands = 0; - - for (int id = id_start; id < id_end; ++id) - for (int ih = ih_start; ih < ih_end; ++ih) - for (int iw = iw_start; iw < iw_end; ++iw) { - size_t src_offset_init = strided_offset(mb, src_n_stride, - id, src_d_stride, - ih, src_h_stride, - iw, src_w_stride); - auto s = src + src_offset_init; - - // need to move the loop to separate function - // for GCC 4.8.5 to vectorize - array_add(OC, s, d); - - num_summands++; - } - - num_summands = (alg == pooling_avg_include_padding) ? - KW * KH * KD : num_summands; - - // need to move the loop to separate function - // for GCC 4.8.5 to vectorize - array_div_by_const(OC, d, num_summands, d); - } - }); -} - -template -void nhwc_pooling_bwd_t::execute_backward( - const exec_ctx_t &ctx) const { - using namespace alg_kind; - using namespace nhwc_pooling; - - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); - const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); - const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); - - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - const int SD = pd()->KSD(); - const int SH = pd()->KSH(); - const int SW = pd()->KSW(); - const int OC = pd()->C(); - const int padF = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - - const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; - auto alg = pd()->desc()->alg_kind; - - DECLARE_READ_STRIDES(diff_src); - DECLARE_READ_STRIDES(diff_dst); - - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; - - const int MB = pd()->MB(); - - parallel_nd(MB, ID, IH, IW, - [&](int mb, int id, int ih, int iw) { - size_t src_offset_init = strided_offset(mb, diff_src_n_stride, - id, diff_src_d_stride, - ih, diff_src_h_stride, - iw, diff_src_w_stride); - - // check if kernel windows are disjoint, in this case there's no - // update needed and we just write there once, no initialization - // required. - if (!(KD == SD && KH == SH && KW == SW)) - for (int oc = 0; oc < OC; ++oc) - diff_src[src_offset_init + oc] = data_type_t(0); - - // Find out which output cells may correspond to current - // input position. Current input postition divided by - // stride, with integer divide rounding down, is the - // right-most output. - // Left-most output may be computed if we decrement input - // by (kernel_size - 1) and then do the same division by - // stride. - int od_left = nstl::max((id + padF - KD + 1) / SD, 0); - int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0); - int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0); - // Notice +1 here to preserve the C loop "less than" - // condition for continuing the for loop. - int od_right = nstl::min((id + padF) / SD + 1 , OD); - int oh_right = nstl::min((ih + padT) / SH + 1 , OH); - int ow_right = nstl::min((iw + padL) / SW + 1 , OW); - - for (int od = od_left; od < od_right; ++od) - for (int oh = oh_left; oh < oh_right; ++oh) - for (int ow = ow_left; ow < ow_right; ++ow) { - const int kd = id - od*SD + padF; - const int kh = ih - oh*SH + padT; - const int kw = iw - ow*SW + padL; - - if (kd < 0 || kd >= KD) - continue; - if (kh < 0 || kh >= KH) - continue; - if (kw < 0 || kw >= KW) - continue; - - size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, - od, diff_dst_d_stride, - oh, diff_dst_h_stride, - ow, diff_dst_w_stride); - - if (alg == pooling_max) { - DECLARE_READ_STRIDES(ws); - size_t ws_offset_init = strided_offset(mb, ws_n_stride, - od, ws_d_stride, - oh, ws_h_stride, - ow, ws_w_stride); - const int index = kd * KH * KW + kh * KW + kw; - - PRAGMA_OMP_SIMD() - for (int oc = 0; oc < OC; ++oc) { - const int index_from_ws = - (MEM_D(ws).data_type() == data_type::u8) - ? (int)ws[ws_offset_init + oc] - : ((int *)ws)[ws_offset_init + oc]; - - const data_t d = diff_dst[dst_offset_init + oc]; - - // Check if kernel windows are disjoint, in this case - // there's no update needed and we just write there once - // otherwise we add value to the contents. - if (!(KD == SD && KH == SH && KW == SW)) - diff_src[src_offset_init + oc] += - (index_from_ws == index) - ? d - : data_type_t(0); - else - diff_src[src_offset_init + oc] = - (index_from_ws == index) - ? d - : data_type_t(0); - } - } else { - // pooling_avg - auto id_start = apply_offset(od*SD, padF); - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto id_end = nstl::min(od*SD - padF + KD, ID); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - auto num_summands = (alg == pooling_avg_include_padding) - ? KW*KH*KD - : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); - - PRAGMA_OMP_SIMD() - for (int oc = 0; oc < OC; ++oc) { - const data_t d = diff_dst[dst_offset_init + oc]; - // Check if kernel windows are disjoint, in this case - // there's no update needed and we just write there once - // otherwise we add value to the contents. - if (!(KD == SD && KH == SH && KW == SW)) - diff_src[src_offset_init + oc] += d / num_summands; - else - diff_src[src_offset_init + oc] = d / num_summands; - } - } - } - }); -} - -template struct nhwc_pooling_fwd_t; -template struct nhwc_pooling_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp deleted file mode 100644 index 7e33b6869..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp +++ /dev/null @@ -1,210 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_NHWC_POOLING_HPP -#define CPU_NHWC_POOLING_HPP - -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_pooling_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace nhwc_pooling { -size_t strided_offset(const int _n, const size_t _sn, const int _d, - const size_t _sd, const int _h, const size_t _sh, const int _w, - const size_t _sw); -} - -template -struct nhwc_pooling_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_fwd_pd_t { - using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; - - DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t); - - status_t init() { - const format_tag_t desired_fmt_tag = - ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc; - - bool ok = true - && set_default_params() == status::success - && is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, - alg_kind::pooling_avg_include_padding, - alg_kind::pooling_avg_exclude_padding) - && utils::everyone_is(data_type, - src_md()->data_type, - dst_md()->data_type) - && attr()->has_default_values() - && memory_desc_matches_tag(*src_md(), desired_fmt_tag) - && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); - if (!ok) return status::unimplemented; - - bool is_training = desc_.prop_kind == prop_kind::forward_training; - if (desc()->alg_kind == alg_kind::pooling_max && is_training) - init_default_ws(); - - return status::success; - } - }; - - nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - void array_div_by_const(const int n, const data_t *src, const size_t num, - data_t *dst) const; - void array_add(const int n, const data_t *src, data_t *dst) const; - - template - void array_nhwc_max(const int n, data_t *dst, const data_t *src, - unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, - const int index) const { - assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists - PRAGMA_OMP_SIMD() - for (int oc = 0; oc < n; ++oc) { - auto s = src[oc]; - data_t mv = dst[oc]; - - // update index of maximum -#if defined __INTEL_COMPILER - if ((use_workspace) && (s > mv)) { - assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); - if (ws_dt == data_type::u8) { - assert(0 <= index && index <= 255); - ws[ws_offset + oc] = index; - } else - reinterpret_cast(ws)[ws_offset + oc] = index; - } -#else - // Need to add explicit predicates for GCC to vectorize this. - // And although the resulting code is ugly, it is still 4 times - // faster than scalar - if (use_workspace) { - assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); - - if (ws_dt == data_type::u8) { - assert(0 <= index && index <= 255); - unsigned char predicate = (s > mv) ? 0xff : 0; - unsigned char current_value = ws[ws_offset + oc]; - current_value = (predicate & (unsigned char)index) - | ((~predicate) & current_value); - ws[ws_offset + oc] = current_value; - } else { - auto wint = reinterpret_cast(ws); - unsigned int predicate = (s > mv) ? 0xffffffff : 0; - unsigned int current_value = wint[ws_offset + oc]; - current_value = (predicate & (unsigned int)index) - | ((~predicate) & current_value); - wint[ws_offset + oc] = current_value; - } - } -#endif - // update maximum - dst[oc] = nstl::max(s, mv); - } - } - - template - void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws, - const size_t ws_offset, const data_type_t ws_dt) const { - assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists - for (int oc = 0; oc < n; ++oc) { - if (use_workspace) { - assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); - if (ws_dt == data_type::u8) { - ws[ws_offset + oc] = 0; - } else - reinterpret_cast(ws)[ws_offset + oc] = 0; - } - dst[oc] = nstl::numeric_limits::lowest(); - } - } - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct nhwc_pooling_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_bwd_pd_t { - using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; - - DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t); - - status_t init() { - const format_tag_t desired_fmt_tag = - ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; - - bool ok = true - && set_default_params() == status::success - && !is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, - alg_kind::pooling_avg_include_padding, - alg_kind::pooling_avg_exclude_padding) - && utils::everyone_is(data_type, - diff_dst_md()->data_type, - diff_src_md()->data_type) - && attr()->has_default_values() - && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) - && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); - if (!ok) return status::unimplemented; - - if (desc()->alg_kind == alg_kind::pooling_max) { - init_default_ws(); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - return status::success; - } - }; - - nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -}// namespace cpu -}// namespace impl -}// namespace mkldnn - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp deleted file mode 100644 index e20333e66..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp +++ /dev/null @@ -1,288 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" - -#include "cpu_batch_normalization_utils.hpp" -#include "jit_generator.hpp" - -#include "nspc_batch_normalization.hpp" - -// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases -#if (defined __clang_major__) && (__clang_major__ >= 6) -#define SAFE_TO_USE_OMP_SIMD 0 -#else -#define SAFE_TO_USE_OMP_SIMD 1 -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace memory_tracking::names; - -void nspc_batch_normalization_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - const bool save_stats = pd()->is_training(); - const bool is_training = pd()->is_training(); - const bool fuse_bn_relu = pd()->fuse_bn_relu(); - const bool calculate_stats = !pd()->stats_is_src(); - const bool with_relu = pd()->with_relu_post_op(); - - auto scratchpad = this->scratchpad(ctx); - auto tmp_mean = scratchpad.get(key_bnorm_tmp_mean); - auto tmp_var = scratchpad.get(key_bnorm_tmp_var); - auto *ws_reduce = scratchpad.get(key_bnorm_reduction); - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - - data_t *mean, *variance; - if (!calculate_stats) { - mean = const_cast( - CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); - variance = const_cast( - CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); - } else { - if (save_stats) { - mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); - variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); - } else { - mean = tmp_mean; - variance = tmp_var; - } - } - - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); - - const dim_t N = pd()->MB(); - const dim_t C = pd()->C(); - const dim_t SP = pd()->H() * pd()->W() * pd()->D(); - - const float eps = pd()->desc()->batch_norm_epsilon; - const bool use_scaleshift = pd()->use_scaleshift(); - auto maybe_post_op - = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; - - assert(mkldnn_thr_syncable()); - parallel(0, [&](const int ithr, const int nthr) { - dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; - balance211(N, nthr, ithr, N_s, N_e); - balance211(C, nthr, ithr, C_s, C_e); - data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; - data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; - - if (calculate_stats) { - for (dim_t c = 0; c < C; c++) - ws_reduce[C * ithr + c] = 0.; - - for (dim_t n = N_s; n < N_e; n++) - for (dim_t sp = 0; sp < SP; sp++) - PRAGMA_OMP_SIMD() - for (dim_t c = 0; c < C; c++) - ws_reduce[C * ithr + c] += src[(size_t)n * SP * C - + sp * C + c]; - - mkldnn_thr_barrier(); - - for (dim_t c = C_s; c < C_e; c++) { - mean[c] = 0; - for (dim_t n = 0; n < nthr; n++) - mean[c] += ws_reduce[C * n + c]; - mean[c] /= SP * N; - } - - mkldnn_thr_barrier(); - - for (dim_t c = 0; c < C; c++) { - mean_loc[c] = mean[c]; - ws_reduce[C * ithr + c] = 0.; - } - - for (dim_t n = N_s; n < N_e; n++) - for (dim_t sp = 0; sp < SP; sp++) - PRAGMA_OMP_SIMD() - for (dim_t c = 0; c < C; c++) { - data_t m = src[(size_t)n * SP * C + sp * C + c] - - mean_loc[c]; - ws_reduce[C * ithr + c] += m * m; - } - - mkldnn_thr_barrier(); - - for (dim_t c = C_s; c < C_e; c++) { - variance[c] = 0; - for (dim_t n = 0; n < nthr; n++) - variance[c] += ws_reduce[C * n + c]; - variance[c] /= SP * N; - } - - mkldnn_thr_barrier(); - - for (dim_t c = 0; c < C; c++) - variance_loc[c] = variance[c]; - } else { - variance_loc = variance; - mean_loc = mean; - } - - for (dim_t n = N_s; n < N_e; n++) { - for (dim_t sp = 0; sp < SP; sp++) { -#if SAFE_TO_USE_OMP_SIMD - PRAGMA_OMP_SIMD() -#endif - for (dim_t c = 0; c < C; c++) { - data_t sqrt_variance = static_cast( - sqrtf(variance_loc[c] + eps)); - data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; - data_t sv = use_scaleshift ? scaleshift[C + c] : 0; - size_t d_off = (size_t)n * SP * C + sp * C + c; - data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; - if (fuse_bn_relu) { - if (bn_res <= 0) { - bn_res = 0; - if (is_training) - ws[d_off] = 0; - } else { - if (is_training) - ws[d_off] = 1; - } - } - dst[d_off] = maybe_post_op(bn_res); - } - } - } - }); -} - -void nspc_batch_normalization_bwd_t::execute_backward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); - auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); - - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); - - auto scratchpad = this->scratchpad(ctx); - auto tmp_diff_ss = scratchpad.get(key_bnorm_tmp_diff_ss); - - if (diff_scaleshift == nullptr) - diff_scaleshift = tmp_diff_ss; - - const dim_t N = pd()->MB(); - const dim_t C = pd()->C(); - const dim_t SP = pd()->D() * pd()->H() * pd()->W(); - data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; - auto *ws_reduce = scratchpad.get(key_bnorm_reduction); - - const float eps = pd()->desc()->batch_norm_epsilon; - const bool use_scaleshift = pd()->use_scaleshift(); - const bool calculate_diff_stats = !pd()->use_global_stats(); - const bool fuse_bn_relu = pd()->fuse_bn_relu(); - - assert(mkldnn_thr_syncable()); - parallel(0, [&](const int ithr, const int nthr) { - dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; - balance211(N, nthr, ithr, N_s, N_e); - balance211(C, nthr, ithr, C_s, C_e); - - data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; - data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); - - for (dim_t c = 0; c < C; c++) { - ws_reduce[C * ithr + c] = 0.; - ws_reduce[C * nthr + C * ithr + c] = 0.; - } - - for (dim_t n = N_s; n < N_e; n++) - for (dim_t sp = 0; sp < SP; sp++) -#if SAFE_TO_USE_OMP_SIMD - PRAGMA_OMP_SIMD() -#endif - for (dim_t c = 0; c < C; c++) { - const size_t d_off = (size_t)n * SP * C + sp * C + c; - data_t dd; - if (fuse_bn_relu) - dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; - else - dd = diff_dst[d_off]; - ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd; - ws_reduce[C * nthr + C * ithr + c] += dd; - } - - mkldnn_thr_barrier(); - - for (dim_t c = C_s; c < C_e; c++) { - data_t sqrt_variance - = static_cast(1.0f / sqrtf(variance[c] + eps)); - diff_gamma[c] = 0; - diff_beta[c] = 0; - for (dim_t n = 0; n < nthr; n++) { - diff_gamma[c] += ws_reduce[C * n + c]; - diff_beta[c] += ws_reduce[C * nthr + C * n + c]; - } - diff_gamma[c] *= sqrt_variance; - } - - mkldnn_thr_barrier(); - - for (dim_t c = 0; c < C; c++) { - diff_gamma_loc[c] = diff_gamma[c]; - diff_beta_loc[c] = diff_beta[c]; - } - - for (dim_t n = N_s; n < N_e; n++) { - for (dim_t sp = 0; sp < SP; sp++) { -#if SAFE_TO_USE_OMP_SIMD - PRAGMA_OMP_SIMD() -#endif - for (dim_t c = 0; c < C; c++) { - const size_t d_off = (size_t)n * SP * C + sp * C + c; - data_t gamma = use_scaleshift ? scaleshift[c] : 1; - data_t sqrt_variance - = static_cast(1.0f / sqrtf(variance[c] + eps)); - data_t v_diff_src; - if (fuse_bn_relu) - v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; - else - v_diff_src = diff_dst[d_off]; - if (calculate_diff_stats) { - v_diff_src -= diff_beta_loc[c] / (SP * N) - + (src[d_off] - mean[c]) * diff_gamma_loc[c] - * sqrt_variance / (SP * N); - } - v_diff_src *= gamma * sqrt_variance; - diff_src[d_off] = v_diff_src; - } - } - } - }); -} - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp deleted file mode 100644 index aad86b05a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp +++ /dev/null @@ -1,169 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP -#define CPU_NSPC_BATCH_NORMALIZATION_HPP - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_batch_normalization_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct nspc_batch_normalization_fwd_t : public cpu_primitive_t { - struct pd_t : public cpu_batch_normalization_fwd_pd_t { - pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t); - - status_t init() { - using namespace data_type; - using namespace prop_kind; - - bool ok = true - /* the algorithm requires barriers while switching - * between parallelization over N and C dimensions */ - && mkldnn_thr_syncable() - && is_fwd() - && !has_zero_dim_memory() - && src_md()->data_type == f32 - && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) - && memory_desc_matches_tag(*src_md(), format_tag::nhwc) - && (attr()->has_default_values() || this->with_relu_post_op()); - if (!ok) return status::unimplemented; - - if (is_training() && fuse_bn_relu()) init_default_ws(8); - - init_scratchpad(); - - return status::success; - } - - private: - void init_scratchpad() { - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - if (!stats_is_src()) { - dim_t sz = nstl::max(C(), 16) * mkldnn_get_max_threads(); - scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz); - scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz); - scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz); - } - } - }; - - typedef typename prec_traits::type data_t; - - nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - ~nspc_batch_normalization_fwd_t() {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -struct nspc_batch_normalization_bwd_t : public cpu_primitive_t { - struct pd_t : public cpu_batch_normalization_bwd_pd_t { - pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t); - - status_t init() { - using namespace data_type; - using namespace prop_kind; - - bool ok = true - /* the algorithm requires barriers while switching - * between parallelization over N and C dimensions */ - && mkldnn_thr_syncable() - && is_bwd() - && !has_zero_dim_memory() - && utils::everyone_is(f32, src_md()->data_type, - diff_src_md()->data_type) - && IMPLICATION(use_scaleshift(), - utils::everyone_is(f32, - weights_md()->data_type, - diff_weights_md()->data_type)) - && memory_desc_matches_tag(*src_md(), format_tag::nhwc) - && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - if (fuse_bn_relu()) { - init_default_ws(8); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - init_scratchpad(); - - return status::success; - } - - private: - void init_scratchpad() { - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(key_bnorm_reduction, - sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); - scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C() - * (mkldnn_get_max_threads() + 1)); - } - }; - - typedef typename prec_traits::type data_t; - - nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - ~nspc_batch_normalization_bwd_t() {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp deleted file mode 100644 index d79b1a034..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "simple_q10n.hpp" - -#include "ref_batch_normalization.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -void ref_batch_normalization_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - /* fast return */ - if (this->pd()->has_zero_dim_memory()) return; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT); - - auto mean = pd()->stats_is_src() - ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN)) - : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN); - auto variance = pd()->stats_is_src() - ? const_cast(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE)) - : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE); - - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); - - const memory_desc_wrapper data_d(pd()->src_md()); - const memory_desc_wrapper scaleshift_d(pd()->weights_md()); - - const dim_t N = pd()->MB(); - const dim_t C = pd()->C(); - dim_t H = 1, W = 1, D = 1; - const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); - if (has_spatial) { - D = pd()->D(); - H = pd()->H(); - W = pd()->W(); - } - - const float eps = pd()->desc()->batch_norm_epsilon; - const bool use_scaleshift = pd()->use_scaleshift();; - const bool save_stats = pd()->is_training(); - const bool is_training = pd()->is_training(); - const bool fuse_bn_relu = pd()->fuse_bn_relu(); - const bool calculate_stats = !pd()->stats_is_src(); - - const bool with_relu = pd()->with_relu_post_op(); - auto maybe_post_op = [&](float res) { - return (with_relu && res < 0.0f) ? 0.0f : res; - }; - const bool is_3d = data_d.ndims() == 5; - - auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, - dim_t d, dim_t h, dim_t w) { - if (has_spatial) { - if (is_3d) - return data_d.off(n, c, d, h, w); - else - return data_d.off(n, c, h, w); - } else - return data_d.off(n, c); - }; - - parallel_nd(C, [&](dim_t c) { - float v_mean = calculate_stats ? 0 : mean[c]; - float v_variance = calculate_stats ? 0 : variance[c]; - - if (calculate_stats) { - for (dim_t n = 0; n < N; ++n) - for (dim_t d = 0; d < D; ++d) - for (dim_t h = 0; h < H; ++h) - for (dim_t w = 0; w < W; ++w) - v_mean += src[data_offset(data_d, n, c, d, h, w)]; - v_mean /= W*N*H*D; - - for (dim_t n = 0; n < N; ++n) - for (dim_t d = 0; d < D; ++d) - for (dim_t h = 0; h < H; ++h) - for (dim_t w = 0; w < W; ++w) { - float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean; - v_variance += m*m; - } - v_variance /= W*H*N*D; - } - - float sqrt_variance = sqrtf(v_variance + eps); - float sm = (use_scaleshift - ? scaleshift[scaleshift_d.off(0, c)] - : 1.0f) / sqrt_variance; - float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; - - for (dim_t n = 0; n < N; ++n) - for (dim_t d = 0; d < D; ++d) - for (dim_t h = 0; h < H; ++h) - for (dim_t w = 0; w < W; ++w) { - auto d_off = data_offset(data_d,n,c,d,h,w); - float bn_res = sm * ((float)src[d_off] - v_mean) + sv; - if (fuse_bn_relu) { - if (bn_res <= 0) { - bn_res = 0; - if (is_training) - ws[d_off] = 0; - } else { - if (is_training) - ws[d_off] = 1; - } - } - if (data_type == data_type::s8) { - dst[d_off] = qz_a1b0()(maybe_post_op(bn_res)); - } else { - dst[d_off] = static_cast(maybe_post_op(bn_res)); - } - } - - if (calculate_stats) { - if (save_stats) { - mean[c] = v_mean; - variance[c] = v_variance; - } - } - }); -} - -template struct ref_batch_normalization_fwd_t; -template struct ref_batch_normalization_fwd_t; - -template -void ref_batch_normalization_bwd_t::execute_backward( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); - auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); - auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); - - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); - - const memory_desc_wrapper data_d(pd()->src_md()); - const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); - const memory_desc_wrapper scaleshift_d(pd()->weights_md()); - const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md()); - - const dim_t C = pd()->C(); - - /* fast return */ - if (this->pd()->has_zero_dim_memory()) { - if (diff_scaleshift) { - for (dim_t c = 0; c < C; ++c) { - diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0; - diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0; - } - } - return; - } - - const dim_t N = pd()->MB(); - dim_t H = 1, W = 1, D = 1; - const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); - if (has_spatial) { - D = pd()->D(); - H = pd()->H(); - W = pd()->W(); - } - - const float eps = pd()->desc()->batch_norm_epsilon; - const bool use_scaleshift = pd()->use_scaleshift(); - const bool calculate_diff_stats = !pd()->use_global_stats(); - const bool fuse_bn_relu = pd()->fuse_bn_relu(); - - const bool is_3d = data_d.ndims() == 5; - - auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, - dim_t d, dim_t h, dim_t w) { - if (has_spatial) { - if (is_3d) - return data_d.off(n, c, d, h, w); - else - return data_d.off(n, c, h, w); - } else - return data_d.off(n, c); - }; - - parallel_nd(C, [&](dim_t c) { - data_t v_mean = mean[c]; - data_t v_variance = variance[c]; - data_t sqrt_variance = static_cast(1.0f / sqrtf(v_variance + eps)); - data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1; - data_t diff_gamma = data_t(0); - data_t diff_beta = data_t(0); - diff_gamma = 0.0; - diff_beta = 0.0; - - for (dim_t n = 0; n < N; ++n) - for (dim_t d = 0; d < D; ++d) - for (dim_t h = 0; h < H; ++h) - for (dim_t w = 0; w < W; ++w) { - const size_t s_off = data_offset(data_d, n, c, d, h, w); - data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)]; - if (fuse_bn_relu && !ws[s_off]) - dd = 0; - - diff_gamma += (src[s_off] - v_mean) * dd; - diff_beta += dd; - } - diff_gamma *= sqrt_variance; - - if (diff_scaleshift) { - diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma; - diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta; - } - - for (dim_t n = 0; n < N; ++n) - for (dim_t d = 0; d < D; ++d) - for (dim_t h = 0; h < H; ++h) - for (dim_t w = 0; w < W; ++w) { - const size_t s_off = data_offset(data_d, n, c, d, h, w); - const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w); - data_t dd = diff_dst[dd_off]; - if (fuse_bn_relu && !ws[s_off]) - dd = 0; - - data_t v_diff_src = dd; - if (calculate_diff_stats) { - v_diff_src -= diff_beta/(D*W*H*N) + - (src[s_off] - v_mean) * - diff_gamma*sqrt_variance/(D*W*H*N); - } - v_diff_src *= gamma*sqrt_variance; - diff_src[dd_off] = v_diff_src; - } - }); -} - -template struct ref_batch_normalization_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp deleted file mode 100644 index aa9f74125..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_BATCH_NORMALIZATION_HPP -#define CPU_REF_BATCH_NORMALIZATION_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_batch_normalization_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_batch_normalization_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_batch_normalization_fwd_pd_t { - pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && src_md()->data_type == data_type - && IMPLICATION(use_scaleshift(), - weights_md()->data_type == data_type::f32) - && (attr()->has_default_values() || with_relu_post_op()); - if (!ok) return status::unimplemented; - - if (src_md()->data_type == data_type::s8 && !stats_is_src()) - return status::unimplemented; - - if (is_training() && fuse_bn_relu()) init_default_ws(8); - - return status::success; - } - }; - - ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_batch_normalization_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_batch_normalization_bwd_pd_t { - pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, - const primitive_attr_t *attr, - const batch_normalization_fwd_pd_t *hint_fwd_pd) - : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) - {} - - DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t); - - status_t init() { - bool ok = true - && is_bwd() - && utils::everyone_is(data_type, src_md()->data_type, - diff_src_md()->data_type) - && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type, - weights_md()->data_type, - diff_weights_md()->data_type)) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - if (fuse_bn_relu()) { - init_default_ws(8); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - return status::success; - } - }; - - ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp deleted file mode 100644 index 4c534b550..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp +++ /dev/null @@ -1,97 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef REF_CONCAT_HPP -#define REF_CONCAT_HPP - -#include "reorder_pd.hpp" - -#include "cpu_concat_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct ref_concat_t: public cpu_primitive_t { - struct pd_t: public cpu_concat_pd_t { - using cpu_concat_pd_t::cpu_concat_pd_t; - - pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { - for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) - reorder_pds_.push_back( - (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); - } - ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } - - DECLARE_CONCAT_PD_T("ref:any", ref_concat_t); - - status_t init() { - bool ok = cpu_concat_pd_t::init() == status::success; - if (!ok) return status::unimplemented; - - for (int i = 0; i < n_; ++i) { - auto r_impls = engine_->get_reorder_implementation_list(); - for (auto r = r_impls; *r; ++r) { - const primitive_attr_t attr; /* alpha == 1. */ - reorder_pd_t *r_pd = nullptr; - if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), - engine_, src_image_md(i)) == status::success) { - r_pd->init_info(); - reorder_pds_.push_back(r_pd); - break; - } - } - } - - ok = reorder_pds_.size() == (size_t)n_; - return ok ? status::success : status::unimplemented; - } - - nstl::vector reorder_pds_; - }; - - ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) { - const int n = pd()->n_inputs(); - reorders_.resize(n); - for (int i = 0; i < n; ++i) - pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); - } - - ~ref_concat_t() { for (auto &r: reorders_) delete r; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto n = pd()->n_inputs(); - for (int i = 0; i < n; ++i) { - exec_args_t r_args; - r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); - r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); - exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); - reorders_[i]->execute(r_ctx); - } - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - nstl::vector reorders_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp deleted file mode 100644 index c0a979c4c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp +++ /dev/null @@ -1,395 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "mkldnn_traits.hpp" -#include "type_helpers.hpp" - -#include "ref_convolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using math::saturate; -using math::get_bias; - -template -void ref_convolution_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const bool with_groups = pd()->with_groups(); - - const int G = pd()->G(); - const int MB = pd()->MB(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - - const int OC = pd()->OC() / G; - const int IC = pd()->IC() / G; - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - - const int KSD = pd()->KSD(); - const int KSH = pd()->KSH(); - const int KSW = pd()->KSW(); - - const int KDD = pd()->KDD(); - const int KDH = pd()->KDH(); - const int KDW = pd()->KDW(); - - const int padFront = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - const bool with_relu = 0; // TODO: change if support post_ops - const float nslope = 0.f; - - const int ndims = pd()->desc()->src_desc.ndims; - - auto ker = [=](int g, int mb, int oc, int od, int oh, - int ow) { - acc_data_t d = 0; - for (int ic = 0; ic < IC; ++ic) - for (int kd = 0; kd < KD; ++kd) - for (int kh = 0; kh < KH; ++kh) - for (int kw = 0; kw < KW; ++kw) { - const int id = od * KSD - padFront + kd * (1 + KDD); - const int ih = oh * KSH - padT + kh * (1 + KDH); - const int iw = ow * KSW - padL + kw * (1 + KDW); - - if (id < 0 || id >= ID) continue; - if (ih < 0 || ih >= IH) continue; - if (iw < 0 || iw >= IW) continue; - - if (ndims == 5) - d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)] - * (with_groups - ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] - : weights[weights_d.off(oc, ic, kd, kh, kw)]); - else if (ndims == 4) - d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)] - * (with_groups - ? weights[weights_d.off(g, oc, ic, kh, kw)] - : weights[weights_d.off(oc, ic, kh, kw)]); - else if (ndims == 3) - d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)] - * (with_groups - ? weights[weights_d.off(g, oc, ic, kw)] - : weights[weights_d.off(oc, ic, kw)]); - else - assert(false); - - } - return d; - }; - - parallel_nd(G, MB, OC, OD, OH, OW, - [&](int g, int mb, int oc, int od, int oh, int ow) { - float a = bias - ? get_bias(bias, bias_d.off(g * OC + oc), - pd()->desc()->bias_desc.data_type) - : 0; - a += ker(g, mb, oc, od, oh, ow); - if (with_relu && a < 0) - a = a * nslope; - if (ndims == 5) - dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate(a); - else if (ndims == 4) - dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate(a); - else if (ndims == 3) - dst[dst_d.off(mb, g*OC + oc, ow)] = saturate(a); - else - assert(false); - }); -} - -template -void ref_convolution_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const bool with_groups = pd()->with_groups(); - - const int G = pd()->G(); - const int MB = pd()->MB(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - - const int OC = pd()->OC() / G; - const int IC = pd()->IC() / G; - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - - const int KSD = pd()->KSD(); - const int KSH = pd()->KSH(); - const int KSW = pd()->KSW(); - - const int KDD = pd()->KDD(); - const int KDH = pd()->KDH(); - const int KDW = pd()->KDW(); - - const int padFront = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - const int ndims = pd()->desc()->diff_src_desc.ndims; - - auto ker = [=](int g, int mb, int ic, int id, int ih, - int iw) { - acc_data_t d = 0; - for (int oc = 0; oc < OC; ++oc) - for (int kd = 0; kd < KD; ++kd) - for (int kh = 0; kh < KH; ++kh) - for (int kw = 0; kw < KW; ++kw) { - if (iw + padL < kw * (1 + KDW) - || ih + padT < kh * (1 + KDH) - || id + padFront < kd * (1 + KDD)) - continue; - int ow = iw - kw * (1 + KDW) + padL; - int oh = ih - kh * (1 + KDH) + padT; - int od = id - kd * (1 + KDD) + padFront; - if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) - continue; - - ow /= KSW; - oh /= KSH; - od /= KSD; - - if (od < OD && oh < OH && ow < OW) { - if (ndims == 5) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC - + oc, od, oh, ow)] * (with_groups - ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] - : weights[weights_d.off(oc, ic, kd, kh, kw)]); - else if (ndims == 4) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC - + oc, oh, ow)] * (with_groups - ? weights[weights_d.off(g, oc, ic, kh, kw)] - : weights[weights_d.off(oc, ic, kh, kw)]); - else if (ndims == 3) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC - + oc, ow)] * (with_groups - ? weights[weights_d.off(g, oc, ic, kw)] - : weights[weights_d.off(oc, ic, kw)]); - else - assert(false); - } - } - return d; - }; - - parallel_nd(G, MB, IC, ID, IH, IW, - [&](int g, int mb, int ic, int id, int ih, int iw) { - auto ds_idx = (ndims == 5) - ? diff_src_d.off(mb, g*IC + ic, id, ih, iw) - : (ndims == 4) - ? diff_src_d.off(mb, g*IC + ic, ih, iw) - : diff_src_d.off(mb, g*IC + ic, iw); - float a = bias - ? get_bias(bias, bias_d.off(g * IC + ic), - pd()->desc()->bias_desc.data_type) - : 0; - a += ker(g, mb, ic, id, ih, iw); - diff_src[ds_idx] = saturate(a); - }); -} - -template -void ref_convolution_bwd_weights_t::execute_backward_weights(const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); - - const bool with_groups = pd()->with_groups(); - - const int G = pd()->G(); - const int MB = pd()->MB(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - - const int OC = pd()->OC() / G; - const int IC = pd()->IC() / G; - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - - const int KSD = pd()->KSD(); - const int KSH = pd()->KSH(); - const int KSW = pd()->KSW(); - - const int KDD = pd()->KDD(); - const int KDH = pd()->KDH(); - const int KDW = pd()->KDW(); - - const int padFront = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - const int ndims = pd()->desc()->src_desc.ndims; - -auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) { - for (int mb = 0; mb < MB; ++mb) - for (int od = 0; od < OD; ++od) - for (int oh = 0; oh < OH; ++oh) - for (int ow = 0; ow < OW; ++ow) { - if (ow*KSW + kw * (1 + KDW) < padL - || oh*KSH + kh * (1 + KDH) < padT - || od*KSD + kd * (1 + KDD) < padFront - || ow*KSW + kw * (1 + KDW) >= IW + padL - || oh*KSH + kh * (1 + KDH) >= IH + padT - || od*KSD + kd * (1 + KDD) >= ID + padFront) - continue; - - int id = od*KSD - padFront + kd * (1 + KDD); - int ih = oh*KSH - padT + kh * (1 + KDH); - int iw = ow*KSW - padL + kw * (1 + KDW); - if (ndims == 5) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, - oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)]; - else if (ndims == 4) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)] - * src[src_d.off(mb, g*IC + ic, ih, iw)]; - else if (ndims == 3) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)] - * src[src_d.off(mb, g*IC + ic, iw)]; - else - assert(false); - } - }; - - auto ker_bias = [=](acc_data_t &d, int g, int oc) { - for (int mb = 0; mb < MB; ++mb) - for (int od = 0; od < OD; ++od) - for (int oh = 0; oh < OH; ++oh) - for (int ow = 0; ow < OW; ++ow) { - if (ndims == 5) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh, - ow)]; - else if (ndims == 4) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, - ow)]; - else if (ndims == 3) - d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]; - else - assert(false); - } - }; - - parallel_nd(G, OC, [&](int g, int oc) { - if (diff_bias) { - // XXX: loss of precision when bias is a float... - acc_data_t db = 0; - ker_bias(db, g, oc); - diff_bias[diff_bias_d.off(g*OC+oc)] - = saturate(db); - } - - for (int ic = 0; ic < IC; ++ic) - for (int kd = 0; kd < KD; ++kd) - for (int kh = 0; kh < KH; ++kh) - for (int kw = 0; kw < KW; ++kw) { - acc_data_t dw = 0; - ker(dw, g, oc, ic, kd, kh, kw); - - if (ndims == 5) { - auto idx = with_groups - ? diff_weights_d.off(g, oc, ic, kd, kh, kw) - : diff_weights_d.off(oc, ic, kd, kh, kw); - diff_weights[idx] = saturate(dw); - } else if (ndims == 4) { - auto idx = with_groups - ? diff_weights_d.off(g, oc, ic, kh, kw) - : diff_weights_d.off(oc, ic, kh, kw); - diff_weights[idx] = saturate(dw); - } else if (ndims == 3) { - auto idx = with_groups - ? diff_weights_d.off(g, oc, ic, kw) - : diff_weights_d.off(oc, ic, kw); - diff_weights[idx] = saturate(dw); - } else { - assert(false); - } - } - }); -} - -using namespace data_type; - -template struct ref_convolution_fwd_t; - -template struct ref_convolution_fwd_t; -template struct ref_convolution_fwd_t; -template struct ref_convolution_fwd_t; -template struct ref_convolution_fwd_t; - -template struct ref_convolution_bwd_data_t; - -template struct ref_convolution_bwd_data_t; -template struct ref_convolution_bwd_data_t; -template struct ref_convolution_bwd_data_t; -template struct ref_convolution_bwd_data_t; - -template struct ref_convolution_bwd_weights_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp deleted file mode 100644 index 7c83d0c6d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp +++ /dev/null @@ -1,194 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_CONVOLUTION_HPP -#define CPU_REF_CONVOLUTION_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_convolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_fwd_pd_t { - using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t); - - status_t init() { - using namespace data_type; - - bool ok = true - && is_fwd() - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, wei_type, data_type::undef, - dst_type, acc_type) - && IMPLICATION(with_bias(), true - && IMPLICATION(src_type == u8, - utils::one_of(bias_md_.data_type, f32, s32, s8, u8)) - && IMPLICATION(src_type == f32, - bias_md_.data_type == f32)) - && set_default_formats() - && attr()->has_default_values(); - return ok ? status::success : status::unimplemented; - } - - protected: - bool set_default_formats() { - using namespace format_tag; - auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, goiw, goihw, goidhw) - : utils::pick(ndims() - 3, oiw, oihw, oidhw); - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_convolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_data_pd_t { - using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(diff_src_type, wei_type, data_type::undef, - diff_dst_type, acc_type) - && set_default_formats() - && attr()->has_default_values(); - - return ok ? status::success : status::unimplemented; - } - - virtual bool support_bias() const override { return true; } - - protected: - bool set_default_formats() { - using namespace format_tag; - auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, goiw, goihw, goidhw) - : utils::pick(ndims() - 3, oiw, oihw, oidhw); - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type diff_src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_convolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_convolution_bwd_weights_pd_t { - using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t); - - status_t init() { - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && set_default_alg_kind(alg_kind::convolution_direct) - && expect_data_types(src_type, diff_wei_type, diff_wei_type, - diff_dst_type, acc_type) - && set_default_formats() - && attr()->has_default_values(); - return ok ? status::success : status::unimplemented; - } - - protected: - bool set_default_formats() { - using namespace format_tag; - auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); - auto wei_tag = with_groups() - ? utils::pick(ndims() - 3, goiw, goihw, goidhw) - : utils::pick(ndims() - 3, oiw, oihw, oidhw); - return set_default_formats_common(dat_tag, wei_tag, dat_tag); - } - }; - - ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type diff_wei_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp deleted file mode 100644 index 541a303aa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp +++ /dev/null @@ -1,199 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "mkldnn_traits.hpp" -#include "math_utils.hpp" - -#include "ref_deconvolution.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias, - data_t *dst) const { - const memory_desc_wrapper dst_d(pd()->dst_md()); - - const int G = pd()->G(); - const int MB = pd()->MB(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int OD = pd()->OD(); - const int OC = pd()->OC() / G; - const int ndims = pd()->desc()->src_desc.ndims; - - parallel_nd(MB, G, OC, OD, OH, OW, - [&](int mb, int g, int oc, int od, int oh, int ow) { - auto b = bias[g * OC + oc]; - switch (ndims) { - case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break; - case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break; - case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break; - default: assert(!"invalid dimension size"); - } - }); -} - -void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias, - data_t *dst) const { - const memory_desc_wrapper dst_d(pd()->dst_md()); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int SP = pd()->OW()*pd()->OH()*pd()->OD(); - - parallel_nd(MB, OC, [&](int mb, int oc) { - PRAGMA_OMP_SIMD() - for (int sp = 0; sp < SP; ++sp) { - auto offset = (size_t)(mb * OC + oc) * SP + sp; - dst[offset] += bias[oc]; - } - }); -} - -template -void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias, - data_t *dst) const { - const memory_desc_wrapper dst_d(pd()->dst_md()); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int SP = pd()->OW() * pd()->OH() * pd()->OD(); - - const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0]; - - parallel_nd(MB, utils::div_up(OC, blksize), SP, - [&](int mb, int oc_blk, int sp) { - int oc = oc_blk * blksize; - auto offset = mb * stride_mb + oc * SP + sp * blksize; - const int blk = nstl::min(blksize, OC - oc); - - PRAGMA_OMP_SIMD() - for (int i = 0; i < blk; ++i) - dst[offset + i] += bias[oc + i]; - }); -} - -void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst, - data_t *diff_bias) const { - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - - const int G = pd()->G(); - const int MB = pd()->MB(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - const int OC = pd()->OC() / G; - const int OD = pd()->OD(); - const int ndims = pd()->desc()->src_desc.ndims; - - parallel_nd(G, OC, [&](int g, int oc) { - data_t db = 0; - for (int mb = 0; mb < MB; ++mb) { - for (int od = 0; od < OD; ++od) { - for (int oh = 0; oh < OH; ++oh) { - for (int ow = 0; ow < OW; ++ow) { - switch (ndims) { - case 5: - db += diff_dst[diff_dst_d.off( - mb, g * OC + oc, od, oh, ow)]; - break; - case 4: - db += diff_dst[diff_dst_d.off( - mb, g * OC + oc, oh, ow)]; - break; - case 3: - db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)]; - break; - default: assert(!"invalid dimension size"); - } - } - } - } - } - diff_bias[g * OC + oc] = db; - }); -} - -void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( - const data_t *diff_dst, data_t *diff_bias) const { - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - - const int OC = pd()->OC(); - const int MB = pd()->MB(); - const int SP = pd()->OH()*pd()->OW()*pd()->OD(); - - parallel_nd(OC, [&](int oc) { - data_t db = 0; - for (int mb = 0; mb < MB; ++mb) { - PRAGMA_OMP_SIMD() - for (int sp = 0; sp < SP; ++sp) { - auto offset = (size_t)(mb * OC + oc) * SP + sp; - db += diff_dst[offset]; - } - } - diff_bias[oc] = db; - }); -} - -template -void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( - const data_t *diff_dst, data_t *diff_bias) const { - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - - const int OC = pd()->OC(); - const int MB = pd()->MB(); - const int SP = pd()->OH() * pd()->OW() * pd()->OD(); - - const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0]; - - parallel_nd(utils::div_up(OC, blksize), [&](int ocb) { - data_t db[blksize] = {0}; - - for (int mb = 0; mb < MB; ++mb) { - for (int sp = 0; sp < SP; ++sp) { - auto offset = mb * stride_mb + (ocb * SP + sp) * blksize; - - PRAGMA_OMP_SIMD() - for (int i = 0; i < blksize; ++i) - db[i] += diff_dst[offset+i]; - } - } - - const int blk = nstl::min(blksize, OC - ocb * blksize); - - PRAGMA_OMP_SIMD() - for (int i = 0; i < blk; ++i) - diff_bias[ocb * blksize + i] = db[i]; - }); -} - -template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>( - const data_t *diff_dst, data_t *diff_bias) const; -template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>( - const data_t *diff_dst, data_t *diff_bias) const; -template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>( - const data_t *diff_dst, data_t *diff_bias) const; -template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>( - const data_t *diff_dst, data_t *diff_bias) const; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp deleted file mode 100644 index d61903c32..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp +++ /dev/null @@ -1,502 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_DECONVOLUTION_HPP -#define CPU_REF_DECONVOLUTION_HPP - -#include -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "primitive_iterator.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_deconvolution_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -static status_t compute_blocked_format(bool with_groups, - const memory_desc_t *oi_md, memory_desc_t *io_md) -{ - /* Computes blocking for *i*o* format from *o*i* format */ - - bool sanity_check_ok = true - && oi_md->ndims == io_md->ndims - && oi_md->format_kind == format_kind::blocked; - if (!sanity_check_ok) return status::invalid_arguments; - - const blocking_desc_t &oi_blk = oi_md->format_desc.blocking; - blocking_desc_t io_blk = io_md->format_desc.blocking; - - io_md->format_kind = format_kind::blocked; - io_blk = oi_blk; - - const int ID_OC = 0 + with_groups; - const int ID_IC = 1 + with_groups; - - nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]); - for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) { - if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) { - io_blk.inner_idxs[i_blk] = - (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC); - } - } - - return memory_desc_init_by_blocking_desc(*io_md, io_blk); -} - -static status_t conv_descr_create(const deconvolution_desc_t *dd, - convolution_desc_t *cd) -{ - using namespace prop_kind; - alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct - ? alg_kind::convolution_direct : alg_kind::convolution_winograd; - - const memory_desc_t *src_md, *dst_md, *d_weights_d; - prop_kind_t prop_kind; - memory_desc_t c_weights_d; - if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { - prop_kind = backward_data; - src_md = &dd->dst_desc; - dst_md = &dd->src_desc; - d_weights_d = &dd->weights_desc; - } else if (dd->prop_kind == backward_data) { - prop_kind = forward_training; - src_md = &dd->diff_dst_desc; - dst_md = &dd->diff_src_desc; - d_weights_d = &dd->weights_desc; - } else { - prop_kind = dd->prop_kind; - src_md = &dd->diff_dst_desc; - dst_md = &dd->src_desc; - d_weights_d = &dd->diff_weights_desc; - } - - const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; - - /* create weights desc for convolution */ - c_weights_d = *d_weights_d; - - const int ID_OC = 0 + with_groups; - const int ID_IC = 1 + with_groups; - - nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]); - nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]); - nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]); - - if (c_weights_d.format_kind != format_kind::any) - CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d)); - - return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, - prop_kind != backward_weights ? &dd->bias_desc : nullptr, - dst_md, dd->strides, dd->dilates, - dd->padding[0], dd->padding[1], dd->padding_kind); -} - -struct ref_deconvolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_deconvolution_fwd_pd_t { - pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) - {} - - pd_t(const pd_t &other) - : cpu_deconvolution_fwd_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) - , conv_supports_bias_(other.conv_supports_bias_) - , dst_tag_(other.dst_tag_) - {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); - - status_t init_convolution() { - using namespace types; - - convolution_desc_t cd; - CHECK(conv_descr_create(desc(), &cd)); - - mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, - &attr_, nullptr); - while (++it != it.end()) { - conv_pd_ = *it; - conv_supports_bias_ = - static_cast(conv_pd_) - ->support_bias(); - bool output_f32 = utils::everyone_is(data_type::f32, - desc()->accum_data_type, desc()->dst_desc.data_type); - - bool ok = true - && conv_pd_->weights_md()->extra.flags == 0 - /* deconv reference code can process only f32 bias */ - && IMPLICATION(with_bias(), - conv_supports_bias_ || output_f32); - if (ok) return status::success; - - delete conv_pd_; - } - conv_pd_ = nullptr; - return status::unimplemented; - } - - status_t init() { - using namespace format_tag; - bool ok = true - && is_fwd() - && utils::one_of(desc()->alg_kind, - alg_kind::deconvolution_direct, - alg_kind::deconvolution_winograd) - && attr()->post_ops_.has_default_values(); - - if (ok) { - CHECK(init_convolution()); - if (weights_md_.format_kind == format_kind::any) { - CHECK(compute_blocked_format(with_groups(), - conv_pd_->weights_md(), &desc_.weights_desc)); - weights_md_ = desc_.weights_desc; - } - if (src_md_.format_kind == format_kind::any) - src_md_ = *conv_pd_->diff_dst_md(); - if (dst_md_.format_kind == format_kind::any) - dst_md_ = *conv_pd_->diff_src_md(); - if (bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md_, x)); - - dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, - utils::pick(ndims() - 3, ncw, nchw, ncdhw), - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), - utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); - - return status::success; - } - - return status::unimplemented; - } - - virtual void init_scratchpad_md() override { - scratchpad_md_ = *conv_pd_->scratchpad_md(); - } - - primitive_desc_t *conv_pd_; - bool conv_supports_bias_; - format_tag_t dst_tag_; - }; - - typedef typename prec_traits::type data_t; - - ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - ~ref_deconvolution_fwd_t() { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); - conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); - if (pd()->with_bias() && pd()->conv_supports_bias_) - conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS); - conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); - const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); - - conv_p_->execute(conv_ctx); - - if (pd()->with_bias() && !pd()->conv_supports_bias_) { - using namespace format_tag; - - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - switch (pd()->dst_tag_) { - case ncdhw: case nchw: case ncw: - compute_fwd_bias_ncdhw(bias, dst); - break; - case nCdhw8c: case nChw8c: case nCw8c: - compute_fwd_bias_nCdhwXc<8>(bias, dst); - break; - case nCdhw16c: case nChw16c: case nCw16c: - compute_fwd_bias_nCdhwXc<16>(bias, dst); - break; - default: - compute_fwd_bias(bias, dst); - break; - } - } - return status::success; - } - -private: - void compute_fwd_bias(const data_t *bias, data_t *dst) const; - void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const; - template void compute_fwd_bias_nCdhwXc(const data_t *bias, - data_t *dst) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - primitive_t *conv_p_; -}; - -struct ref_deconvolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_deconvolution_bwd_data_pd_t { - pd_t(engine_t *engine, const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) - {} - - pd_t(const pd_t &other) - : cpu_deconvolution_bwd_data_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); - - status_t init_convolution() { - using namespace types; - - convolution_desc_t cd; - status_t status = conv_descr_create(desc(), &cd); - if (status != status::success) return status; - - mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, - &attr_, nullptr); - while (++it != it.end()) { - conv_pd_ = *it; - if (conv_pd_->weights_md()->extra.flags == 0) - return status::success; - delete conv_pd_; - } - - return status::unimplemented; - } - - status_t init() { - using namespace data_type; - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && utils::everyone_is(data_type::f32, - desc()->diff_src_desc.data_type, - desc()->weights_desc.data_type, - desc()->diff_dst_desc.data_type) - && utils::one_of(desc()->alg_kind, - alg_kind::deconvolution_direct, - alg_kind::deconvolution_winograd); - - if (ok) { - CHECK(init_convolution()); - if (weights_md_.format_kind == format_kind::any) { - CHECK(compute_blocked_format(with_groups(), - conv_pd_->weights_md(), &desc_.weights_desc)); - weights_md_ = desc_.weights_desc; - } - if (diff_src_md_.format_kind == format_kind::any) - diff_src_md_ = *conv_pd_->dst_md(); - if (diff_dst_md_.format_kind == format_kind::any) - diff_dst_md_ = *conv_pd_->src_md(); - - return status::success; - } - - return status::unimplemented; - } - - virtual void init_scratchpad_md() override { - scratchpad_md_ = *conv_pd_->scratchpad_md(); - } - - primitive_desc_t *conv_pd_; - }; - - typedef typename prec_traits::type data_t; - - ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - ~ref_deconvolution_bwd_data_t() { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); - conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); - conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); - const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); - - conv_p_->execute(conv_ctx); - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - primitive_t *conv_p_; -}; - -struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_deconvolution_bwd_weights_pd_t { - pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) - {} - - pd_t(const pd_t &other) - : cpu_deconvolution_bwd_weights_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) - , dst_tag_(other.dst_tag_) - {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); - - status_t init_convolution() { - using namespace types; - - convolution_desc_t cd; - status_t status = conv_descr_create(desc(), &cd); - if (status != status::success) return status; - - mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, - &attr_, nullptr); - while (++it != it.end()) { - conv_pd_ = *it; - if (conv_pd_->diff_weights_md()->extra.flags == 0) - return status::success; - delete conv_pd_; - } - return status::unimplemented; - } - - status_t init() { - using namespace format_tag; - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && utils::everyone_is(data_type::f32, - desc()->src_desc.data_type, - desc()->diff_weights_desc.data_type, - desc()->diff_dst_desc.data_type) - && utils::one_of(desc()->alg_kind, - alg_kind::deconvolution_direct, - alg_kind::deconvolution_winograd) - && attr()->has_default_values(); - if (ok) { - CHECK(init_convolution()); - if (diff_weights_md_.format_kind == format_kind::any) { - CHECK(compute_blocked_format(with_groups(), - conv_pd_->diff_weights_md(), - &desc_.diff_weights_desc)); - diff_weights_md_ = desc_.diff_weights_desc; - } - if (src_md_.format_kind == format_kind::any) - src_md_ = *conv_pd_->diff_dst_md(); - if (diff_dst_md_.format_kind == format_kind::any) - diff_dst_md_ = *conv_pd_->src_md(); - if (diff_bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); - - dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, - utils::pick(ndims() - 3, ncw, nchw, ncdhw), - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), - utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); - - return status::success; - } - - return status::unimplemented; - } - - virtual void init_scratchpad_md() override { - scratchpad_md_ = *conv_pd_->scratchpad_md(); - } - - primitive_desc_t *conv_pd_; - format_tag_t dst_tag_; - }; - - typedef typename prec_traits::type data_t; - - ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - ~ref_deconvolution_bwd_weights_t() { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); - conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); - conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); - const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); - - status_t status = conv_p_->execute(conv_ctx); - if (status != status::success) return status; - - if (pd()->with_bias()) { - using namespace format_tag; - - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - switch (pd()->dst_tag_) { - case ncdhw: case nchw: case ncw: - compute_bwd_bias_ncdhw(diff_dst, diff_bias); - break; - case nCdhw8c: case nChw8c: case nCw8c: - compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias); - break; - case nCdhw16c: case nChw16c: case nCw16c: - compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias); - break; - default: - compute_bwd_bias(diff_dst, diff_bias); - break; - } - } - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const; - void compute_bwd_bias_ncdhw(const data_t *diff_dst, - data_t *diff_bias) const; - template void compute_bwd_bias_nCdhwXc( - const data_t *diff_dst, data_t *diff_bias) const; - - primitive_t *conv_p_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp deleted file mode 100644 index 7beee8d32..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp +++ /dev/null @@ -1,297 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_eltwise.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace alg_kind; -using namespace math; - -ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, - float beta): alg_(alg), alpha_(alpha), beta_(beta) { - assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, - eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, - eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); -} - -ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( - const post_ops_t::entry_t::eltwise_t &eltwise) - : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {} - -float ref_eltwise_scalar_fwd_t::compute_scalar(float s) { - switch (alg_) { - case eltwise_relu: return relu_fwd(s, alpha_); - case eltwise_tanh: return tanh_fwd(s); - case eltwise_elu: return elu_fwd(s, alpha_); - case eltwise_square: return square_fwd(s); - case eltwise_abs: return abs_fwd(s); - case eltwise_sqrt: return sqrt_fwd(s); - case eltwise_linear: return linear_fwd(s, alpha_, beta_); - case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_); - case eltwise_soft_relu: return soft_relu_fwd(s); - case eltwise_logistic: return logistic_fwd(s); - default: assert(!"unknown eltwise alg_kind"); - } - - return 0.f; -} - -template -void ref_eltwise_fwd_t::execute_forward_nCspBc_padded( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper data_d(pd()->src_md()); - const blocking_desc_t &blk = data_d.blocking_desc(); - const int block = blk.inner_blks[0]; - - const int MB = pd()->MB(); - const int C = pd()->C() / block; - const int C_PADDED = data_d.padded_dims()[1] / block; - const int tail = pd()->C() % block; - const int SP = pd()->D() * pd()->H() * pd()->W(); - const auto alg_kind = pd()->desc()->alg_kind; - const float alpha = pd()->desc()->alpha; - const float beta = pd()->desc()->beta; - - auto ker = [=] (data_t &d, data_t s) { - switch (alg_kind) { - case eltwise_linear: d = linear_fwd(s, alpha, beta); break; - case eltwise_bounded_relu: - d = bounded_relu_fwd(s, alpha); break; - case eltwise_soft_relu: d = soft_relu_fwd(s); break; - case eltwise_logistic: d = logistic_fwd(s); break; - default: assert(!"unknown eltwise alg_kind"); - } - }; - - // FIXME: integer overflow? - - parallel_nd(MB, C_PADDED, SP, - [&](int n, int c, int sp) { - auto d_off = (n*C_PADDED*SP + c*SP + sp) * block; - if (c < C) { - for (int v = 0; v < block; v++) - ker(dst[d_off + v], src[d_off + v]); - } else { - for (int v = 0; v < tail; v++) - ker(dst[d_off + v], src[d_off + v]); - } - }); -} - -template -void ref_eltwise_fwd_t::execute_forward_generic( - const exec_ctx_t &ctx) const { - /* fast return */ - if (pd()->has_zero_dim_memory()) return; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper data_d(pd()->src_md()); - - const int MB = pd()->MB(); - const int C = pd()->C(); - const int D = pd()->D(); - const int H = pd()->H(); - const int W = pd()->W(); - const auto alg_kind = pd()->desc()->alg_kind; - const float alpha = pd()->desc()->alpha; - const float beta = pd()->desc()->beta; - const bool is_3d = pd()->desc()->data_desc.ndims == 5; - - parallel_nd(MB, C, D, H, W, - [&](int n, int c, int id, int h, int w) { - auto d_off = is_3d - ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w); - data_t s = src[d_off]; - data_t &d = dst[d_off]; - switch (alg_kind) { - case eltwise_relu: d = relu_fwd(s, alpha); break; - case eltwise_tanh: d = tanh_fwd(s); break; - case eltwise_elu: d = elu_fwd(s, alpha); break; - case eltwise_square: d = square_fwd(s); break; - case eltwise_abs: d = abs_fwd(s); break; - case eltwise_sqrt: d = sqrt_fwd(s); break; - case eltwise_linear: d = linear_fwd(s, alpha, beta); break; - case eltwise_bounded_relu: - d = bounded_relu_fwd(s, alpha); break; - case eltwise_soft_relu: d = soft_relu_fwd(s); break; - case eltwise_logistic: d = logistic_fwd(s); break; - default: assert(!"unknown eltwise alg_kind"); - } - }); -} - -template -void ref_eltwise_fwd_t::execute_forward_dense( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper data_d(pd()->src_md()); - - const ptrdiff_t nelems = static_cast(data_d.nelems(true)); - const auto alg_kind = pd()->desc()->alg_kind; - const float alpha = pd()->desc()->alpha; - const float beta = pd()->desc()->beta; - - src += data_d.offset0(); - dst += data_d.offset0(); - - if (alg_kind == eltwise_relu) { - // a fast path for relu as the most popular activation - parallel_nd(nelems, [&](ptrdiff_t e) { - dst[e] = relu_fwd(src[e], alpha); - }); - return; - } - - parallel_nd(nelems, [&](ptrdiff_t e) { - const data_t s = src[e]; - data_t &d = dst[e]; - - switch (alg_kind) { - case eltwise_tanh: d = tanh_fwd(s); break; - case eltwise_elu: d = elu_fwd(s, alpha); break; - case eltwise_square: d = square_fwd(s); break; - case eltwise_abs: d = abs_fwd(s); break; - case eltwise_sqrt: d = sqrt_fwd(s); break; - case eltwise_linear: d = linear_fwd(s, alpha, beta); break; - case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break; - case eltwise_soft_relu: d = soft_relu_fwd(s); break; - case eltwise_logistic: d = logistic_fwd(s); break; - default: assert(!"unknown eltwise alg_kind"); - } - }); -} - -template -void ref_eltwise_bwd_t::execute_backward_generic( - const exec_ctx_t &ctx) const { - /* fast return */ - if (pd()->has_zero_dim_memory()) return; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper data_d(pd()->src_md()); - const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); - - const int MB = pd()->MB(); - const int C = pd()->C(); - const int D = pd()->D(); - const int H = pd()->H(); - const int W = pd()->W(); - const auto alg_kind = pd()->desc()->alg_kind; - const float alpha = pd()->desc()->alpha; - const float beta = pd()->desc()->beta; - const bool is_3d = pd()->desc()->data_desc.ndims == 5; - - parallel_nd(MB, C, D, H, W, - [&](int n, int c, int d, int h, int w) { - auto data_off = is_3d - ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w); - auto diff_data_off = is_3d - ? diff_data_d.off(n, c, d, h, w) - : diff_data_d.off(n, c, h, w); - data_t s = src[data_off]; - data_t dd = diff_dst[diff_data_off]; - data_t &ds = diff_src[diff_data_off]; - switch (alg_kind) { - case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; - case eltwise_tanh: ds = tanh_bwd(dd, s); break; - case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; - case eltwise_square: ds = square_bwd(dd, s); break; - case eltwise_abs: ds = abs_bwd(dd, s); break; - case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; - case eltwise_linear: - ds = linear_bwd(dd, s, alpha, beta); break; - case eltwise_bounded_relu: - ds = bounded_relu_bwd(dd, s, alpha); break; - case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; - case eltwise_logistic: ds = logistic_bwd(dd, s); break; - default: assert(!"unknown eltwise alg_kind"); - } - }); -} - -template -void ref_eltwise_bwd_t::execute_backward_dense( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper data_d(pd()->src_md()); - const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); - - const ptrdiff_t nelems = static_cast(data_d.nelems(true)); - const auto alg_kind = pd()->desc()->alg_kind; - const float alpha = pd()->desc()->alpha; - const float beta = pd()->desc()->beta; - - src += data_d.offset0(); - diff_dst += diff_data_d.offset0(); - diff_src += diff_data_d.offset0(); - - parallel_nd(nelems, [&](ptrdiff_t e) { - const data_t dd = diff_dst[e]; - const data_t s = src[e]; - data_t &ds = diff_src[e]; - - switch (alg_kind) { - case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; - case eltwise_tanh: ds = tanh_bwd(dd, s); break; - case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; - case eltwise_square: ds = square_bwd(dd, s); break; - case eltwise_abs: ds = abs_bwd(dd, s); break; - case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; - case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break; - case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break; - case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; - case eltwise_logistic: ds = logistic_bwd(dd, s); break; - default: assert(!"unknown eltwise alg_kind"); - } - }); -} - -template struct ref_eltwise_fwd_t; -template struct ref_eltwise_fwd_t; -template struct ref_eltwise_fwd_t; -template struct ref_eltwise_fwd_t; - -template struct ref_eltwise_bwd_t; -template struct ref_eltwise_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp deleted file mode 100644 index 8f4ab3541..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp +++ /dev/null @@ -1,168 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_ELTWISE_HPP -#define CPU_REF_ELTWISE_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_eltwise_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct ref_eltwise_scalar_fwd_t { -public: - ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta); - - // note that eltwise.scale is ignored - ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise); - - float compute_scalar(float s); - - const alg_kind_t alg_; - const float alpha_; - const float beta_; -}; - -template -struct ref_eltwise_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_eltwise_fwd_pd_t { - using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t); - - status_t init() { - using namespace utils; - - auto src_d = memory_desc_wrapper(src_md()); - - use_dense_ = false - || src_d.is_dense() - || (src_d.is_dense(true) && is_zero_preserved()); - - use_nCspBc_padded_ = !use_dense_ - && src_d.blocking_desc().inner_nblks == 1 - && one_of(src_d.blocking_desc().inner_blks[0], 8, 16) - && src_d.blocking_desc().inner_idxs[0] == 1 - && src_d.only_padded_dim(1) - && src_d.is_dense(true); - - if (has_zero_dim_memory()) - use_dense_ = use_nCspBc_padded_ = false; - - const bool use_generic = !use_dense_ && !use_nCspBc_padded_; - - bool ok = true - && is_fwd() - && everyone_is(data_type, desc()->data_desc.data_type) - && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5)) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - return status::success; - } - - bool use_dense_, use_nCspBc_padded_; - }; - - ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if (pd()->use_dense_) - execute_forward_dense(ctx); - else if (pd()->use_nCspBc_padded_) - execute_forward_nCspBc_padded(ctx); - else - execute_forward_generic(ctx); - return status::success; - } - -private: - void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const; - void execute_forward_dense(const exec_ctx_t &ctx) const; - void execute_forward_generic(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_eltwise_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_eltwise_bwd_pd_t { - using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t); - - status_t init() { - using namespace utils; - - bool ok = true - && !is_fwd() - && everyone_is(data_type, - desc()->data_desc.data_type, - desc()->diff_data_desc.data_type) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - auto diff_dst_d = memory_desc_wrapper(diff_dst_md()); - const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md()); - - use_dense_ = true - && same_fmt_ - && diff_dst_d.is_dense(true) - && is_zero_preserved() - && !has_zero_dim_memory(); - const bool use_generic = !use_dense_; - - if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5)) - return status::unimplemented; - - return status::success; - } - - bool use_dense_; - }; - - ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if (pd()->use_dense_) - execute_backward_dense(ctx); - else - execute_backward_generic(ctx); - return status::success; - } - -private: - void execute_backward_dense(const exec_ctx_t &ctx) const; - void execute_backward_generic(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp deleted file mode 100644 index c807a9ffd..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp +++ /dev/null @@ -1,285 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "mkldnn_traits.hpp" -#include "math_utils.hpp" - -#include "ref_inner_product.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using math::saturate; -using math::get_bias; - -template -void ref_inner_product_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper bias_d(pd()->weights_md(1)); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int IC = pd()->IC(); - - const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5); - const int ndims = src_d.ndims() - 2; - - const auto &post_ops = pd()->attr()->post_ops_; - const bool do_relu = post_ops.len_ == 1; - const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; - - auto ker_has_spatial = [=](int mb, int oc) { - acc_data_t d = 0; - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - for (int ic = 0; ic < IC; ++ic) { - for (int kd = 0; kd < KD; ++kd) { - for (int kh = 0; kh < KH; ++kh) { - for (int kw = 0; kw < KW; ++kw) { - switch (ndims) { - case 3: - d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)] - * weights[weights_d.off( - oc, ic, kd, kh, kw)]; - break; - case 2: - d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)] - * weights[weights_d.off(oc, ic, kh, kw)]; - break; - case 1: - d += (acc_data_t)src[src_d.off(mb, ic, kw)] - * weights[weights_d.off(oc, ic, kw)]; - break; - default: assert(!"unsupported ndims size"); - } - } - } - } - } - return d; - }; - - auto ker_no_spatial = [=](int mb, int oc) { - acc_data_t d = 0; - for (int ic = 0; ic < IC; ++ic) { - d += (acc_data_t)src[src_d.off(mb, ic)] - * weights[weights_d.off(oc, ic)]; - } - return d; - }; - - parallel_nd(MB, OC, [&](int mb, int oc) { - float a = bias - ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type) - : 0; - if (src_has_spatial) - a += ker_has_spatial(mb, oc); - else - a += ker_no_spatial(mb, oc); - if (do_relu && a < (acc_data_t)0) - a *= nslope; - dst[dst_d.off(mb, oc)] = saturate(a); - }); -} - -using namespace data_type; -template struct ref_inner_product_fwd_t; -template struct ref_inner_product_fwd_t; -template struct ref_inner_product_fwd_t; -template struct ref_inner_product_fwd_t; -template struct ref_inner_product_fwd_t; - -template -void ref_inner_product_bwd_data_t::execute_backward_data(const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); - auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); - auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper weights_d(pd()->weights_md(0)); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int IC = pd()->IC(); - - const bool diff_src_has_spatial - = utils::one_of(diff_src_d.ndims(), 3, 4, 5); - const int ndims = diff_src_d.ndims() - 2; - - parallel_nd(MB, IC, [&](int mb, int ic) { - if (diff_src_has_spatial) { - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - for (int kd = 0; kd < KD; ++kd) - for (int kh = 0; kh < KH; ++kh) - for (int kw = 0; kw < KW; ++kw) { - acc_data_t ds = acc_data_t(0); - for (int oc = 0; oc < OC; ++oc) { - switch (ndims) { - case 3: - ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] - * weights[weights_d.off(oc, ic, kd, kh, kw)]); - break; - case 2: - ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] - * weights[weights_d.off(oc, ic, kh, kw)]); - break; - case 1: - ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] - * weights[weights_d.off(oc, ic, kw)]); - break; - default: assert(!"unsupported ndims size"); - } - } - switch (ndims) { - case 3: - diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] - = (diff_src_data_t)ds; - break; - case 2: - diff_src[diff_src_d.off(mb, ic, kh, kw)] - = (diff_src_data_t)ds; - break; - case 1: - diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds; - break; - default: assert(!"unsupported ndims size"); - } - } - } else { - acc_data_t ds = acc_data_t(0); - for (int oc = 0; oc < OC; ++oc) { - ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] * - weights[weights_d.off(oc, ic)]); - } - diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds; - } - }); -} - -template struct ref_inner_product_bwd_data_t; - -template -void ref_inner_product_bwd_weights_t::execute_backward_weights( - const exec_ctx_t &ctx) const { - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); - auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); - const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); - - const int MB = pd()->MB(); - const int OC = pd()->OC(); - const int IC = pd()->IC(); - - const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5); - const int ndims = src_d.ndims() - 2; - - parallel_nd(OC, IC, [&](int oc, int ic) { - if (src_has_spatial) { - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - for (int kd = 0; kd < KD; ++kd) { - for (int kh = 0; kh < KH; ++kh) { - for (int kw = 0; kw < KW; ++kw) { - data_t *dw(nullptr); - switch (ndims) { - case 3: - dw = &diff_weights[diff_weights_d.off( - oc, ic, kd, kh, kw)]; - break; - case 2: - dw = &diff_weights[diff_weights_d.off( - oc, ic, kh, kw)]; - break; - case 1: - dw = &diff_weights[diff_weights_d.off(oc, ic, kw)]; - break; - default: assert(!"unsupported ndims size"); - } - *dw = data_t(0); - for (int mb = 0; mb < MB; ++mb) { - switch (ndims) { - case 3: - *dw += diff_dst[diff_dst_d.off(mb, oc)] - * src[src_d.off(mb, ic, kd, kh, kw)]; - break; - case 2: - *dw += diff_dst[diff_dst_d.off(mb, oc)] - * src[src_d.off(mb, ic, kh, kw)]; - break; - case 1: - *dw += diff_dst[diff_dst_d.off(mb, oc)] - * src[src_d.off(mb, ic, kw)]; - break; - default: assert(!"unsupported ndims size"); - } - } - } - } - } - } else { - data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)]; - *dw = data_t(0); - for (int mb = 0; mb < MB; ++mb) { - *dw += diff_dst[diff_dst_d.off(mb, oc)] * - src[src_d.off(mb, ic)]; - } - } - }); - - if (diff_bias) { - diff_bias += diff_bias_d.offset0(); - - parallel_nd(OC, [&](int oc) { - data_t *db = &diff_bias[oc]; - *db = data_t(0); - for (int mb = 0; mb < MB; ++mb) - *db += diff_dst[diff_dst_d.off(mb, oc)]; - }); - } -} - -template struct ref_inner_product_bwd_weights_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp deleted file mode 100644 index bf87dbd51..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp +++ /dev/null @@ -1,159 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_INNER_PRODUCT_HPP -#define CPU_REF_INNER_PRODUCT_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_inner_product_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_inner_product_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_fwd_pd_t { - using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t); - - status_t init() { - using namespace data_type; - - bool ok = true - && set_default_params() == status::success - && is_fwd() - && src_md()->data_type == src_type - && weights_md()->data_type == wei_type - && desc()->accum_data_type == acc_type - && dst_md()->data_type == dst_type - && IMPLICATION(with_bias(), utils::one_of( - weights_md(1)->data_type, f32, s32, s8, u8)) - && attr()->output_scales_.has_default_values() - && attr()->post_ops_.len_ <= 1 - && IMPLICATION(attr()->post_ops_.len_ == 1, - attr()->post_ops_.entry_[0].is_relu(true, false)); - return ok ? status::success : status::unimplemented; - } - }; - - ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_inner_product_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_bwd_data_pd_t { - using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t); - - status_t init() { - bool ok = true - && set_default_params() == status::success - && desc()->prop_kind == prop_kind::backward_data - && diff_src_md()->data_type == diff_src_type - && weights_md()->data_type == wei_type - && desc()->accum_data_type == acc_type - && diff_dst_md()->data_type == diff_dst_type - && attr()->has_default_values(); - return ok ? status::success : status::unimplemented; - } - }; - - ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type diff_src_data_t; - typedef typename prec_traits::type wei_data_t; - typedef typename prec_traits::type diff_dst_data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_data(ctx); - return status::success; - } - -private: - void execute_backward_data(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_inner_product_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_inner_product_bwd_weights_pd_t { - using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t); - - status_t init() { - bool ok = true - && set_default_params() == status::success - && desc()->prop_kind == prop_kind::backward_weights - && utils::everyone_is(data_type, - src_md()->data_type, - diff_dst_md()->data_type, - diff_weights_md()->data_type) - && IMPLICATION(with_bias(), - data_type == diff_weights_md(1)->data_type) - && attr()->has_default_values(); - return ok ? status::success : status::unimplemented; - } - }; - - ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward_weights(ctx); - return status::success; - } - -private: - void execute_backward_weights(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp deleted file mode 100644 index 325e97963..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp +++ /dev/null @@ -1,252 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" - -#include "ref_lrn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -static inline float fast_negative_powf(float omega, float beta) { - float Y; -/* - * Y = omega^(-3/4) = - * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega)) - * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega) - * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega) - * = sqrtf(1.0f / sqrtf(omega) / omega) - * = sqrtf(1.0f / (sqrtf(omega) * omega)) - */ - if (beta == 0.75f) { - Y = sqrtf(1.0f / (sqrtf(omega) * omega)); - } else { - Y = 1.0f / powf(omega, beta); - } - return Y; -}; - -template -template -void ref_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - using namespace alg_kind; - using namespace format_tag; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper data_d(pd()->src_md()); - - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - const size_t stride_mb = data_d.blocking_desc().strides[0]; - const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; - constexpr int blksize = tag == nChw16c ? 16 : 8; - - auto data_off = [&](int mb, int c, int h, int w) -> size_t { - switch (tag) { - case nChw16c: - case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize - + h * W * blksize + w * blksize + c % blksize; - case nchw: return mb * stride_mb + c * H * W + h * W + w; - case nhwc: return mb * stride_mb + h * W * C + w * C + c; - default: return data_d.off(mb, c, h, w); - } - }; - - auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { - const float alpha = static_cast(pd()->desc()->lrn_alpha); - const float beta = static_cast(pd()->desc()->lrn_beta); - const float k = static_cast(pd()->desc()->lrn_k); - - const int size = pd()->desc()->local_size; - const int half_size = (size - 1) / 2; - - float sum = 0; - if (across_channels) { - const int c_st = nstl::max(oc - half_size + 0, 0); - const int c_en = nstl::min(oc + half_size + 1, C); - - for (int c = c_st; c < c_en; ++c) { - const float s = src[data_off(mb, c, oh, ow)]; - sum += s * s; - } - } else { - int h_st = nstl::max(oh - half_size + 0, 0); - int h_en = nstl::min(oh + half_size + 1, H); - int w_st = nstl::max(ow - half_size + 0, 0); - int w_en = nstl::min(ow + half_size + 1, W); - for (int h = h_st; h < h_en; ++h) { - for (int w = w_st; w < w_en; ++w) { - const float s = src[data_off(mb, oc, h, w)]; - sum += s * s; - } - } - } - const int summands = across_channels ? size : size * size; - sum = k + alpha * sum / summands; - size_t off = data_off(mb, oc, oh, ow); - d[0] = static_cast(src[off] * fast_negative_powf(sum, beta)); - }; - - const int MB = pd()->MB(); - if (tag == nChw16c || tag == nChw8c) { - parallel_nd(MB, utils::div_up(C, blksize), H, W, - [&](int mb, int c_blk, int h, int w) { - int c = c_blk * blksize; - const size_t off = mb * stride_mb + c * H * W - + (h * W + w) * blksize; - PRAGMA_OMP_SIMD() - for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) - ker(&dst[off + cc], mb, c + cc, h, w); - }); - } else if (tag == nhwc) { - parallel_nd(MB, H, W, C, - [&](int mb, int h, int w, int c) { - const size_t off = mb * stride_mb + h * W * C + w * C + c; - ker(&dst[off], mb, c, h, w); - }); - } else { - parallel_nd(MB, C, H, W, - [&](int mb, int c, int h, int w) { - const size_t off = data_off(mb, c, h, w); - ker(&dst[off], mb, c, h, w); - }); - } -} - -template -template -void ref_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const { - using namespace alg_kind; - using namespace format_tag; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper data_d(pd()->src_md()); - - const int MB = pd()->MB(); - const int C = pd()->C(); - const int H = pd()->H(); - const int W = pd()->W(); - const size_t stride_mb = data_d.blocking_desc().strides[0]; - constexpr int blksize = tag == nChw16c ? 16 : 8; - - const float alpha = static_cast(pd()->desc()->lrn_alpha); - const float beta = static_cast(pd()->desc()->lrn_beta); - const float k = static_cast(pd()->desc()->lrn_k); - const int kernel_size = pd()->desc()->local_size; - const int half_ksize = (kernel_size - 1) / 2; - - auto data_off = [&](int mb, int c, int h, int w) -> size_t { - switch (tag) { - case nChw16c: - case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize - + h * W * blksize + w * blksize + c%blksize; - case nchw: return mb * stride_mb + c * H * W + h * W + w; - case nhwc: return mb * stride_mb + h * W * C + w * C + c; - default: return data_d.off(mb, c, h, w); - } - }; - - auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { - const int c_st = nstl::max(oc - half_ksize + 0, 0); - const int c_en = nstl::min(oc + half_ksize + 1, C); - - float A = 0, B = 0, omega_mid = 0; - for (int c = c_st; c < c_en; c++) { - float sum = 0.0; - const int i_st = nstl::max(c - half_ksize, 0); - const int i_en = nstl::min(c + kernel_size - half_ksize, C); - - for (int i = i_st; i < i_en; ++i) { - const float value = src[data_off(mb, i, oh, ow)]; - sum += value * value; - } - const float omega = static_cast(k + sum * alpha / kernel_size); - if (c == oc) omega_mid = omega; - float t = src[data_off(mb, c, oh, ow)] - * fast_negative_powf(omega, beta); - B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)]; - } - - const size_t off = data_off(mb, oc, oh, ow); - A = fast_negative_powf(omega_mid, beta) * diff_dst[off]; - B *= src[off]; - B *= (2.0f * alpha * beta) / kernel_size; - *d = static_cast(A - B); // final cast down to data_t - }; - - if (tag == nChw16c || tag == nChw8c) { - parallel_nd(MB, utils::div_up(C, blksize), H, W, - [&](int mb, int c_blk, int h, int w) { - int c = c_blk * blksize; - const size_t off = mb * stride_mb + c * H * W + - (h * W + w) * blksize; - PRAGMA_OMP_SIMD() - for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) - ker(&diff_src[off + cc], mb, c + cc, h, w); - }); - } else if (tag == nhwc) { - parallel_nd(MB, H, W, C, - [&](int mb, int h, int w, int c) { - const size_t off = mb * stride_mb + h * W * C + w * C + c; - ker(&diff_src[off], mb, c, h, w); - }); - } else { - parallel_nd(MB, C, H, W, - [&](int mb, int c, int h, int w) { - const size_t off = data_off(mb, c, h, w); - ker(&diff_src[off], mb, c, h, w); - }); - } -} - -template void ref_lrn_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const; -template void ref_lrn_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const; -template void ref_lrn_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const; -template void ref_lrn_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const; -template void ref_lrn_fwd_t:: -execute_forward(const exec_ctx_t &ctx) const; -template void ref_lrn_bwd_t:: -execute_backward(const exec_ctx_t &ctx) const; -template void ref_lrn_bwd_t:: -execute_backward(const exec_ctx_t &ctx) const; -template void ref_lrn_bwd_t:: -execute_backward(const exec_ctx_t &ctx) const; -template void ref_lrn_bwd_t:: -execute_backward(const exec_ctx_t &ctx) const; -template void ref_lrn_bwd_t:: -execute_backward(const exec_ctx_t &ctx) const; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp deleted file mode 100644 index f25cfb7fa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp +++ /dev/null @@ -1,136 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_LRN_HPP -#define CPU_REF_LRN_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_lrn_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_lrn_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_lrn_fwd_pd_t { - using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t); - - status_t init() { - using namespace format_tag; - - bool ok = true - && is_fwd() - && src_md()->data_type == data_type - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - dat_tag_ = memory_desc_matches_one_of_tag( - *src_md(), nChw16c, nChw8c, nchw, nhwc); - - return status::success; - } - - format_tag_t dat_tag_; - }; - - ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - using namespace format_tag; - switch (pd()->dat_tag_) { - case nChw16c: execute_forward(ctx); break; - case nChw8c: execute_forward(ctx); break; - case nchw: execute_forward(ctx); break; - case nhwc: execute_forward(ctx); break; - default: execute_forward(ctx); - } - return status::success; - } - -private: - template - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_lrn_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_lrn_bwd_pd_t { - using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t); - - status_t init() { - using namespace format_tag; - using namespace alg_kind; - - bool ok = true - && !is_fwd() - && utils::one_of(desc()->alg_kind, lrn_across_channels - /*, lrn_within_channel */) // not supported yet - && utils::everyone_is(data_type, - src_md()->data_type, - diff_src_md()->data_type) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - dat_tag_ = memory_desc_matches_one_of_tag( - *src_md(), nChw16c, nChw8c, nchw, nhwc); - - return status::success; - } - - format_tag_t dat_tag_; - }; - - ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - using namespace format_tag; - switch (pd()->dat_tag_) { - case nChw16c: execute_backward(ctx); break; - case nChw8c: execute_backward(ctx); break; - case nchw: execute_backward(ctx); break; - case nhwc: execute_backward(ctx); break; - default: execute_backward(ctx); - } - return status::success; - } - -private: - template - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp deleted file mode 100644 index 65b934e12..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp +++ /dev/null @@ -1,381 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" - -#include "ref_pooling.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -void ref_pooling_fwd_t::execute_forward( - const exec_ctx_t &ctx) const { - using namespace alg_kind; - using namespace prop_kind; - - auto alg = pd()->desc()->alg_kind; - - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); - - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - const memory_desc_wrapper ws_d(pd()->workspace_md()); - const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; - - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - const int SD = pd()->KSD(); - const int SH = pd()->KSH(); - const int SW = pd()->KSW(); - const int padF = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - const bool is_3d = pd()->desc()->src_desc.ndims == 5; - - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; - - auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) { - if (ws) { - assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); - size_t offset = is_3d - ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);; - if (ws_dt == data_type::u8) { - assert(0 <= value && value <= 255); - ws[offset] = value; - } else - reinterpret_cast(ws)[offset] = value; - } - }; - - auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) { - for (int kh = 0; kh < KH; ++kh) { - for (int kw = 0; kw < KW; ++kw) { - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - if (ih < 0 || ih >= IH) continue; - if (iw < 0 || iw >= IW) continue; - - auto s = src[src_d.off(mb, oc, ih, iw)]; - if (s > d[0]) { - d[0] = s; - set_ws(mb, oc, 1, oh, ow, kh*KW + kw); - } - } - } - }; - - auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) { - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH - : (ih_end - ih_start)*(iw_end - iw_start); - - acc_data_t dst = 0; - for (int ih = ih_start; ih < ih_end; ++ih) { - for (int iw = iw_start; iw < iw_end; ++iw) { - dst += src[src_d.off(mb, oc, ih, iw)]; - } - } - - d[0] = math::out_round((float)dst / num_summands); - }; - - auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { - for (int kd = 0; kd < KD; ++kd) { - for (int kh = 0; kh < KH; ++kh) { - for (int kw = 0; kw < KW; ++kw) { - const int id = od * SD - padF + kd; - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - if (id < 0 || id >= ID) continue; - if (ih < 0 || ih >= IH) continue; - if (iw < 0 || iw >= IW) continue; - - auto s = src[src_d.off(mb, oc, id, ih, iw)]; - if (s > d[0]) { - d[0] = s; - set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw); - } - } - } - } - }; - - auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { - auto id_start = apply_offset(od*SD, padF); - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto id_end = nstl::min(od*SD - padF + KD, ID); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD - : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); - - acc_data_t dst = 0; - for (int id = id_start; id < id_end; ++id) { - for (int ih = ih_start; ih < ih_end; ++ih) { - for (int iw = iw_start; iw < iw_end; ++iw) { - dst += src[src_d.off(mb, oc, id, ih, iw)]; - } - } - } - - d[0] = math::out_round((float)dst / num_summands); - }; - - const int MB = pd()->MB(); - const int OC = pd()->C(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - - if (alg == pooling_max) { - parallel_nd(MB, OC, OD, OH, OW, - [&](int mb, int oc, int od, int oh, int ow) { - data_t *d = is_3d - ? &dst[dst_d.off(mb, oc, od, oh, ow)] - : &dst[dst_d.off(mb, oc, oh, ow)]; - d[0] = nstl::numeric_limits::lowest(); - set_ws(mb, oc, od, oh, ow, 0); - if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); - else ker_max(d, mb, oc, oh, ow); - }); - } else { - parallel_nd(MB, OC, OD, OH, OW, - [&](int mb, int oc, int od, int oh, int ow) { - data_t *d = is_3d - ? &dst[dst_d.off(mb, oc, od, oh, ow)] - : &dst[dst_d.off(mb, oc, oh, ow)]; - d[0] = 0; - if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); - else ker_avg(d, mb, oc, oh, ow); - }); - } -} - -template -void ref_pooling_bwd_t::execute_backward( - const exec_ctx_t &ctx) const { - using namespace alg_kind; - - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); - const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); - const memory_desc_wrapper ws_d(pd()->workspace_md()); - - const int ID = pd()->ID(); - const int IH = pd()->IH(); - const int IW = pd()->IW(); - const int KD = pd()->KD(); - const int KH = pd()->KH(); - const int KW = pd()->KW(); - const int SD = pd()->KSD(); - const int SH = pd()->KSH(); - const int SW = pd()->KSW(); - const int padF = pd()->padFront(); - const int padT = pd()->padT(); - const int padL = pd()->padL(); - - const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; - - auto alg = pd()->desc()->alg_kind; - - auto apply_offset = [=](int index, int offset) { - return (index > offset) ? index - offset : 0; - }; - - auto ker_zero = [=](int _mb, int _oc) { - for (int ih = 0; ih < IH; ++ih) { - for (int iw = 0; iw < IW; ++iw) { - diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0); - } - } - }; - - auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) { - const size_t ws_off = ws_d.off(mb, oc, oh, ow); - const int index = ws_d.data_type() == data_type::u8 - ? (int)ws[ws_off] : ((int *)ws)[ws_off]; - const int kw = index % KW; - const int kh = index / KW; - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - // If padding area could fit the kernel, - // then input displacement would be out of bounds. - // No need to back propagate there as padding is - // virtual in pooling_max case. - if (ih < 0 || ih >= IH) - return; - if (iw < 0 || iw >= IW) - return; - - diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0]; - }; - - auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) { - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH - : (ih_end - ih_start)*(iw_end - iw_start); - - for (int ih = ih_start; ih < ih_end; ++ih) { - for (int iw = iw_start; iw < iw_end; ++iw) { - diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands; - } - } - }; - - auto ker_zero_3d = [=](int _mb, int _oc) { - for (int id = 0; id < ID; ++id) { - for (int ih = 0; ih < IH; ++ih) { - for (int iw = 0; iw < IW; ++iw) { - diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] = - data_type_t(0); - } - } - } - }; - - auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh, - int ow) { - const size_t ws_off = ws_d.off(mb, oc, od, oh, ow); - const int index = ws_d.data_type() == data_type::u8 - ? (int)ws[ws_off] : ((int *)ws)[ws_off]; - const int kw = index % KW; - const int kh = (index / KW) % KH; - const int kd = (index / KW) / KH; - const int id = od * SD - padF + kd; - const int ih = oh * SH - padT + kh; - const int iw = ow * SW - padL + kw; - - // If padding area could fit the kernel, - // then input displacement would be out of bounds. - // No need to back propagate there as padding is - // virtual in pooling_max case. - if (id < 0 || id >= ID) - return; - if (ih < 0 || ih >= IH) - return; - if (iw < 0 || iw >= IW) - return; - - diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0]; - }; - - auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh, - int ow) { - auto id_start = apply_offset(od*SD, padF); - auto ih_start = apply_offset(oh*SH, padT); - auto iw_start = apply_offset(ow*SW, padL); - auto id_end = nstl::min(od*SD - padF + KD, ID); - auto ih_end = nstl::min(oh*SH - padT + KH, IH); - auto iw_end = nstl::min(ow*SW - padL + KW, IW); - - auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD - : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); - - for (int id = id_start; id < id_end; ++id) - for (int ih = ih_start; ih < ih_end; ++ih) - for (int iw = iw_start; iw < iw_end; ++iw) { - diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands; - } - }; - - const int MB = pd()->MB(); - const int OC = pd()->C(); - const int OD = pd()->OD(); - const int OH = pd()->OH(); - const int OW = pd()->OW(); - - if (pd()->desc()->alg_kind == alg_kind::pooling_max) { - parallel_nd(MB, OC, [&](int mb, int oc) { - if (is_3d) ker_zero_3d(mb, oc); - else ker_zero(mb, oc); - for (int od = 0; od < OD; ++od) { - for (int oh = 0; oh < OH; ++oh) { - for (int ow = 0; ow < OW; ++ow) { - const data_t *d = is_3d - ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] - : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; - if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); - else ker_max(d, mb, oc, oh, ow); - } - } - } - }); - } else { - parallel_nd(MB, OC, [&](int mb, int oc) { - if (is_3d) ker_zero_3d(mb, oc); - else ker_zero(mb, oc); - for (int od = 0; od < OD; ++od) { - for (int oh = 0; oh < OH; ++oh) { - for (int ow = 0; ow < OW; ++ow) { - const data_t *d = is_3d - ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] - : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; - if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); - else ker_avg(d, mb, oc, oh, ow); - } - } - } - }); - } -} - -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; -template struct ref_pooling_fwd_t; - -template struct ref_pooling_bwd_t; -template struct ref_pooling_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp deleted file mode 100644 index e43ceaa82..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_POOLING_HPP -#define CPU_REF_POOLING_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_pooling_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_pooling_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_fwd_pd_t { - using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t); - - status_t init() { - bool ok = true - && set_default_params() == status::success - && is_fwd() - && utils::everyone_is(data_type, src_md()->data_type, - dst_md()->data_type) - && desc()->accum_data_type == acc_type - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - bool is_training = desc_.prop_kind == prop_kind::forward_training; - if (desc()->alg_kind == alg_kind::pooling_max && is_training) - init_default_ws(); - - return status::success; - } - }; - - ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - - typedef typename prec_traits::type data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_forward(ctx); - return status::success; - } - -private: - void execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct ref_pooling_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_pooling_bwd_pd_t { - using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t); - - status_t init() { - bool ok = true - && set_default_params() == status::success - && !is_fwd() - && utils::everyone_is(data_type, diff_dst_md()->data_type, - diff_src_md()->data_type) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - if (desc()->alg_kind == alg_kind::pooling_max) { - init_default_ws(); - if (!compare_ws(hint_fwd_pd_)) - return status::unimplemented; - } - - return status::success; - } - }; - - ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} - typedef typename prec_traits::type data_t; - typedef typename prec_traits::type acc_data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_backward(ctx); - return status::success; - } - -private: - void execute_backward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp deleted file mode 100644 index af2774311..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp +++ /dev/null @@ -1,153 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" - -#include "ref_shuffle.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace format_tag; - -template -template -void ref_shuffle_t::execute_(const exec_ctx_t &ctx) const { - using namespace prop_kind; - using namespace utils; - - const memory_desc_wrapper data_d(pd()->data_md()); - - auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST; - auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC; - auto input = CTX_IN_MEM(const data_t *, i_arg); - auto output = CTX_OUT_MEM(data_t *, o_arg); - - const int axis = pd()->axis(); - const int axis_size = pd()->axis_size(); - - const int MB = pd()->MB(); - const int C = pd()->C(); - int H = 1, W = 1, D = 1, HW = 1, SP = 1; - const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5); - if (has_spatial) - { - D = pd()->D(); - H = pd()->H(); - W = pd()->W(); - HW = H * W; - SP = D * HW; - } - const size_t stride_mb = data_d.blocking_desc().strides[0]; - constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8; - - if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) { -#if MKLDNN_THR == MKLDNN_THR_OMP -# pragma omp parallel for collapse(3) schedule(static) - for (int mb = 0; mb < MB; ++mb) - for (int cb = 0; cb < C; cb += blksize) - for (int sp = 0; sp < SP; ++sp) { - const size_t off = mb * stride_mb + sp * blksize; - const size_t output_off = off + cb * SP; - PRAGMA_OMP_SIMD() - for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) - { - int input_c = rev_transposed_[cb + cc]; - const size_t input_off = off + input_c / blksize * SP * blksize - + input_c % blksize; - output[output_off + cc] = input[input_off]; - } - } -#else - parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c, - int sp) { - const size_t off = mb * stride_mb + sp * blksize; - const int cb = c * blksize; - const size_t output_off = off + cb * SP; - for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) - { - int input_c = rev_transposed_[cb + cc]; - const size_t input_off = off + input_c / blksize * SP * blksize - + input_c % blksize; - output[output_off + cc] = input[input_off]; - } - }); -#endif - } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) { - parallel_nd(MB, SP, [&](int mb, int sp) { - const size_t off = mb * stride_mb + sp * C; - PRAGMA_OMP_SIMD() - for (int c = 0; c < C; ++c) - output[off + c] = input[off + rev_transposed_[c]]; - }); - } else if (axis == 1 && one_of(tag, nchw, ncdhw)) { - parallel_nd(MB, C, [&](int mb, int c) { - const size_t output_off = mb * stride_mb + c * SP; - const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP; - PRAGMA_OMP_SIMD() - for (int sp = 0; sp < SP; ++sp) { - output[output_off + sp] = input[input_off + sp]; - } - }); - } else { - auto dims = pd()->desc()->data_desc.dims; - auto ndims = pd()->desc()->data_desc.ndims; - const size_t outer_size = utils::array_product(dims, axis); - const size_t inner_size = utils::array_product(dims + axis + 1, - ndims - axis - 1); - const size_t dim = axis_size * inner_size; - - parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a, - size_t in) - { - const size_t off = ou * dim + in; - auto &o = output[data_d.off_l(off + a * inner_size)]; - o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)]; - }); - } -} - -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<4>::execute_(const exec_ctx_t &ctx) const; - -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; -template void ref_shuffle_t<1>::execute_(const exec_ctx_t &ctx) const; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp deleted file mode 100644 index 5e09a1a69..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp +++ /dev/null @@ -1,111 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_SHUFFLE_HPP -#define CPU_REF_SHUFFLE_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_shuffle_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_shuffle_t : public cpu_primitive_t { - using shuffle_class = ref_shuffle_t; - - struct pd_t: public cpu_shuffle_pd_t { - using cpu_shuffle_pd_t::cpu_shuffle_pd_t; - - DECLARE_COMMON_PD_T("ref:any", shuffle_class); - - status_t init() { - using namespace format_tag; - - bool ok = true - && data_type_size - == types::data_type_size(data_md()->data_type); - if (!ok) return status::unimplemented; - - if (ndims() == 5) { - dat_tag_ = memory_desc_matches_one_of_tag( - *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc); - } else if (ndims() == 4) { - dat_tag_ = memory_desc_matches_one_of_tag( - *data_md(), nChw16c, nChw8c, nchw, nhwc); - } else - dat_tag_ = any; - - return status::success; - } - - format_tag_t dat_tag_; - }; - - ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) { - const int axis_size = pd()->axis_size(); - const int group_size = pd()->group_size(); - const int transpose_row = pd()->is_fwd() ? group_size - : axis_size / group_size; - const int transpose_col = pd()->is_fwd() ? axis_size / group_size - : group_size; - rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64); - parallel_nd(transpose_col, transpose_row, [&](int i, int j) { - rev_transposed_[j * transpose_col + i] = i * transpose_row + j; - }); - } - - ~ref_shuffle_t() { free(rev_transposed_); } - - typedef typename typesize_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - using namespace format_tag; - switch (pd()->dat_tag_) { - case nCdhw16c: execute_(ctx); break; - case nChw16c: execute_(ctx); break; - case nCdhw8c: execute_(ctx); break; - case nChw8c: execute_(ctx); break; - case ncdhw: execute_(ctx); break; - case nchw: execute_(ctx); break; - case ndhwc: execute_(ctx); break; - case nhwc: execute_(ctx); break; - default: execute_(ctx); break; - } - return status::success; - } - -private: - template - void execute_(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - int *rev_transposed_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp deleted file mode 100644 index 36d5237f5..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp +++ /dev/null @@ -1,264 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include -#include - -#include "c_types_map.hpp" -#include "mkldnn_thread.hpp" -#include "type_helpers.hpp" - -#include "ref_softmax.hpp" -#include "gemm/os_blas.hpp" - -#ifdef USE_MKL -#include "mkl_vml_functions.h" -#endif - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -void ref_softmax_fwd_t::execute_forward_dense( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - parallel_nd(outer_size_, [&](int ou) { - const data_t *src_data = src + ou * channels_; - data_t *dst_data = dst + ou * channels_; - data_t scalar = 0; - - _max(channels_, src_data, &scalar); - _sub(channels_, scalar, src_data, dst_data); - _exp(channels_, dst_data, dst_data); - _sum(channels_, dst_data, &scalar); - _scal(channels_, data_t(1)/scalar, dst_data); - }); -} - -template -void ref_softmax_fwd_t::execute_forward_generic( - const exec_ctx_t &ctx) const { - auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - data_t space_max_val = 0, space_denom_val = 0; - data_t *space_max = &space_max_val, *space_denom = &space_denom_val; - if (inner_size_ > 1) { - using namespace memory_tracking::names; - space_max = scratchpad(ctx).template get(key_softmax_reduction); - space_denom = space_max + inner_size_; - } - - const memory_desc_wrapper data_d(pd()->src_md()); - const size_t dim = channels_ * inner_size_; - - for (int ou = 0; ou < outer_size_; ou++) { - utils::array_set(space_max, -FLT_MAX, inner_size_); - utils::array_set(space_denom, 0, inner_size_); - - for (int c = 0; c < channels_; c++) { - for(int in = 0; in < inner_size_; in++) { - size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); - space_max[in] = nstl::max(space_max[in], src[off]); - } - } - - for (int c = 0; c < channels_; c++) { - for(int in = 0; in < inner_size_; in++) { - size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); - space_denom[in] += dst[off] = exp(src[off] - space_max[in]); - } - } - - for (int c = 0; c < channels_; c++) { - for (int in = 0; in < inner_size_; in++) { - size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); - dst[off] /= space_denom[in]; - } - } - } -} - -template -void ref_softmax_fwd_t::_max(int n, const data_t *x, - data_t *max_data) const { -// Intel(R) C++ Compiler generates the maxps + shuffle pattern -// for the max search which works faster -#if !defined(__INTEL_COMPILER) - // The code below makes a compiler to generate maxps instruction - // rather than maxss, which is generated for the 'else' code path - auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); }; - auto min_wrapper = [](int a, int b) { return nstl::min(a, b); }; - - constexpr int unroll_factor = 32; - data_t max_values[unroll_factor]; - - if (n < unroll_factor) { - data_t max_val = x[0]; - for (int i = 1; i < n; i++) { - max_val = max_wrapper(max_val, x[i]); - } - max_data[0] = max_val; - return; - } - for (int i = 0; i < unroll_factor; i++) { - max_values[i] = x[i]; - } - for (int i = unroll_factor; i < n; i += unroll_factor) { - int offset = min_wrapper(i, n - unroll_factor); - for (int j = 0; j < unroll_factor; j++) { - max_values[j] = max_wrapper(max_values[j], x[offset + j]); - } - } - data_t max_val = max_values[0]; - for (int i = 1; i < unroll_factor; i++) { - max_val = max_wrapper(max_val, max_values[i]); - } - max_data[0] = max_val; -#else - max_data[0] = x[0]; - for (int c = 1; c < n; ++c) - max_data[0] = nstl::max(max_data[0], x[c]); -#endif -} - -template -void ref_softmax_fwd_t::_sub(int n, data_t alpha, const data_t *x, - data_t *y) const { - constexpr int unroll_factor = 32; - int tail = n % unroll_factor; - for (int i = 0; i < n - tail; i += unroll_factor) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < unroll_factor; j++) { - y[i + j] = x[i + j] - alpha; - } - } - PRAGMA_OMP_SIMD() - for (int i = n - tail; i < n; i++) { - y[i] = x[i] - alpha; - } -} - -template -void ref_softmax_fwd_t::_exp(int n, const data_t *a, - data_t *r) const { -#ifdef USE_MKL - if (data_type == data_type::f32) { - vsExp(n, a, r); - return; - } -#endif - parallel_nd(n, [&](int c) { r[c] = expf(a[c]); }); -} - -template -void ref_softmax_fwd_t::_sum(int n, const data_t *x, - data_t *sum_data) const { -#ifdef USE_CBLAS - // Here we are summing x's eg. e^z , which are positives - // so we can use BLAS ASUM - if (data_type == data_type::f32) { - sum_data[0] = cblas_sasum(n, x, 1); - return; - } -#endif - data_t tsum = static_cast(0); - PRAGMA_OMP_SIMD(reduction(+ : tsum)) - for (int c = 0; c < n; ++c) - tsum += x[c]; - sum_data[0] = tsum; -} - -template -void ref_softmax_fwd_t::_scal(int n, data_t alpha, data_t *x) const { -#ifdef USE_CBLAS - if (data_type == data_type::f32) { - cblas_sscal(n, alpha, x, 1); - return; - } -#endif - parallel_nd(n, [&](int c) { x[c] *= alpha; }); -} - -template struct ref_softmax_fwd_t; - - -// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW) -template -void ref_softmax_bwd_t::execute_backward_dense( - const exec_ctx_t &ctx) const { - auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - parallel_nd(outer_size_, [&](int ou) { - data_t sbr = 0; - size_t off = channels_*ou; - for (int c = 0; c < channels_; c++) { - size_t loff = off + c; - data_t ldata = dst[loff]; - sbr += diff_dst[loff]*ldata; - diff_src[loff] = ldata; - } - - for(int c=0; c < channels_ ; ++c) { - size_t loff = off + c; - diff_src[loff] *= (diff_dst[loff] - sbr); - } - }); -} - -template -void ref_softmax_bwd_t::execute_backward_generic( - const exec_ctx_t &ctx) const { - auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); - - const memory_desc_wrapper diff_d(pd()->diff_src_md()); - const memory_desc_wrapper data_d(pd()->dst_md()); - - const size_t dim = channels_ * inner_size_; - - parallel_nd(outer_size_, [&](int ou) { - for (int in = 0; in < inner_size_; in++) { - data_t sbr = 0; - for (int c = 0; c < channels_; c++) { - size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); - size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in); - sbr += diff_dst[off_diff] * dst[off_data]; - } - - for(int c=0; c < channels_ ; ++c) { - size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); - size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in); - diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr); - } - } - }); -} - -template struct ref_softmax_bwd_t; - -} -} -} - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp deleted file mode 100644 index 5cb74d800..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp +++ /dev/null @@ -1,186 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_SOFTMAX_HPP -#define CPU_REF_SOFTMAX_HPP - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "cpu_softmax_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct ref_softmax_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_softmax_fwd_pd_t { - using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t); - - status_t init() { - bool ok = true - && is_fwd() - && src_md()->data_type == data_type - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - init_scratchpad(); - - return status::success; - } - - private: - void init_scratchpad() { - const int inner_size = utils::array_product( - desc()->data_desc.dims + desc()->softmax_axis + 1, - desc()->data_desc.ndims - desc()->softmax_axis - 1); - - if (inner_size > 1) { - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(memory_tracking::names::key_softmax_reduction, - sizeof(data_t) * 2 * inner_size); - } - } - }; - - ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { - auto ndims = pd()->desc()->data_desc.ndims; - auto dims = pd()->desc()->data_desc.dims; - auto axis = pd()->desc()->softmax_axis; - - outer_size_ = utils::array_product(dims, axis); - channels_ = dims[axis]; - inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); - - const memory_desc_wrapper data_d(pd()->src_md()); - - bool no_axis_blocking = true; - for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk) - if (data_d.blocking_desc().inner_idxs[iblk] == axis) - no_axis_blocking = false; - - use_dense_ = inner_size_ == 1 && data_d.is_dense() - && no_axis_blocking - && data_d.blocking_desc().strides[axis] == 1; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if (use_dense_) - execute_forward_dense(ctx); - else - execute_forward_generic(ctx); - return status::success; - } - -private: - void execute_forward_dense(const exec_ctx_t &ctx) const; - void execute_forward_generic(const exec_ctx_t &ctx) const; - - void _max(int n, const data_t *x, data_t *max_data) const; - void _sub(int n, data_t alpha, const data_t *x, data_t *y) const; - void _exp(int n, const data_t *a, data_t *r) const; - void _sum(int n, const data_t *x, data_t *sum_data) const; - void _scal(int n, data_t alpha, data_t *x) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - bool use_dense_; - int outer_size_, channels_, inner_size_; -}; - -template -struct ref_softmax_bwd_t: public cpu_primitive_t { - struct pd_t: public cpu_softmax_bwd_pd_t { - using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; - - DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t); - - status_t init() { - bool ok = true - && !is_fwd() - && utils::everyone_is(data_type, - dst_md()->data_type, - diff_src_md()->data_type) - && attr()->has_default_values(); - if (!ok) return status::unimplemented; - - return status::success; - } - }; - - ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) { - auto dims = pd()->desc()->diff_desc.dims; - auto axis = pd()->desc()->softmax_axis; - auto ndims = pd()->desc()->diff_desc.ndims; - - outer_size_ = utils::array_product(dims, axis); - channels_ = dims[axis]; - inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); - - const memory_desc_wrapper data_d(pd()->dst_md()); - const memory_desc_wrapper diff_d(pd()->diff_dst_md()); - - bool no_axis_blocking = true; - for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk) - if (diff_d.blocking_desc().inner_idxs[iblk] == axis) - no_axis_blocking = false; - - use_dense_ = true - && inner_size_ == 1 - && diff_d == data_d - && diff_d.is_dense() - && no_axis_blocking - && diff_d.blocking_desc().strides[axis] == 1; - } - - typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - if (use_dense_) - execute_backward_dense(ctx); - else - execute_backward_generic(ctx); - return status::success; - } - -private: - void execute_backward_dense(const exec_ctx_t &ctx) const; - void execute_backward_generic(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - bool use_dense_; - int outer_size_, channels_, inner_size_; -}; - - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp deleted file mode 100644 index 3b2a75d99..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp +++ /dev/null @@ -1,101 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef REF_SUM_HPP -#define REF_SUM_HPP - -#include "reorder_pd.hpp" - -#include "cpu_sum_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct ref_sum_t: public cpu_primitive_t { - struct pd_t: public cpu_sum_pd_t { - using cpu_sum_pd_t::cpu_sum_pd_t; - - pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) { - for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) - reorder_pds_.push_back( - (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); - } - - ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } - - DECLARE_SUM_PD_T("ref:any", ref_sum_t); - - status_t init() { - bool ok = cpu_sum_pd_t::init() == status::success; - if (!ok) return status::unimplemented; - - for (int i = 0; i < n_; ++i) { - auto r_impls = engine_->get_reorder_implementation_list(); - for (auto r = r_impls; *r; ++r) { - primitive_attr_t attr; - attr.output_scales_.set(scales_[i]); - if (i != 0) attr.post_ops_.append_sum(1.0); - - reorder_pd_t *r_pd; - if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), - engine_, dst_md()) == status::success) { - r_pd->init_info(); - reorder_pds_.push_back(r_pd); - break; - } - } - } - - ok = reorder_pds_.size() == (size_t)n_; - return ok ? status::success : status::unimplemented; - } - - nstl::vector reorder_pds_; - }; - - ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) { - const int n = pd()->n_inputs(); - reorders_.resize(n); - for (int i = 0; i < n; ++i) - pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); - } - - ~ref_sum_t() { for (auto &r: reorders_) delete r; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto n = pd()->n_inputs(); - for (int i = 0; i < n; ++i) { - exec_args_t r_args; - r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); - r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); - exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); - reorders_[i]->execute(r_ctx); - } - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - nstl::vector reorders_; -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp deleted file mode 100644 index 537084db9..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - * Common for RNN and LSTM cell execution - */ -#include "ref_rnn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { -using namespace rnn_utils; - -template -rnn_cell_execution_sig( - (_ref_rnn_common_t::cell_execution)) { - if (!rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, - rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, - states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, - rnn.gates_ws_ld); - } - (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, - 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, - rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); - - if (rnn_postgemm_ != nullptr) - rnn_postgemm_->execute(rnn, ws_gates_, states_t_l_, c_states_t_l_, - states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, - diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, - ws_cell_); - else - (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, - states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, - diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, - ws_cell_); -} -template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); -template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); - -template <> -rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) { - ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); - (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, - states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, - diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, - ws_cell_); - - /// bwd by data on the cell - (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, - 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, - 0.0, diff_states_t_l_, rnn.states_ws_ld); - - if (!rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, - rnn.n_gates * rnn.dic, 1.0, w_layer_[0], - rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, - &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); - - /// bwd by weights on the cell - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, - rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, - diff_w_layer_, rnn.diff_weights_layer_ld); - } - - if (!rnn.merge_gemm_iter) - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, - rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, - diff_w_iter_, rnn.diff_weights_iter_ld); - - /// bwd by bias we just accumulate diffs from the gates - gates_reduction(rnn, ws_gates_, diff_bias_); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp deleted file mode 100644 index e1a61d4c6..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp +++ /dev/null @@ -1,180 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - * Cell execution GRU - */ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_rnn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::math; -using namespace rnn_utils; - -#define AOC array_offset_calculator -template <> -rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_[0]); - ws_states_aoc_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); - - // 1. gemm Wx[0-2],x - if (!rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, - rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, - states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, - rnn.gates_ws_ld); - } - - // 2. gemm Wh[0-1],h - (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb, - rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, - rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); - - // 3. activation zt and rt + elemwise multiplication rt,ht-1 - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); - ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); - states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); - } - }); - - // 4. gemm Wh[2],h~t - (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1], - rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0, - &(ws_gates(0, 2, 0)), rnn.gates_ws_ld); - - // 5. activation h~t + calculate ht - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); - states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) - + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); - } - }); -} - -template <> -rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) { - assert(!"GRU int8 is not supported"); -} - -template <> -rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - ws_states_aoc_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); - ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); - ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); - ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); - ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); - - // use state memory for intermediate computations - // TODO: use cell ws for that - float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0)); - float *hG1_ = dhG1_; - AOC dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld); - AOC hG1(hG1_, rnn.states_nld, rnn.states_ws_ld); - - // 1. calculate dG2, dG1, and part of dht-1 - // dG2^ = dh * (1 - G0) * (1 - G2^2) - // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) - // dht-1 (part) = dh * G0 - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - float h = states_tm1_l(i, j); - float dHt = diff_states_tp1_l(0, i, j) - + diff_states_t_lp1(rnn.n_states, i, j); - float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt - * one_m_square(ws_gates(i, 2, j)); - float dG0 = (h - ws_gates(i, 2, j)) * dHt - * x_m_square(ws_gates(i, 0, j)); - - diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); - ws_gates(i, 0, j) = dG0; - ws_gates(i, 2, j) = dG2; - } - }); - - // 2. calculate intermediate d(hG1) - // d(hG1) = dG2 * W2h^t - (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1], - rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0, - dhG1_, rnn.states_ws_ld); - - // 3. calculate dG1^ and part of dht-1 - // dG1^ = d(hG1) * h * G1 * (1 - G1) - // dht-1 (part) += d(hG1) * G1 - // h * G1 (required for dWh) - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - float h = states_tm1_l(i, j); - float G1 = ws_gates(i, 1, j); - diff_states_t_l(0, i, j) += dhG1(i, j) * G1; - ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1); - hG1(i, j) = G1 * h; - } - }); - - // 4. calculate diff weights - // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) - gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, - rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, - rnn.diff_weights_iter_ld); - gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)), - rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0, - &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld); - - // 5. calculate diff states - // dht-1 += dG1 * W1h + dG0 * W0h - (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, - (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0], - rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0, - diff_states_t_l_, rnn.states_ws_ld); - - if (!rnn.merge_gemm_layer) { - // dWx += [dG0 dG1 dG2] * [x] - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, - rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, - diff_w_layer_, rnn.diff_weights_layer_ld); - // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x - (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, - rnn.n_gates * rnn.dic, 1.0, w_layer_[0], - rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, - &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld); - } - - // 6. calculate diff bias - gates_reduction(rnn, ws_gates_, diff_bias_); -} -#undef AOC - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp deleted file mode 100644 index 8dea8c90a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp +++ /dev/null @@ -1,170 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - * Cell execution GRU with linear before reset - */ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_rnn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::math; -using namespace rnn_utils; -#define AOC array_offset_calculator - -template <> -rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_); - ws_states_aoc_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); - ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_); - AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); - - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); - ws_gates(i, 0, j) = logistic_fwd( - ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j)); - ws_gates(i, 1, j) = logistic_fwd( - ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j)); - ws_gates(i, 2, j) = tanh_fwd( - ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j)); - states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) - + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); - if (rnn.is_training) - ws_Wh_b(i, j) = Wh_b; - } - }); -} - -template <> -rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) { - assert(!"GRU LBR int8 is not supported"); -} - -template <> -rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) { - if (!rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, - rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, - states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, - rnn.gates_ws_ld); - } - (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, - 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, - rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld); - (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, - states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, - diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, - ws_cell_); -} - -template <> -rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { - assert(!"GRU LBR int8 is not supported"); -} - -template <> -rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); - ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); - ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); - ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); - ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); - AOC ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); - - // 1. calculate dG1 dG2 dG3 - // dG0 = (dht - G2) * dht * (1 - G0) * G0 - // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 - // dG2 = (1 - G0) * dht * (1 - G2*G2) - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - float h = states_tm1_l(i, j); - float dHt = diff_states_tp1_l(0, i, j) - + diff_states_t_lp1(rnn.n_states, i, j); - float dG0 = (h - ws_gates(i, 2, j)) * dHt - * x_m_square(ws_gates(i, 0, j)); - float dG2 = (1.0f - ws_gates(i, 0, j)) - * one_m_square(ws_gates(i, 2, j)) * dHt; - float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); - - diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); - ws_gates(i, 2, j) = dG2; - ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j); - ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; - ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; - } - }); -} - -template <> -rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { - ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); - ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); - - (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, - states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, - diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, - ws_cell_); - - if (!rnn.merge_gemm_layer) { - // dx = dG * Wx^t - (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, - rnn.n_gates * rnn.dic, 1.0, w_layer_[0], - rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, - &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); - // dWx += dG^t * x - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, - rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, - diff_w_layer_, rnn.diff_weights_layer_ld); - } - // dh += dGr * Wh^t - (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, - 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld, - 1.0, diff_states_t_l_, rnn.states_ws_ld); - - // dWh += dGr^t * h - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_, - rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, - rnn.diff_weights_layer_ld); - - // db1-3 += e * dG - // db4 += e * (r * dG2) - gates_reduction(rnn, ws_gates_, diff_bias_); - - parallel_nd(rnn.dic, [&](int j) { - for (int i = 0; i < rnn.mb; i++) { - diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j); - } - }); -} - -#undef AOC - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp deleted file mode 100644 index a15ba00d4..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp +++ /dev/null @@ -1,143 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - * Cell execution LSTM - */ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "../simple_q10n.hpp" -#include "ref_rnn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::math; -using namespace rnn_utils; - -template <> -rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_); - ws_states_aoc_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); - ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); - - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); - ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); - ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); - ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j)); - - float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j) - + ws_gates(i, 0, j) * ws_gates(i, 2, j); - states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp); - c_states_t_l(i, j) = tmp; - } - }); -} - -template <> -rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) { - ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_); - ws_states_aoc_u8_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); - ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); - - float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - - auto q_d = [&](float f) { - float qf = f * data_scale + data_shift; - return qz_a1b0()(qf); - }; - - auto deq_w = [&](acc_data_t s, int gate, int j) { - return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ? - saturate(s) * (1.f / (weights_scales[0] * data_scale)) : - saturate(s) * (1.f / (weights_scales[gate * rnn.dic + j] - * data_scale)); - }; - - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - float G0 = logistic_fwd( - deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j)); - float G1 = logistic_fwd( - deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j)); - float G2 = tanh_fwd( - deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j)); - float G3 = logistic_fwd( - deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j)); - float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2; - states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp)); - c_states_t_l(i, j) = tmp; - } - }); -} - -template <> -rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_); - ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); - ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); - ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); - ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); - ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); - - parallel_nd(rnn.mb, [&](int i) { - PRAGMA_OMP_SIMD() - for (int j = 0; j < rnn.dic; j++) { - float Ct = c_states_t_l(i, j); - /// @todo save it in the workspace in fwd pass or recompute it to - /// save bw - float tanhCt = tanh_fwd(Ct); - // we have 2 incoming diffs on Ht - float dHt = diff_states_tp1_l(0, i, j) - + diff_states_t_lp1(rnn.n_states, i, j); - float dCt = diff_states_tp1_l(1, i, j) - + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; - - float dG1 = c_states_tm1_l(i, j) * dCt - * x_m_square(ws_gates(i, 1, j)); - float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); - float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); - float dG2 - = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); - - diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j); - - ws_gates(i, 0, j) = dG0; - ws_gates(i, 1, j) = dG1; - ws_gates(i, 2, j) = dG2; - ws_gates(i, 3, j) = dG3; - } - }); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp deleted file mode 100644 index 4536e8dfa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp +++ /dev/null @@ -1,113 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - * Cell execution of Vanilla RNN - */ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_rnn.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::math; -using namespace rnn_utils; - -template <> -float activation( - float dd, float s, float alpha, float cliping) { - return relu_fwd(s, alpha); -} - -template <> -float activation( - float dd, float s, float alpha, float cliping) { - return relu_bwd(dd, s, alpha); -} - -template <> -float activation( - float dd, float s, float alpha, float cliping) { - return tanh_fwd(s); -} - -template <> -float activation( - float dd, float s, float alpha, float cliping) { - return dd * one_m_square(s); -} - -template <> -float activation( - float dd, float s, float alpha, float cliping) { - return logistic_fwd(s); -} - -template <> -float activation( - float dd, float s, float alpha, float cliping) { - return dd * x_m_square(s); -} - -template <> -rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_); - ws_states_aoc_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); - - parallel_nd(rnn.mb, [&](int i) { - for (int j = 0; j < rnn.dic; j++) { - const float h - = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); - ws_gates(i, 0, j) = states_t_l(i, j) = h; - } - }); -} - -template <> -rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) { - assert(!"VANILLA RNN int8 is not supported"); -} - -template <> -rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) { - ws_gates_aoc_t ws_gates(rnn, ws_gates_); - bias_aoc_t bias(rnn, bias_); - ws_states_aoc_t states_t_l(rnn, states_t_l_); - ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); - ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); - ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); - ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); - - parallel_nd(rnn.mb, [&](int i) { - for (int j = 0; j < rnn.dic; ++j) { - const float dH = diff_states_t_lp1(rnn.n_states, i, j) - + diff_states_tp1_l(0, i, j); - auto g = ws_gates(i, 0, j); - ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); - } - }); -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp deleted file mode 100644 index b39427caf..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp +++ /dev/null @@ -1,191 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_RNN_PD_HPP -#define CPU_RNN_PD_HPP - -#include "c_types_map.hpp" -#include "nstl.hpp" -#include "rnn_pd.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "rnn_utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { - using rnn_fwd_pd_t::rnn_fwd_pd_t; - -protected: - status_t set_default_params() { - using namespace format_tag; - if (src_layer_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); - if (dst_layer_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); - - // Optional parameters - if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); - if (with_bias() && bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); - if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); - - return status::success; - } - - status_t check_layout_consistency() { - using namespace format_tag; - using namespace data_type; - using namespace types; - - auto is_blocked = [&](memory_desc_t md, int ndims) { - return md.format_kind == format_kind::blocked && md.ndims == ndims; - }; - - bool ok = true; - ok = ok && is_blocked(src_layer_md_, 3) - && is_blocked(dst_layer_md_, 3); - ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), - is_blocked(src_iter_md_, 5)) - && IMPLICATION(!is_zero_md(&dst_iter_md_), - is_blocked(dst_iter_md_, 5)); - - if (weights_layer_md_.format_kind == format_kind::rnn_packed) - ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format - == mkldnn_ldigo_p); - else - ok = ok && rnn_utils::is_ldigo(&weights_layer_md_); - - if (weights_iter_md_.format_kind == format_kind::rnn_packed) - ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format - == mkldnn_ldigo_p); - else - ok = ok && rnn_utils::is_ldigo(&weights_iter_md_); - - ok = ok && IMPLICATION(!is_zero_md(&bias_md_), - memory_desc_matches_tag(bias_md_, ldgo)); - - /* Int8 is supported only for packed weights */ - data_type_t weights_iter_dt = weights_iter_md_.data_type; - data_type_t weights_layer_dt = weights_layer_md_.data_type; - ok = ok && IMPLICATION( - weights_iter_dt == s8, weights_iter_md_.format_kind - == format_kind::rnn_packed); - ok = ok && IMPLICATION( - weights_layer_dt == s8, weights_layer_md_.format_kind - == format_kind::rnn_packed); - - return ok ? status::success : status::unimplemented; - } -}; - -struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { - using rnn_bwd_pd_t::rnn_bwd_pd_t; - -protected: - status_t set_default_params() { - using namespace format_tag; - if (src_layer_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); - if (dst_layer_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); - - if (diff_src_layer_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); - if (diff_weights_layer_md_.format_kind == format_kind::any) { - CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); - CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); - } - if (diff_weights_iter_md_.format_kind == format_kind::any) { - CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); - CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); - } - if (diff_dst_layer_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); - - // Optional parameters - if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); - if (with_bias() && bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); - if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); - - if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc)); - if (with_bias() && diff_bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); - if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc)); - - return status::success; - } - - status_t check_layout_consistency() { - using namespace format_tag; - using namespace types; - - auto is_blocked = [&](memory_desc_t md, int ndims) { - return md.format_kind == format_kind::blocked && md.ndims == ndims; - }; - - bool ok = true; - ok = ok && is_blocked(src_layer_md_, 3) - && is_blocked(dst_layer_md_, 3); - ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), - is_blocked(src_iter_md_, 5)) - && IMPLICATION(!is_zero_md(&dst_iter_md_), - is_blocked(dst_iter_md_, 5)); - - if (weights_layer_md_.format_kind == format_kind::rnn_packed) - ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format - == mkldnn_ldgoi_p); - else - ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_); - - if (weights_iter_md_.format_kind == format_kind::rnn_packed) - ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format - == mkldnn_ldgoi_p); - else - ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_); - - ok = ok && IMPLICATION(!is_zero_md(&bias_md_), - memory_desc_matches_tag(bias_md_, ldgo)); - - ok = ok && is_blocked(diff_src_layer_md_, 3) - && is_blocked(diff_dst_layer_md_, 3); - ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_), - is_blocked(diff_src_iter_md_, 5)) - && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), - is_blocked(diff_dst_iter_md_, 5)); - - ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) - && rnn_utils::is_ldigo(&diff_weights_iter_md_); - ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_), - memory_desc_matches_tag(diff_bias_md_, ldgo)); - - return ok ? status::success : status::unimplemented; - } -}; -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp deleted file mode 100644 index 09445648a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp +++ /dev/null @@ -1,401 +0,0 @@ -/******************************************************************************* -* Copyright 2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - * Cell execution LSTM - */ - -#include "rnn_utils.hpp" -#include "../jit_generator.hpp" -#include "../jit_uni_eltwise.hpp" -#include "c_types_map.hpp" -#include "utils.hpp" - -#include "mkldnn_thread.hpp" - - -namespace mkldnn { -namespace impl { -namespace cpu { - -struct jit_uni_rnn_postgemm_kernel : public jit_generator { - - typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_, - void *c_states_t_l_, void *c_states_tm1_l_); - - jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){} - - virtual void init() = 0; - -template - rnn_elemwise_sig(execute) { - rnn_utils::ws_gates_aoc ws_gates(rnn, ws_gates_); - rnn_utils::bias_aoc_t bias(rnn, bias_); - rnn_utils::ws_states_aoc states_t_l(rnn, states_t_l_); - rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); - rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); - - // Todo: add parallelization on dic for the batch 1 case - // Assumption: the kernel runs a loop on dic elements - parallel_nd(rnn.mb, [&](int i) { - auto b_ = &bias(0, 0); - auto g_ = &ws_gates(i, 0, 0); - auto s_tl_ = &states_t_l(i, 0); - auto c_tl_ = &c_states_t_l(i, 0); - auto c_tm1l_ = &c_states_tm1_l(i, 0); - kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_); - }); - } - -protected: - kernel_t kernel_; - const rnn_utils::rnn_conf_t &rnn_; - const primitive_attr_t *attr_; -}; - -template -struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel -{ - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd) - - typedef typename utils::conditional::type acc_data_t; - typedef typename utils::conditional, - jit_uni_eltwise_injector_f32>::type injector_t; - - jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr) - : jit_uni_rnn_postgemm_kernel(rnn, attr){} - - void init() override { - // we use rax for both constant tables as they use the same table - sigmoid_injector_ = new injector_t(this, - alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax); - tanh_injector_ = new injector_t(this, - alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax); - generate(); - kernel_ = (kernel_t) this->getCode(); - } - -protected: - injector_t *sigmoid_injector_; - injector_t *tanh_injector_; - - // register size in bytes - using Vmm = typename jit_uni_eltwise_injector_f32::Vmm; - size_t vlen = cpu_isa_traits::vlen; - size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen; - size_t cstate_dt_size = sizeof(float); - size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float); - size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float); - size_t qscale_dt_size = sizeof(float); - size_t bias_dt_size = sizeof(float); - - void generate() { - using namespace Xbyak; - - int mask = attr_->rnn_weights_qparams_.mask_; - float *weights_scales = attr_->rnn_weights_qparams_.scales_; - float data_scale = attr_->rnn_data_qparams_.scale_; - float data_shift = attr_->rnn_data_qparams_.shift_; - - // Labels declaration - Label vector_loop_start_label, vector_loop_end_label; - Label rem_loop_start_label, rem_loop_end_label; - Label table_label; - - // Register map - Reg64 loop_cnt(r11); // loop counter - Reg64 table_reg(rbx); // table is used for data scale and shifts - Reg64 weights_scales_reg(r13); - // We skip vmm0 as it can be used by the injector for masks on sse4.2 - Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7); - - // constant table map - Address dscale_off_addr = ptr[table_reg]; - Address dshift_off_addr = ptr[table_reg + vlen]; - Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen]; - Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits::vlen]; - - // quantize from float to u8 - auto q_d = [&](Vmm f, Vmm tmp_vmm) { - uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm); - uni_vmulps(f, f, dscale_off_addr); // apply scale - uni_vaddps(f, f, dshift_off_addr); // apply shift - uni_vcvtps2dq(f, f); // convert to int32 - uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16 - uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation - // Note that the results are interleaved by 128 bit chunks, so we need to merge them together - switch (vlen) { - case 64: { //avx512 - Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx()); - uni_vmovups(tmpz, zmm_perm_mask_addr); - vpermd(fz, tmpz, fz); - break; } - case 32: { //avx - Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx()); - uni_vmovups(tmpy, ymm_perm_mask_addr); - vpermd(fy, tmpy, fy); - break; } - case 16: // sse: nothing to do - break; - default: assert(!"Unsupported case"); - }; - }; - - auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) { - if (packed) - uni_vrcpps(tmp, s); - else - uni_vrcpss(tmp, s); // prevent divide by zero - // we add one Newton iteration - uni_vmulps(s, s, tmp); - uni_vmulps(s, s, tmp); // s <- s * tmp^2 - uni_vaddps(tmp, tmp, tmp); - uni_vsubps(tmp, tmp, s); - uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 - }; - - // dequantize from s32 to float - auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) { - // TODO: if mask is 0 precompute mul and inverse - if (mask == 0) - uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); - else - uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]); - uni_vcvtdq2ps(s, s); - uni_vmulps(tmp1, tmp1, dscale_off_addr); - fast_recip(tmp1, tmp2, packed); - uni_vmulps(s, s, tmp1); - }; - - // We start code generations here - preamble(); - - // extract addresses passed as parameter -#ifdef _WIN32 - auto addr_ws_gates_reg = abi_param1; - auto addr_bias_reg = abi_param2; - auto addr_states_t_l_reg = abi_param3; - auto addr_c_states_tm1_l_reg = abi_param4; - auto addr_c_states_t_l_reg = r10; - // Here we cannot use rbp to have initial stack pointer so we - // use rsp and offset it with the size of pushed registers in - // preamble - mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]); -#else - auto addr_ws_gates_reg = abi_param1; - auto addr_bias_reg = abi_param2; - auto addr_states_t_l_reg = abi_param3; - auto addr_c_states_tm1_l_reg = abi_param4; - auto addr_c_states_t_l_reg = abi_param5; -#endif - - // initialize registers with addresses and constants - mov(table_reg, table_label); - mov(weights_scales_reg, size_t(weights_scales)); - // both sigmoid and tanh use the same table so load address just once in rax - sigmoid_injector_->load_table_addr(); - - mov(loop_cnt, rnn_.dic * gate_dt_size); - cmp(loop_cnt, vlen); - jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); - - L(vector_loop_start_label); - { - // load G0 G1 G2 G3 - uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); - uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); - uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); - uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); - - // dequantize the gates from s32 to f32 if needed - if (src_data_t == data_type::u8){ - deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true); - deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true); - deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true); - deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true); - } - - // add biases - uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); - uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); - uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); - uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); - - // inject eltwise code - sigmoid_injector_->compute_vector(G0.getIdx()); - sigmoid_injector_->compute_vector(G1.getIdx()); - tanh_injector_->compute_vector(G2.getIdx()); - sigmoid_injector_->compute_vector(G3.getIdx()); - - // compute c_states_t_l = G1 * c_tm1_l + G0 * G2 - uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]); - uni_vmulps(tmp1_vmm, tmp1_vmm, G1); - uni_vfmadd231ps(tmp1_vmm, G0, G2); - uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm); - - // states_t_l = G3 * tanh(c_states_t_l) - tanh_injector_->compute_vector(tmp1_vmm.getIdx()); - uni_vmulps(tmp1_vmm, tmp1_vmm, G3); - - // if int8, we quantize the resulting state - if (src_data_t == data_type::u8) - q_d(tmp1_vmm, tmp2_vmm); - - // write back the result - if(vlen_dst == vlen) - uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm); - else - // we write only 1/4 of the register - switch(vlen_dst){ - case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; - case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; - case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; - default: - assert(!"Unsuported vector length for quantization"); - } - - // increment address pointers - add(addr_ws_gates_reg, vlen); - add(addr_bias_reg, vlen); - add(addr_states_t_l_reg, vlen_dst); - add(addr_c_states_tm1_l_reg, vlen); - add(addr_c_states_t_l_reg, vlen); - if (mask != 0) - add(weights_scales_reg, vlen); - - // increment loop counter - sub(loop_cnt, vlen); - cmp(loop_cnt, vlen); - jge(vector_loop_start_label); - } - L(vector_loop_end_label); - - cmp(loop_cnt, 0); - je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); - // Same code as above, we just use movuss for accessing inputs - // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar - L(rem_loop_start_label); - { - // remaping registers to Xmms - Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx()); - Xmm tmp1s_vmm(tmp1_vmm.getIdx()); - - // load G0 G1 G2 G3 - uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); - uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); - uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); - uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); - - // dequantize the gates from s32 to f32 if needed - if (src_data_t == data_type::u8){ - deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false); - deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false); - deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false); - deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false); - } - - // add biases - uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); - uni_vaddps(G0s, G0s, tmp1s_vmm); - uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); - uni_vaddps(G1s, G1s, tmp1s_vmm); - uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); - uni_vaddps(G2s, G2s, tmp1s_vmm); - uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); - uni_vaddps(G3s, G3s, tmp1s_vmm); - - // inject eltwise code - sigmoid_injector_->compute_vector(G0s.getIdx()); - sigmoid_injector_->compute_vector(G1s.getIdx()); - tanh_injector_->compute_vector(G2s.getIdx()); - sigmoid_injector_->compute_vector(G3s.getIdx()); - - // compute c_states_t_l = G1 * c_tm1_l + G0s * G2 - uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]); - uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s); - uni_vfmadd231ps(tmp1s_vmm, G0s, G2s); - uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm); - - // states_t_l = G3 * tanh(c_states_t_l) - tanh_injector_->compute_vector(tmp1s_vmm.getIdx()); - uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s); - - // if int8, we quantize the resulting state - if (src_data_t == data_type::u8) - q_d(tmp1_vmm, tmp2_vmm); - - // write back the result - if(vlen_dst == vlen) - uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm); - else - // we write only 1/4 of the register - switch(vlen_dst){ - case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; - case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; - case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; - default: - assert(!"Unsuported vector length for quantization"); - } - - // increment address pointers - add(addr_ws_gates_reg, gate_dt_size); - add(addr_bias_reg, bias_dt_size); - add(addr_states_t_l_reg, hstate_dt_size); - add(addr_c_states_tm1_l_reg, cstate_dt_size); - add(addr_c_states_t_l_reg, cstate_dt_size); - if (mask != 0) - add(weights_scales_reg, qscale_dt_size); - - // increment loop counter - sub(loop_cnt, gate_dt_size); - cmp(loop_cnt, 0); - jg(rem_loop_start_label); - - } - L(rem_loop_end_label); - - postamble(); - - // Again, only one table is needed and shared between sigmoid and tanh - sigmoid_injector_->prepare_table(false); - tanh_injector_->prepare_table(true); - - L(table_label); - { - for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale)); - for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift)); - // perm mask for ymm - dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7); - // perm mask for zmm - dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7); - dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14); - } - } - -}; - -template struct jit_uni_lstm_postgemm_kernel_fwd; -template struct jit_uni_lstm_postgemm_kernel_fwd; -template struct jit_uni_lstm_postgemm_kernel_fwd; - -template struct jit_uni_lstm_postgemm_kernel_fwd; -template struct jit_uni_lstm_postgemm_kernel_fwd; -template struct jit_uni_lstm_postgemm_kernel_fwd; -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp deleted file mode 100644 index ead536816..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp +++ /dev/null @@ -1,788 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - General architecture - - for diff states, we have n_states + 1 as we have n_states diff - to propagate to the previous iteration and 1 states to propagate - to the previous layer - index 0 is dh for cell(t-1, l) to consume - index 1 is dc for cell(t-1, l) to consume - index 2 is dh for cell(t, l-1) to consume - this indexing enables to have the same indexing for states in elemwise - function - only the cell execution function should be impacted - - */ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_rnn.hpp" -#include "../gemm/gemm.hpp" -#include "../simple_q10n.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::memory_tracking::names; -using namespace rnn_utils; -#define AOC array_offset_calculator - -template -void _ref_rnn_common_t::gates_reduction( - const rnn_conf_t &rnn, const acc_data_t *ws_gates_, - float *diff_bias_) const { - auto body = [&](int i, int k) { - for (int j = 0; j < rnn.mb; j++) - diff_bias_[i * rnn.dic + k] - += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; - }; - - // @todo block k on simd-width -#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ - /* icc 17.0 has a problem with simd collapse */ \ - && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) -#pragma omp parallel for simd collapse(2) - for (int i = 0; i < rnn.n_gates; i++) - for (int k = 0; k < rnn.dic; k++) - body(i, k); -#else - parallel_nd(rnn.n_gates, rnn.dic, body); -#endif -} - -template -rnn_gemm_sig((_ref_rnn_common_t::gemm)) { - assert(ldA * ldB * ldC != 0); - extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, - &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); -} - -template <> -rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { - assert(!"non packed gemm is disabled for int8"); -} - -template -rnn_gemm_sig((_ref_rnn_common_t::packed_gemm)) { -#if (USE_MKL_PACKED_GEMM) - assert(transA == 'N'); - cblas_sgemm_compute(CblasColMajor, CblasPacked, - (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, - ldB, beta, c_, ldC); -#else - UNUSED(transA); - UNUSED(transB); - UNUSED(m); - UNUSED(n); - UNUSED(k); - UNUSED(alpha); - UNUSED(ldA); - UNUSED(b_); - UNUSED(ldB); - UNUSED(beta); - UNUSED(c_); - UNUSED(ldC); - assert(!"packed gemm is disabled"); -#endif -} - -template <> -rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { -#if (USE_MKL_PACKED_GEMM) - int8_t offseta = 0, offsetb = 0; - int32_t offsetc = 0; - cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, - CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, - ldB, offsetb, beta, c_, ldC, &offsetc); -#else - UNUSED(transA); - UNUSED(transB); - UNUSED(m); - UNUSED(n); - UNUSED(k); - UNUSED(alpha); - UNUSED(ldA); - UNUSED(b_); - UNUSED(ldB); - UNUSED(beta); - UNUSED(c_); - UNUSED(ldC); - assert(!"packed gemm is disabled"); -#endif -} - -//*************** Grid computations strategy: linear ***************// -template -rnn_grid_execution_sig( - (_ref_rnn_common_t::linear_execution)) { - AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); - AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); - AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, - (rnn.n_states + 1), rnn.n_iter + 1, - rnn.states_nld * rnn.states_ws_ld); - AOC ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, - rnn.gates_nld * rnn.gates_ws_ld); - AOC weights_input( - weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); - AOC weights_states( - weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); - AOC bias( - bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); - AOC diff_weights_layer(diff_weights_layer_, rnn.n_layer, - rnn.n_dir, - rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); - AOC diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, - rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); - AOC diff_bias( - diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); - AOC ws_grid( - ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); - - // We run the grid of computation - for (int dir = 0; dir < rnn.n_dir; dir++) { - for (int j = 0; j < rnn.n_layer; j++) { - int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; - - if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, - rnn.mb * rnn.n_iter, rnn.slc, 1.0, - weights_input(lay, dir, 0), rnn.weights_iter_ld, - &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, - &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); - } - - for (int i = 0; i < rnn.n_iter; i++) { - int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; - (this->*cell_func)(rnn, - &(ws_states(lay + 1, dir, iter + 1, 0)), - &(ws_c_states(lay + 1, dir, iter + 1, 0)), - &(ws_diff_states(lay, dir, 0, iter, 0)), - &(weights_input(lay, dir, 0)), - &(weights_states(lay, dir, 0)), - &(bias(lay, dir, 0)), - &(ws_states(lay, dir, iter + 1, 0)), - &(ws_states(lay + 1, dir, iter, 0)), - &(ws_c_states(lay + 1, dir, iter, 0)), - &(ws_diff_states(lay + 1, dir, 0, iter, 0)), - &(ws_diff_states(lay, dir, 0, iter + 1, 0)), - &(diff_weights_layer(lay, dir, 0)), - &(diff_weights_iter(lay, dir, 0)), - &(diff_bias(lay, dir, 0)), - &(ws_gates(lay, dir, iter, 0)), - &(ws_grid(lay, dir, iter, 0)), - ws_cell_); - } - - if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, - rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), - rnn.weights_layer_ld, - (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), - rnn.gates_ws_ld, 0.0, - (acc_data_t *)(&(ws_diff_states( - lay, dir, rnn.n_states, 0, 0))), - rnn.states_ws_ld); - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, - rnn.mb * rnn.n_iter, 1.0, - (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), - rnn.gates_ws_ld, - (src_data_t *)(&(ws_states(lay, dir, 1, 0))), - rnn.states_ws_ld, 1.0, - (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), - rnn.diff_weights_layer_ld); - } - if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, - rnn.mb * rnn.n_iter, 1.0, - (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), - rnn.gates_ws_ld, - (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), - rnn.states_ws_ld, 1.0, - (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), - rnn.diff_weights_iter_ld); - } - } - } -} - -//********* GRID computations strategy: utility functions **********// - -template -void _ref_rnn_common_t::copy_init_layer( - const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, - float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, - const float *__restrict diff_dst_layer_) const { - - AOC ws_states( - ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - auto xt_d = memory_desc_wrapper(pd()->src_md(0)); - - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto xxt = xt_ + xt_d.blk_off(it, b); - src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); - src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); - if (rnn.exec_dir != r2l) - for (int c = 0; c < rnn.slc; c++) - ws_l2r_ptr[c] = xxt[c]; - if (rnn.exec_dir != l2r) - for (int c = 0; c < rnn.slc; c++) - ws_r2l_ptr[c] = xxt[c]; - }); -} - -template <> -void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, - src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, - const float *diff_dst_layer_) const { - AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, - (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); - - switch (rnn.exec_dir) { - case bi_concat: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - ws_diff_states( - rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) - = diff_dst_layer_x[rnn.dic + s]; - } - }); - break; - case bi_sum: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - ws_diff_states( - rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) - = diff_dst_layer_x[s]; - } - }); - break; - case l2r: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - } - }); - break; - case r2l: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x = diff_dst_layer_ - + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - } - }); - break; - default: assert(!"Unsupported direction"); break; - } -} - -/* For int8 configuration, input iteration states may be of types f32 or u8 - * Internally h_state is always stored in u8 and c_state is always stored in f32 - * If input states are of type u8 then h state is copied and c state is dequantized - * If input states are of type f32 then h state is quantized and c_state is copied - * */ -template -template -void _ref_rnn_common_t::copy_init_iter( - const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, - float *__restrict ws_c_states_, float *__restrict ws_diff_states_, - const input_data_t *__restrict firstit_states_, - const float *__restrict diff_dst_iter_) const { - AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - - const bool quantize = pd()->with_src_iter() - && pd()->src_md(1)->data_type == data_type::f32 - && rnn.dt_conf != all_f32; - auto maybe_q = [&](input_data_t f) { - if (quantize) { - float qf = f * data_scale + data_shift; - return qz_a1b0()(qf); - } else - return (src_data_t)f; - }; - - const bool dequantize = pd()->with_src_iter() - && pd()->src_md(1)->data_type == data_type::u8; - auto maybe_deq = [&](input_data_t s) { - if (dequantize) - return (((float)s - data_shift) / data_scale); - else - return (float)s; - }; - auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); - if (firstit_states_) { - parallel_nd( - rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { - for (int s = 0; s < rnn.sic; s++) - ws_states(lay + 1, dir, 0, b, s) = maybe_q( - firstit_states_[firstit_states_d.blk_off( - lay, dir, 0, b, s)]); - if (pd()->cell_kind() == alg_kind::vanilla_lstm) - for (int s = 0; s < rnn.sic; s++) - ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( - firstit_states_[firstit_states_d.blk_off( - lay, dir, 1, b, s)]); - }); - } else { - parallel_nd( - rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { - for (int j = 0; j < rnn.sic; j++) { - ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; - ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; - } - }); - } -} - -template <> -template -void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, - src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, - const input_data_t *firstit_states_, - const float *diff_dst_iter_) const { - AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); - if (diff_dst_iter_) { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, - [&](int lay, int dir, int state, int b) { - array_copy(&(ws_diff_states( - lay, dir, state, rnn.n_iter, b, 0)), - diff_dst_iter_ - + diff_dst_iter_d.blk_off( - lay, dir, state, b), - rnn.dic); - }); - } else { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, - [&](int lay, int dir, int state, int i) { - for (int j = 0; j < rnn.dic; j++) - ws_diff_states(lay, dir, state, rnn.n_iter, i, j) - = 0.0f; - }); - } -} - -template -template -void _ref_rnn_common_t::copy_res_layer( - const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, - const src_data_t *ws_states_, const float *ws_diff_states_) const { - - auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); - AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - float shift = (pd()->attr()->rnn_data_qparams_.shift_); - float scale = (pd()->attr()->rnn_data_qparams_.scale_); - - const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 - && rnn.dt_conf != all_f32; - auto maybe_deq = [&](src_data_t s) { - if (dequantize) - return (dst_data_t)(((float)s - shift) / scale); - else - return (dst_data_t)s; - }; - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - int dir = 0; - if (rnn.exec_dir != r2l) { - for (int s = 0; s < rnn.dic; s++) { - dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] - = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); - } - dir = 1; - } - if (rnn.exec_dir != l2r) { - for (int s = 0; s < rnn.dic; s++) - switch (rnn.exec_dir) { - case bi_sum: - dst_layer_[dst_layer_d.blk_off(it, b, s)] - += maybe_deq(ws_states( - rnn.n_layer, dir, rnn.n_iter - it, b, s)); - break; - default: - dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] - = maybe_deq(ws_states( - rnn.n_layer, dir, rnn.n_iter - it, b, s)); - } - } - }); -} - -template <> -template -void ref_rnn_bwd_f32_t::copy_res_layer( - const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, - const src_data_t *ws_states_, const float *ws_diff_states_) const { - auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); - AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, - rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, - rnn.states_ws_ld); - - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - int dir = 0; - for (int s = 0; s < rnn.slc; s++) { - float *dst_addr = diff_src_layer_ - + diff_src_layer_d.blk_off( - (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, - b, dir * rnn.slc + s); - float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); - if (rnn.n_dir - 1) - res += ws_diff_states( - 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); - dst_addr[0] = res; - } - }); -} - -template -template -void _ref_rnn_common_t::copy_res_iter( - const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, - const src_data_t *ws_states_, float *ws_c_states_, - const float *ws_diff_states_) const { - auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); - AOC ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - AOC ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - - const bool quantize = pd()->with_dst_iter() - && pd()->dst_md(1)->data_type == data_type::u8 - && rnn.dt_conf != all_f32; - auto maybe_q = [&](float f) { - if (quantize) { - float qf = f * data_scale + data_shift; - return qz_a1b0()(qf); - } else - return (output_data_t)f; - }; - - const bool dequantize = pd()->with_dst_iter() - && pd()->dst_md(1)->data_type == data_type::f32 - && rnn.dt_conf != all_f32; - auto maybe_deq = [&](src_data_t s) { - if (dequantize) - return (output_data_t)(((float)s - data_shift) / data_scale); - else - return (output_data_t)s; - }; - if (dst_iter_) { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, - [&](int lay, int dir, int b) { - for (int s = 0; s < rnn.dic; s++) { - dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] - = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); - } - if (pd()->cell_kind() == alg_kind::vanilla_lstm) - for (int s = 0; s < rnn.dic; s++) { - dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] - = maybe_q(ws_c_states( - lay + 1, dir, rnn.n_iter, b, s)); - } - }); - } -} - -template <> -template -void ref_rnn_bwd_f32_t::copy_res_iter( - const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, - const src_data_t *ws_states_, float *ws_c_states_, - const float *ws_diff_states_) const { - auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); - AOC ws_diff_states(ws_diff_states_, rnn.n_layer + 1, - rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, - rnn.states_ws_ld); - if (diff_src_iter_) { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, - [&](int lay, int dir, int state, int b) { - for (int s = 0; s < rnn.sic; s++) { - diff_src_iter_[diff_src_iter_d.blk_off( - lay, dir, state, b, s)] - = ws_diff_states(lay, dir, state, 0, b, s); - } - }); - } -} - -template -rnn_bias_prepare_sig((_ref_rnn_common_t::bias_prepare)) { - /* Original set of bias provided by the user */ - AOC b( - b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); - /* Array of pointers initialized in packing */ - AOC bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); - AOC scratch_bias( - scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); - - if (rnn.copy_bias) { - parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, - [&](size_t i) { scratch_bias_[i] = b_[i]; }); - } - - for (int i = 0; i < rnn.n_layer; i++) { - for (int d = 0; d < rnn.n_dir; d++) { - int offset_bias = 0; - for (int p = 0; p < rnn.n_parts_bias; p++) { - bias(i, d, p) = rnn.copy_bias - ? (float *) &scratch_bias(i, d, offset_bias) - : (float *) &b(i, d, offset_bias); - offset_bias += rnn.parts_bias[p] * rnn.dic; - } - } - } - -} - -template -rnn_bias_finalize_sig( - (_ref_rnn_common_t::bias_finalize)) { - if (rnn.dt_conf != all_f32) { - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; - bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; - for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) - for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { - size_t off = i * rnn.n_bias * rnn.dic + j; - float weights_scale - = scale_per_oc ? weights_scales[j] : weights_scales[0]; - scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) - * data_shift / (weights_scale * data_scale); - } - } -} - -template -rnn_weights_assign_sig((_ref_rnn_common_t::assign_packed_weights)) { - assert(md->format_kind == format_kind::rnn_packed); - const auto packed_desc = md->format_desc.rnn_packed_desc; - AOC weights(weights_, - rnn.n_layer, rnn.n_dir, packed_desc.n_parts); - - size_t offset_packed = 0; - for (int l = 0; l < rnn.n_layer; l++) - for (int d = 0; d < rnn.n_dir; d++) { - for (int p = 0; p < packed_desc.n_parts; p++) { - weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; - offset_packed - += packed_desc.part_pack_size[p] / sizeof(weights_data_t); - } - } -} - -template -rnn_weights_assign_sig( - (_ref_rnn_common_t::assign_weights)) { - assert(md->format_kind == format_kind::blocked); - const auto &blk = md->format_desc.blocking; - /* Original set of weights provided by the user */ - AOC w(w_, - rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); - /* Array of pointers for each part of weights */ - AOC weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); - - for (int i = 0; i < rnn.n_layer; i++) - for (int d = 0; d < rnn.n_dir; d++) { - size_t offset_weights = 0; - for (int p = 0; p < n_parts; p++) { - weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); - offset_weights += gates_per_part[p] * blk.strides[3]; - } - } -} - -//********************* Execution function *********************// -template -void _ref_rnn_common_t::execute_( - const exec_ctx_t &ctx) const { - const rnn_conf_t &rnn = this->pd()->rnn_; - auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); - auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); - auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); - auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); - auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); - - auto dst_last_layer = rnn.is_fwd - ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) - : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); - auto dst_last_iter = rnn.is_fwd - ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) - : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); - - auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); - auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); - - auto w_layer = reinterpret_cast(layer_weights_n_comp); - auto w_iter = reinterpret_cast(iter_weights_n_comp); - auto w_iter_comp = reinterpret_cast( - iter_weights_n_comp + rnn.weights_iter_comp_offset); - auto w_layer_comp = reinterpret_cast( - layer_weights_n_comp + rnn.weights_layer_comp_offset); - - auto scratchpad = this->scratchpad(ctx); - - auto ptr_wei_layer - = scratchpad.template get(key_rnn_ptrs_wei_layer); - auto ptr_wei_iter - = scratchpad.template get(key_rnn_ptrs_wei_iter); - auto ptr_bias = - scratchpad.template get(key_rnn_ptrs_bia); - - // fetchihg buffers from the workspace - // if no workspace was provided we use the scratchpad - char *scratch_ptr = scratchpad.template get(key_rnn_space); - char *ws_ptr = nullptr; - if (rnn.use_workspace) - ws_ptr = rnn.is_fwd - ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) - : const_cast(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); - - char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; - acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); - src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); - float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); - float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); - float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); - float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); - - auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); - auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); - - auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); - auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); - auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); - - // Fetching extra buffers from scratchpad - float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); - - // initialize diff_states to 0 - if (aprop == prop_kind::backward) - array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); - - /* Pack(if using packed gemm API) or copy(if input arrays have bad leading - * dimension */ - (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); - - (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), - rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, - rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, - rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, - ptr_bias, bias, ws_bias); - (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), - rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, - rnn.n_parts_weights_layer, rnn.parts_weights_layer, - rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, - bias, ws_bias); - - (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); - - // we first need to copy the initial states and input into ws - copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); - if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 - || rnn.dt_conf == all_f32) - copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, - (const float *)states, diff_dst_iter); - else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) - copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, - (const uint8_t *)states, diff_dst_iter); - else - assert(!"unimplemented"); - - // run the execution on the grid - (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, - ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, - diff_weights_layer, diff_weights_iter, diff_bias); - - // Finally we copy the results to the result buffers - if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 - || rnn.dt_conf == all_f32) - copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, - ws_diff_states); - else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) - copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, - ws_states, ws_diff_states); - else - assert(!"unimplemented"); - - if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 - || rnn.dt_conf == all_f32) - copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, - ws_c_states, ws_diff_states); - else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) - copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, - ws_c_states, ws_diff_states); - else - assert(!"unimplemented"); -}; - -/* Fix for MSVS warning C4661 */ -template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); -template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); -template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); -template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); -template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); -template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); -template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); -template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); -template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); -template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); -template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); -template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); -template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); - -template struct _ref_rnn_common_t; -template struct _ref_rnn_common_t; -template struct _ref_rnn_common_t; - -#undef AOC -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp deleted file mode 100644 index 6f449a901..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp +++ /dev/null @@ -1,328 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_RNN_HPP -#define CPU_REF_RNN_HPP - -#include - -#include "c_types_map.hpp" -#include "memory_tracking.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -#include "../cpu_isa_traits.hpp" -#include "../gemm/os_blas.hpp" - -#include "cpu_rnn_pd.hpp" -#include "../cpu_primitive.hpp" -#include "rnn_utils.hpp" -#include "jit_uni_rnn_postgemm.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -float activation(float s, float alpha, float cliping, float dd); - -template -struct _ref_rnn_common_t : public cpu_primitive_t { - typedef typename prec_traits::type src_data_t; - typedef typename prec_traits::type weights_data_t; - typedef typename utils::conditional::type acc_data_t; - - using class_name = _ref_rnn_common_t; - - typedef rnn_elemwise_sig((class_name::*elemwise_f)); - typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); - typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); - - typedef rnn_gemm_sig((class_name::*gemm_t)); - typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); - typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); - typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); - - using base_pd_t = - typename utils::conditional::type; - - struct pd_t : public base_pd_t { - using base_pd_t::base_pd_t; - - DECLARE_COMMON_PD_T("ref:any", class_name); - - status_t init() { - using namespace prop_kind; - using namespace utils; - using namespace format_tag; - using namespace rnn_utils; - const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; - - data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; - data_type_t weights_iter_dt - = this->desc()->weights_iter_desc.data_type; - data_type_t weights_layer_dt - = this->desc()->weights_layer_desc.data_type; - - bool ok = true - && one_of(cell_kind, alg_kind::vanilla_rnn, - alg_kind::vanilla_lstm, alg_kind::vanilla_gru, - alg_kind::gru_linear_before_reset) - && IMPLICATION(aprop == prop_kind::forward, - one_of(this->desc()->prop_kind, forward_training, - forward_inference)) - && IMPLICATION(aprop == backward, - one_of(this->desc()->prop_kind, backward)) - && src_layer_dt == src_type - && everyone_is( - weights_type, weights_iter_dt, weights_layer_dt) - && this->set_default_params() == status::success - && this->with_bias(); - if (!ok) - return status::unimplemented; - - init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1), - this->weights_md(0), this->weights_md(1), this->dst_md(0)); - - if (rnn_.dt_conf == all_f32) - ok = ok && this->attr()->has_default_values(); - - // Set weights descriptors to desired format - memory_desc_t new_weights_layer_md = *this->weights_md(0); - CHECK(set_expected_desc(rnn_, new_weights_layer_md, false)); - if (this->weights_layer_md_.format_kind == format_kind::any) { - this->weights_layer_md_ = new_weights_layer_md; - } else if (this->weights_layer_md_.format_kind - == format_kind::rnn_packed) { - if (this->weights_layer_md_ != new_weights_layer_md) - return status::unimplemented; - } - - memory_desc_t new_weights_iter_md = *this->weights_md(1); - CHECK(set_expected_desc(rnn_, new_weights_iter_md, true)); - if (this->weights_iter_md_.format_kind == format_kind::any) { - this->weights_iter_md_ = new_weights_iter_md; - } else if (this->weights_iter_md_.format_kind - == format_kind::rnn_packed) { - if (this->weights_iter_md_ != new_weights_iter_md) - return status::unimplemented; - } - - CHECK(this->check_layout_consistency()); - - set_conf(rnn_, *this->desc(), this->weights_md(0), - this->weights_md(1), this->diff_weights_md(0), - this->diff_weights_md(1)); - - size_t scratchpad_sz{0}, ws_sz{0}; - get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); - - // initialize the workspace if needed - if (rnn_.is_training) { - dims_t ws_dims = { (int)ws_sz }; - mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims, - data_type::u8, format_tag::x); - } - - init_scratchpad(scratchpad_sz); - - return status::success; - } - - rnn_utils::rnn_conf_t rnn_; - - private: - void init_scratchpad(size_t scratchpad_sz) { - using namespace memory_tracking::names; - auto scratchpad = this->scratchpad_registry().registrar(); - scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096); - - int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1; - int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; - scratchpad.book(key_rnn_ptrs_wei_layer, - sizeof(float *) * ptr_wei_sz); - scratchpad.book(key_rnn_ptrs_wei_iter, - sizeof(float *) * ptr_wei_sz); - scratchpad.book(key_rnn_ptrs_bia, - sizeof(float *) * ptr_wei_sz); - } - }; - - _ref_rnn_common_t(const pd_t *apd) - : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) { - /// @todo set max_feature_size assuming that we limit the number of - /// iterations and layer to one if slc != dic and sic != dic - /// respectively - - bias_preparation_func = &class_name::bias_prepare; - bias_finalization_func = &class_name::bias_finalize; - - auto set_gemm_funcs - = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) { - if (packed_gemm) { - g = &class_name::packed_gemm; - a = &class_name::assign_packed_weights; - } else { - g = &class_name::gemm; - a = &class_name::assign_weights; - } - }; - set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, - weights_iter_assign_func); - - set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, - weights_layer_assign_func); - - switch (pd()->cell_kind()) { - case alg_kind::vanilla_lstm: - cell_func = &class_name::cell_execution; - if (aprop == prop_kind::forward) { - if (mayiuse(avx512_core)) - rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( - pd()->rnn_, pd()->attr()); - else if (mayiuse(avx2)) - rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( - pd()->rnn_, pd()->attr()); - else if (mayiuse(sse42)) - rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd( - pd()->rnn_, pd()->attr()); - assert(rnn_postgemm_ != nullptr); - rnn_postgemm_->init(); - } - elemwise_func = &class_name::lstm_elemwise; - break; - case alg_kind::vanilla_rnn: // @todo switch on cell kind - cell_func = &class_name::cell_execution; - elemwise_func = &class_name::rnn_elemwise; - switch (pd()->activation_kind()) { - case alg_kind::eltwise_relu: - activation_func = &activation; - break; - case alg_kind::eltwise_tanh: - activation_func = &activation; - break; - case alg_kind::eltwise_logistic: - activation_func = &activation; - break; - default: break; - } - break; - case alg_kind::vanilla_gru: - cell_func = &class_name::cell_execution_gru; - break; - case alg_kind::gru_linear_before_reset: - cell_func = &class_name::cell_execution_gru_lbr; - elemwise_func = &class_name::gru_lbr_elemwise; - break; - default: break; - } - - grid_computation = &class_name::linear_execution; - - size_t scratchpad_size, workspace_size; - rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_, - ws_c_states_offset_, ws_diff_states_offset_, - ws_grid_comp_offset_, ws_cell_comp_offset_, - ws_bias_offset_, scratchpad_size, workspace_size); - } - - ~_ref_rnn_common_t() {} - - // typedef typename prec_traits::type data_t; - - virtual status_t execute(const exec_ctx_t &ctx) const override { - execute_(ctx); - return status::success; - } - -private: - void execute_(const exec_ctx_t &ctx) const; - rnn_grid_execution_sig(linear_execution); - rnn_cell_execution_sig(cell_execution); - rnn_cell_execution_sig(cell_execution_gru); - rnn_cell_execution_sig(cell_execution_gru_lbr); - rnn_elemwise_sig(rnn_elemwise); - rnn_elemwise_sig(lstm_elemwise); - rnn_elemwise_sig(gru_lbr_elemwise); - rnn_gemm_sig(gemm); - rnn_gemm_sig(packed_gemm); - rnn_bias_prepare_sig(bias_prepare); - rnn_bias_finalize_sig(bias_finalize); - rnn_weights_assign_sig(assign_weights); - rnn_weights_assign_sig(assign_packed_weights); - - float (*activation_func)(float dd, float s, float alpha, float cliping); - - void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, - src_data_t *ws_states_, float *ws_diff_states_, - const src_data_t *xt_, const float *diff_dst_layer) const; - - template - void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, - src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_, - const input_data_t *firstit_states_, - const float *diff_dst_iter) const; - - template - void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, - dst_data_t *dst_layer_, float *diff_src_layer, - const src_data_t *ws_states_, const float *ws_diff_states_) const; - - template - void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, - output_data_t *dst_iter_, float *diff_src_iter, - const src_data_t *ws_states_, float *ws_c_states, - const float *ws_diff_states_) const; - - void gates_reduction(const rnn_utils::rnn_conf_t &rnn, - const acc_data_t *ws_gates_, float *diff_bias_) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - - size_t ws_gates_offset_; - size_t ws_states_offset_; - size_t ws_c_states_offset_; - size_t ws_bias_offset_; - size_t ws_diff_states_offset_; - size_t ws_grid_comp_offset_; - size_t ws_cell_comp_offset_; - jit_uni_rnn_postgemm_kernel *rnn_postgemm_; - - grid_execution_f grid_computation; - cell_execution_f cell_func; - - bias_prepare_t bias_preparation_func; - bias_finalize_t bias_finalization_func; - weights_assign_t weights_layer_assign_func; - weights_assign_t weights_iter_assign_func; - - gemm_t gemm_layer_func; - gemm_t gemm_iter_func; - elemwise_f elemwise_func; -}; - -using ref_rnn_fwd_f32_t = _ref_rnn_common_t; -using ref_rnn_bwd_f32_t = _ref_rnn_common_t; -using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t; -} -} -} -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp deleted file mode 100644 index 78cdedbae..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp +++ /dev/null @@ -1,380 +0,0 @@ -/******************************************************************************* - * Copyright 2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#ifndef CPU_RNN_REORDERS_HPP -#define CPU_RNN_REORDERS_HPP - -#include - -#include "type_helpers.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" -#include "simple_q10n.hpp" -#include "cpu_reorder_pd.hpp" -#include "../gemm/os_blas.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct rnn_data_reorder_t : public cpu_primitive_t { - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t); - - static status_t create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { - const memory_desc_wrapper id(src_md), od(dst_md); - bool args_ok = true - && id.data_type() == type_i - && od.data_type() == type_o - && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) - && od == id; - if (!args_ok) return status::invalid_arguments; - - auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, - dst_md); - if (_pd == nullptr) return out_of_memory; - if (_pd->init() != success) { delete _pd; return unimplemented; } - return safe_ptr_assign(*reorder_pd, _pd); - } - }; - -private: - typedef typename prec_traits::type in_data_t; - typedef typename prec_traits::type out_data_t; - - rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); - auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); - const memory_desc_wrapper &input_d = pd()->src_md(); - const memory_desc_wrapper &output_d = pd()->dst_md(); - const size_t nelems = input_d.nelems(); - const float scale = pd()->attr()->rnn_data_qparams_.scale_; - const float shift = pd()->attr()->rnn_data_qparams_.shift_; - - parallel_nd(nelems, [&](size_t i) { - float in = (float)input[input_d.off_l(i)] * scale + shift; - output[output_d.off_l(i)] = qz_a1b0()(in); - }); - - return status::success; - } - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template -struct rnn_weights_reorder_t : public cpu_primitive_t { - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); - - static status_t create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { -#if !USE_MKL_PACKED_GEMM - return status::unimplemented; -#endif - const memory_desc_wrapper id(src_md), od(dst_md); - bool args_ok = true - && id.data_type() == type_i - && od.data_type() == type_o - && od.format_kind() == format_kind::rnn_packed - && od.rnn_packed_desc().format == mkldnn_ldigo_p - && od.rnn_packed_desc().n_parts == 1 - && attr != nullptr; - if (!args_ok) return status::invalid_arguments; - - format_tag_t itag = id.matches_one_of_tag( - format_tag::ldigo, format_tag::ldgoi); - if (itag == format_tag::undef) return status::invalid_arguments; - - const int mask = attr->rnn_weights_qparams_.mask_; - if (!utils::one_of(mask, 0, 3)) return status::unimplemented; - - auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, - dst_md); - if (_pd == nullptr) return out_of_memory; - _pd->itag_ = itag; - if (_pd->init() != success) { delete _pd; return unimplemented; } - return safe_ptr_assign(*reorder_pd, _pd); - } - - status_t init() { - status_t status = cpu_reorder_pd_t::init(); - if (status != status::success) return status; - - init_scratchpad(); - - return status::success; - } - - format_tag_t itag_ = mkldnn_format_tag_undef; - - private: - void init_scratchpad() { - const memory_desc_wrapper id(src_md()); - const size_t nelems = id.nelems(); - const auto &dims = id.dims(); - - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - size_t quantization_size = sizeof(int8_t) * nelems; - size_t reduction_size = itag_ == ldigo - ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] - * dims[1] * dims[3] * dims[4] - : 0; - scratchpad.book( - key_reorder_rnn_weights_quantization, quantization_size); - scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); - } - }; - -private: - typedef typename prec_traits::type in_data_t; - typedef typename prec_traits::type out_data_t; - - rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { -#if USE_MKL_PACKED_GEMM - auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); - auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); - const memory_desc_wrapper &input_d = pd()->src_md(); - const memory_desc_wrapper &output_d = pd()->dst_md(); - const auto &dims = input_d.dims(); - - const int L = dims[0]; - const int D = dims[1]; - const int I = dims[2]; - const int G = dims[3]; - const int O = dims[4]; - - const bool is_igo = pd()->itag_ == format_tag::ldigo; - - /* Quantize input & compute compensation */ - auto quantized = (int8_t * __restrict)scratchpad(ctx).template get( - memory_tracking::names::key_reorder_rnn_weights_quantization); - auto reduction = (int32_t * __restrict)scratchpad(ctx).template get( - memory_tracking::names::key_reorder_rnn_weights_reduction); - float *comp = reinterpret_cast( - output + output_d.rnn_packed_desc().offset_compensation); - const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; - const int mask = pd()->attr()->rnn_weights_qparams_.mask_; - - if (is_igo) { - int nthr = mkldnn_get_max_threads(); - int LD_nthr = nstl::min(L * D, nthr); - int I_nthr = nstl::min(I, nthr / LD_nthr); - parallel(nthr, [&](const int ithr, const int nthr) { - int LD_ithr = -1, LD_s = -1, LD_e = -1; - int I_ithr = -1, I_s = -1, I_e = -1; - if (ithr < LD_nthr * I_nthr) { - LD_ithr = ithr % LD_nthr; - I_ithr = ithr / LD_nthr; - balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); - balance211(I, I_nthr, I_ithr, I_s, I_e); - } - int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; - for (int ld = LD_s; ld < LD_e; ld++) { - for (int go = 0; go < G * O; go++) - comp_ithr[ld * G * O + go] = 0; - for (int i = I_s; i < I_e; i++) { - PRAGMA_OMP_SIMD() - for (int go = 0; go < G * O; go++) { - const float s = scales[(mask == 0) ? 0 : go]; - int8_t q = qz_b0()( - input[ld * I * G * O + i * G * O + go], s); - quantized[ld * I * G * O + i * G * O + go] - = (int32_t)q; - comp_ithr[ld * G * O + go] += (int32_t)q; - } - } - } - }); - parallel_nd(L * D * G * O, - [&](int s) { comp[s] = saturate(reduction[s]); }); - for (int i = 1; i < I_nthr; i++) { - parallel_nd(L * D * G * O, [&](int s) { - comp[s] += saturate( - reduction[i * L * D * G * O + s]); - }); - } - } else { - parallel_nd(L * D, G * O, [&](int ld, int go) { - int32_t compensation = 0; - const float s = scales[(mask == 0) ? 0 : go]; - PRAGMA_OMP_SIMD() - for (int i = 0; i < I; i++) { - int8_t q = qz_b0()( - input[ld * G * O * I + go * I + i], s); - compensation += (int32_t)q; - quantized[ld * G * O * I + go * I + i] = q; - } - comp[ld * G * O + go] = saturate(compensation); - }); - } - - /* Pack */ - auto off_igo = [&](int l, int d, int i, int g, int o) { - return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; - }; - auto off_goi = [&](int l, int d, int i, int g, int o) { - return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; - }; - int n_parts = output_d.rnn_packed_desc().n_parts; - const size_t *size_packed_cell - = output_d.rnn_packed_desc().part_pack_size; - const int *parts = output_d.rnn_packed_desc().parts; - const int n = output_d.rnn_packed_desc().n; - char *to_pack = output; - for (int l = 0; l < L; l++) { - for (int d = 0; d < D; d++) { - for (int p = 0; p < n_parts; p++) { - int g = (p > 0) ? parts[p - 1] : 0; - int m_p = parts[p] * O; - int k_p = I; - cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, - is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, - &quantized[is_igo ? off_igo(l, d, 0, g, 0) : - off_goi(l, d, g, 0, 0)], - is_igo ? G * O : I, to_pack); - to_pack += size_packed_cell[p]; - } - } - } -#endif - return status::success; - } - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -template <> -struct rnn_weights_reorder_t - : public cpu_primitive_t { - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); - - static status_t create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { -#if !USE_MKL_PACKED_GEMM - return status::unimplemented; -#endif - const memory_desc_wrapper id(src_md), od(dst_md); - bool args_ok = true - && id.data_type() == data_type::f32 - && od.data_type() == data_type::f32 - && od.format_kind() == format_kind::rnn_packed - && utils::one_of(od.rnn_packed_desc().format, - mkldnn_ldigo_p, mkldnn_ldgoi_p) - && attr->has_default_values(); - if (!args_ok) return status::invalid_arguments; - - format_tag_t itag = id.matches_one_of_tag( - format_tag::ldigo, format_tag::ldgoi); - if (itag == format_tag::undef) return status::invalid_arguments; - - const int mask = attr->rnn_weights_qparams_.mask_; - if (!utils::one_of(mask, 0, 3)) return status::unimplemented; - - auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, - dst_md); - if (_pd == nullptr) return out_of_memory; - if (_pd->init() != success) { delete _pd; return unimplemented; } - _pd->itag_ = itag; - return safe_ptr_assign(*reorder_pd, _pd); - } - - format_tag_t itag_; - }; - -private: - rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { -#if USE_MKL_PACKED_GEMM - auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); - auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); - const memory_desc_wrapper &input_d = pd()->src_md(); - const memory_desc_wrapper &output_d = pd()->dst_md(); - const auto &dims = input_d.dims(); - const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); - const int L = dims[0]; - const int D = dims[1]; - const int I = dims[2]; - const int G = dims[3]; - const int O = dims[4]; - - /* Pack */ - bool cross_case = false - || (pd()->itag_ == format_tag::ldigo - && rnn_pdata.format == mkldnn_ldgoi_p) - || (pd()->itag_ == format_tag::ldgoi - && rnn_pdata.format == mkldnn_ldigo_p); - auto trans = cross_case ? CblasTrans : CblasNoTrans; - int n_parts = rnn_pdata.n_parts; - const size_t *size_packed_cell = rnn_pdata.part_pack_size; - const int *parts = rnn_pdata.parts; - const int n = rnn_pdata.n; - - const bool is_igo = pd()->itag_ == format_tag::ldigo; - auto off_igo = [&](int l, int d, int i, int g, int o) { - return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; - }; - auto off_goi = [&](int l, int d, int i, int g, int o) { - return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; - }; - for (int l = 0; l < L; l++) { - for (int d = 0; d < D; d++) { - for (int p = 0; p < n_parts; p++) { - int g = (p > 0) ? parts[p - 1] : 0; - int m_p = is_igo ? parts[p] * O : I; - int k_p = is_igo ? I : parts[p] * O; - int ld = is_igo ? G * O : I; - cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, - k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : - off_goi(l, d, 0, g, 0)], - ld, output); - output += size_packed_cell[p] / sizeof(float); - } - } - } -#endif - return status::success; - } - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} // namespace cpu -} // namespace impl -} // namespace mkldnn - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp deleted file mode 100644 index 1d60415cb..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp +++ /dev/null @@ -1,426 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "c_types_map.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_rnn.hpp" -#include "rnn_utils.hpp" -#include "type_helpers.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace rnn_utils; -using namespace format_tag; -using namespace rnn_packed_format; -using namespace data_type; - -bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) { - if (md.format_kind() != format_kind::blocked) - return false; - - auto blk = md.blocking_desc(); - auto str = blk.strides; - auto dims = md.dims(); - return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1 - && str[3] == dims[4] && str[1] == str[2] * dims[2] - && str[0] == str[1] * dims[1]; -}; - -bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) { - if (md.format_kind() != format_kind::blocked) - return false; - - auto blk = md.blocking_desc(); - auto str = blk.strides; - auto dims = md.dims(); - return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1 - && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3] - && str[0] == str[1] * dims[1]; -}; - -void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, - const memory_desc_wrapper &src_layer_d, - const memory_desc_wrapper &src_iter_d, - const memory_desc_wrapper &weights_layer_d, - const memory_desc_wrapper &weights_iter_d, - const memory_desc_wrapper &dst_layer_d) { - rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, - prop_kind::forward_inference); - rnn.is_training = utils::one_of( - rd.prop_kind, prop_kind::forward_training, prop_kind::backward); - rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset; - - switch (rd.direction) { - case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break; - case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break; - case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break; - case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break; - default: break; - } - - if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(), - weights_layer_d.data_type())) - rnn.dt_conf = all_f32; - else if (dst_layer_d.data_type() == u8) { - if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) - rnn.dt_conf = u8u8u8u8; - else - rnn.dt_conf = f32u8f32u8; - } else { - if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) - rnn.dt_conf = u8u8u8f32; - else - rnn.dt_conf = f32u8f32f32; - } - - rnn.n_layer = weights_layer_d.dims()[0]; - rnn.n_iter = src_layer_d.dims()[0]; - rnn.n_dir = weights_layer_d.dims()[1]; - rnn.n_gates = weights_layer_d.dims()[3]; - rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc); - rnn.n_bias = rnn.n_gates + rnn.is_lbr; - rnn.mb = src_layer_d.dims()[1]; - rnn.sic = weights_iter_d.dims()[2]; - rnn.slc = weights_layer_d.dims()[2]; - rnn.dic = weights_layer_d.dims()[4]; - rnn.dlc = dst_layer_d.dims()[2]; - - rnn.gates_ld = rnn.dic * rnn.n_gates; - rnn.gates_nld = rnn.mb; - rnn.states_nld = rnn.mb; - - /* Set the correct number of weights parts */ - bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru; - rnn.n_parts_weights_layer = 1; - rnn.parts_weights_layer[0] = rnn.n_gates; - rnn.parts_weights_layer[1] = 0; - - rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1; - rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates; - rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0; - - rnn.n_parts_bias = 1; - rnn.parts_bias[0] = rnn.n_bias; - rnn.parts_bias[1] = 0; - - /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas - * and if to mergre gemm across iterations */ - bool is_int8 = rnn.dt_conf != all_f32; - rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) - || is_int8; - bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru, - alg_kind::gru_linear_before_reset); - rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8; - bool is_inference = !rnn.is_training; - - rnn.use_jit_gemm = !mayiuse(avx512_mic) - && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) - || (rnn.is_training && rnn.dic < 500)); - - /* Decide to copy bias */ - rnn.copy_bias = rnn.dt_conf != all_f32; - -#if USE_MKL_PACKED_GEMM - rnn.use_layer_packed_gemm - = (weights_layer_d.format_kind() == format_kind::any - && rnn.slc > 760 && rnn.dic > 760 && is_inference) - || is_int8; // packed gemm is the only supported option for int8 - rnn.use_iter_packed_gemm - = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760 - && rnn.dic > 760 && is_inference) - || is_int8; -#else - rnn.use_layer_packed_gemm = false; - rnn.use_iter_packed_gemm = false; -#endif - - /* Set packed gemm sizes */ - if (rnn.use_layer_packed_gemm) { - rnn.weights_layer_pack_size = 0; - for (int p = 0; p < rnn.n_parts_weights_layer; p++) { - int m_p = rnn.is_fwd - ? (rnn.parts_weights_layer[p] * rnn.dic) - : rnn.slc; - int k_p = rnn.is_fwd - ? rnn.slc - : (rnn.parts_weights_layer[p] * rnn.dic); - int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb; - -#if USE_MKL_PACKED_GEMM - if (rnn.dt_conf == all_f32) - rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size( - CblasAMatrix, m_p, n_p, k_p); - else - rnn.part_weights_layer_pack_size[p] - = cblas_gemm_s8u8s32_pack_get_size( - CblasAMatrix, m_p, n_p, k_p); -#else - UNUSED(m_p); - UNUSED(k_p); - UNUSED(n_p); - rnn.part_weights_layer_pack_size[p] = 0; -#endif - rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir - * rnn.part_weights_layer_pack_size[p]; - } - rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size; - rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer - * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float); - } - - if (rnn.use_iter_packed_gemm) { - rnn.weights_iter_pack_size = 0; - for (int p = 0; p < rnn.n_parts_weights_iter; p++) { - int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) : - rnn.sic; - int k_p = rnn.is_fwd ? rnn.sic : - (rnn.parts_weights_iter[p] * rnn.dic); - int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb; - -#if USE_MKL_PACKED_GEMM - if (rnn.dt_conf == all_f32) - rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size( - CblasAMatrix, m_p, n_p, k_p); - else - rnn.part_weights_iter_pack_size[p] - = cblas_gemm_s8u8s32_pack_get_size( - CblasAMatrix, m_p, n_p, k_p); -#else - UNUSED(m_p); - UNUSED(k_p); - UNUSED(n_p); - rnn.part_weights_iter_pack_size[p] = 0; -#endif - rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir - * rnn.part_weights_iter_pack_size[p]; - } - rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size; - rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer - * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float); - } - -} - -void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, - const memory_desc_wrapper &weights_layer_d, - const memory_desc_wrapper &weights_iter_d, - const memory_desc_wrapper &diff_weights_layer_d, - const memory_desc_wrapper &diff_weights_iter_d) { - - /* Set leading dimensions for input weights arrays depending on input format - */ - rnn.weights_layer_is_packed - = weights_layer_d.format_kind() == format_kind::rnn_packed; - rnn.weights_iter_is_packed - = weights_iter_d.format_kind() == format_kind::rnn_packed; - - auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) { - ld = 0; nld = 0; - if (md.is_blocking_desc()) { - if (is_ldigo(md)) { - ld = (int)md.blocking_desc().strides[2]; - nld = md.dims()[2]; - } else if (is_ldgoi(md)) { - ld = (int)md.blocking_desc().strides[4]; - nld = md.dims()[3] * md.dims()[4]; - } else - assert(!"unsupported weights format"); - } - }; - set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); - set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); - if (!rnn.is_fwd) { - set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, - rnn.diff_weights_layer_nld); - set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, - rnn.diff_weights_iter_nld); - } - - int sizeof_states_dt - = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t); - rnn.states_ws_ld - = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)), - sizeof_states_dt); - rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float)); - - /* Set workspace sizes to store: - * states to copmute a pass - * diff states to copmute bwd pass (training only) - * intermediate results from the gates - */ - rnn.use_workspace = rnn.is_training; - rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir - * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt; - bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm; - rnn.ws_c_states_size = is_lstm - ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb - * rnn.states_ws_ld * sizeof(float) - : 0; - rnn.ws_diff_states_size = rnn.is_training - ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) - * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld - * sizeof(float) - : (size_t)0; - rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb - * rnn.gates_ws_ld * sizeof(float); - - /* set other sizes */ - rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float); - rnn.ws_cell_comp_size - = rnn.is_lbr || rnn.dt_conf != all_f32 - ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float) - : 0; - rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer - * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); - rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic - * sizeof(float); -} - -int rnn_utils::get_good_ld(int dim, int sizeof_dt) { - // we want matrices leading dimentions to be 64-byte aligned, - // and not divisible by 256 to avoid 4K aliasing effects - int ld = rnd_up(dim, 64 / sizeof_dt); - return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; -} - -void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, - size_t &ws_states_offset, size_t &ws_c_states_offset, - size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, - size_t &ws_cell_comp_offset, size_t &ws_bias_offset, - size_t &scratchpad_size, size_t &workspace_size) { - - const size_t page_size = 4096; // 2097152; - size_t current_offset; - /* Mandatory workspaces: go to workspace if use_workspace, scratchpad - * otherwise */ - current_offset = 0; // assumes the workspace base pointer is page aligned - ws_gates_offset = current_offset; - current_offset += rnn.ws_gates_size; - - current_offset = utils::rnd_up(current_offset, page_size); - ws_states_offset = current_offset; - current_offset += rnn.ws_states_size; - - current_offset = utils::rnd_up(current_offset, page_size); - ws_c_states_offset = current_offset; - current_offset += rnn.ws_c_states_size; - - current_offset = utils::rnd_up(current_offset, page_size); - ws_diff_states_offset = current_offset; - current_offset += rnn.ws_diff_states_size; - - current_offset = utils::rnd_up(current_offset, page_size); - ws_grid_comp_offset = current_offset; - current_offset += rnn.ws_grid_comp_size; - - current_offset = utils::rnd_up(current_offset, page_size); - ws_cell_comp_offset = current_offset; - current_offset += rnn.ws_cell_comp_size; - - workspace_size = rnn.use_workspace ? current_offset : 0; - - /* Optional scratchpads */ - // Assumes the scratchpad base pointer is page aligned. - // If use_workspace, the following goes to scratchpad alone, - // otherwise, all goes to scratchpad and continue incrementing offset - current_offset = rnn.use_workspace ? 0 : current_offset; - - if (rnn.copy_bias) { - current_offset = utils::rnd_up(current_offset, page_size); - ws_bias_offset = current_offset; - current_offset += rnn.ws_bias_size; - } - - scratchpad_size = current_offset; -} - -void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, - size_t &scratchpad_size, size_t &workspace_size) { - size_t ws_gates_offset, ws_states_offset, ws_c_states_offset, - ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, - ws_bias_offset; - set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset, - ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, - ws_bias_offset, scratchpad_size, workspace_size); -} - -status_t rnn_utils::set_good_strides( - memory_desc_t &weights_md, format_tag_t tag) { - auto &strides = weights_md.format_desc.blocking.strides; - auto dims = weights_md.dims; - - if (tag == ldigo) { - strides[2] = rnn_utils::get_good_ld((int)strides[2], - (int)types::data_type_size(weights_md.data_type)); - strides[1] = dims[2] * strides[2]; - strides[0] = dims[1] * strides[1]; - } else if (tag == ldgoi) { - strides[4] = rnn_utils::get_good_ld((int)strides[4], - (int)types::data_type_size(weights_md.data_type)); - strides[3] = dims[4] * strides[4]; - strides[1] = dims[3] * strides[3]; - strides[0] = dims[1] * strides[1]; - } else - return status::unimplemented; - - return status::success; -} - -status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, - memory_desc_t &weights_md, bool is_iter) { - using namespace format_tag; - bool use_packed_gemm = is_iter - ? rnn.use_iter_packed_gemm - : rnn.use_layer_packed_gemm; - if (use_packed_gemm) { - weights_md.format_kind = format_kind::rnn_packed; - rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; - rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p; - if (is_iter) { - rnn_pdata.n = rnn.mb; - rnn_pdata.n_parts = rnn.n_parts_weights_iter; - array_copy(rnn_pdata.parts, rnn.parts_weights_iter, - MKLDNN_RNN_MAX_N_PARTS); - array_copy(rnn_pdata.part_pack_size, - rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS); - rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; - rnn_pdata.size = rnn.weights_iter_pack_size; - } else { - rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; - rnn_pdata.n_parts = rnn.n_parts_weights_layer; - array_copy(rnn_pdata.parts, rnn.parts_weights_layer, - MKLDNN_RNN_MAX_N_PARTS); - array_copy(rnn_pdata.part_pack_size, - rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS); - rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; - rnn_pdata.size = rnn.weights_layer_pack_size; - } - } else { - CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); - // Adjust strides for good leading dimension in GEMM - CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); - } - return status::success; -} - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp deleted file mode 100644 index 99eb787a6..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp +++ /dev/null @@ -1,225 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef RNN_UTILS_HPP -#define RNN_UTILS_HPP - -#include "mkldnn.h" - -#include "cpu_rnn_pd.hpp" - - -#define rnn_elemwise_sig(f) \ - void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \ - src_data_t *states_t_l_, float *c_states_t_l_, \ - src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ - float *diff_states_t_l_, float *diff_states_t_lp1_, \ - float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \ - float *ws_cell_) const - -#define rnn_cell_execution_sig(f) \ - void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \ - float *c_states_t_l_, float *diff_states_t_l_, \ - weights_data_t **w_layer_, weights_data_t **w_iter_, \ - float **bias_, src_data_t *states_t_lm1_, \ - src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ - float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ - float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \ - acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const - -#define rnn_grid_execution_sig(f) \ - void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \ - weights_data_t **weights_states_, float **bias_, \ - src_data_t *ws_states_, float *ws_c_states_, \ - float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \ - float *ws_grid_, float *diff_weights_layer_, \ - float *diff_weights_iter_, float *diff_bias_) const - -#define rnn_gemm_sig(f) \ - void f(const char transA, const char transB, int m, int n, int k, \ - const float alpha, const weights_data_t *a_, const int ldA, \ - const src_data_t *b_, const int ldB, const float beta, \ - acc_data_t *c_, const int ldC) const - -#define rnn_bias_prepare_sig(f) \ - void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \ - float *scratch_bias_) const - -#define rnn_bias_finalize_sig(f) \ - void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \ - const float *w_iter_comp, const float *w_layer_comp) const - -#define rnn_weights_assign_sig(f) \ - void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \ - int ld, int OC_size, int IC_size, const int n_parts, \ - const int *gates_per_part, const size_t *part_weights_pack_size, \ - weights_data_t **weights_, const weights_data_t *w_, \ - float **bias_, const float *b_, float *scratch_bias_) const - - -namespace mkldnn { -namespace impl { -namespace cpu { - -namespace rnn_utils { - -using namespace mkldnn::impl::utils; - -enum execution_direction_t { - l2r, - r2l, - bi_concat, - bi_sum, -}; - -enum data_type_conf_t { - all_f32, - u8u8u8f32, - f32u8f32f32, - u8u8u8u8, - f32u8f32u8 -}; - -struct rnn_conf_t { - execution_direction_t exec_dir; - data_type_conf_t dt_conf; - int n_layer, n_iter, n_dir, n_gates, n_states; - int mb; - int slc, sic, dic, dlc; - int gates_ld, gates_nld, gates_ws_ld; - int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS]; - int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS]; - int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS]; - size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS], - part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS]; - bool weights_layer_is_packed, weights_iter_is_packed; - /* Size of packed data in bytes */ - size_t weights_layer_comp_offset, weights_layer_pack_size, - weights_iter_comp_offset, weights_iter_pack_size; - - bool copy_bias; - int weights_layer_ld, weights_layer_nld; - int diff_weights_layer_ld, diff_weights_layer_nld; - int weights_iter_ld, weights_iter_nld; - int diff_weights_iter_ld, diff_weights_iter_nld; - int states_nld, states_ws_ld; - int weights_iter_compensation_size, weights_layer_compensation_size; - bool is_fwd, is_training, is_lbr; - bool use_workspace; - - /* Size of workspace for each tensor in bytes */ - size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size, - ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size; - bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm, - use_iter_packed_gemm; -}; - -bool is_ldigo(const memory_desc_wrapper &md); -bool is_ldgoi(const memory_desc_wrapper &md); - -int get_good_ld(int dim, int sizeof_dt); - -void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, - const memory_desc_wrapper &src_layer_d, - const memory_desc_wrapper &src_iter_d, - const memory_desc_wrapper &weights_layer_d, - const memory_desc_wrapper &weights_iter_d, - const memory_desc_wrapper &dst_layer_d); - -void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, - const memory_desc_wrapper &weights_layer_d, - const memory_desc_wrapper &weights_iter_d, - const memory_desc_wrapper &diff_weights_layer_d, - const memory_desc_wrapper &diff_weights_iter_d); - -void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, - size_t &ws_h_state_offset, size_t &ws_c_state_offset, - size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, - size_t &ws_cell_comp_offset, size_t &ws_bias_offset, - size_t &scratchpad_size, size_t &workspace_size); - -void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, - size_t &scratchpad_size, size_t &workspace_size); -status_t set_expected_desc( - rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter); -status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); - -template -struct ws_gates_aoc { - ws_gates_aoc(const rnn_conf_t &rnn, T *data) - : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {} - T &operator()(int batch, int gate, int dic) { - return gates_(batch, gate * DIC_ + dic); - } - -private: - mkldnn::impl::utils::array_offset_calculator gates_; - int DIC_; -}; -using ws_gates_aoc_t = ws_gates_aoc; -using ws_gates_aoc_s32_t = ws_gates_aoc; - -struct bias_aoc_t { - bias_aoc_t(const rnn_conf_t &rnn, const float *data) - : bias_(data, rnn.n_bias, rnn.dic) {} - const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); } - -private: - mkldnn::impl::utils::array_offset_calculator bias_; -}; - -template -struct ws_states_aoc { - ws_states_aoc(const rnn_conf_t &rnn, T *data) - : state_(data, rnn.states_nld, rnn.states_ws_ld) {} - T &operator()(int batch, int dic) { return state_(batch, dic); } - -private: - mkldnn::impl::utils::array_offset_calculator state_; -}; -using ws_states_aoc_t = ws_states_aoc; -using ws_states_aoc_u8_t = ws_states_aoc; - -struct ws_diff_states_aoc_t { - ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data) - : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld, - rnn.states_ws_ld) {} - float &operator()(int state_n, int batch, int dic) { - return diff_states_(state_n, 0, batch, dic); - } - -private: - mkldnn::impl::utils::array_offset_calculator diff_states_; -}; - -struct ws_diff_w_iter_aoc_t { - ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) - : diff_weights_iter_( - data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) - , DIC_(rnn.dic) {} - float &operator()(int sic, int gate, int dic) { - return diff_weights_iter_(sic, gate * DIC_ + dic); - } - -private: - mkldnn::impl::utils::array_offset_calculator diff_weights_iter_; - int DIC_; -}; -} -} -} -} -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp deleted file mode 100644 index 0420f87aa..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp +++ /dev/null @@ -1,126 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_thread.hpp" - -#include "simple_concat.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace memory_tracking::names; - -template -status_t simple_concat_t::execute(const exec_ctx_t &ctx) const { - auto scratchpad = this->scratchpad(ctx); - auto iptrs = scratchpad.template get(key_concat_iptrs); - auto optrs = scratchpad.template get(key_concat_optrs); - auto nelems_to_copy = scratchpad.template get(key_concat_nelems); - auto is = scratchpad.template get(key_concat_istrides); - - const int num_arrs = pd()->n_inputs(); - const int *perm = pd()->perm_, *iperm = pd()->iperm_; - const int concat_dim = pd()->concat_dim(); - auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - for (int a = 0; a < num_arrs; ++a) { - const memory_desc_wrapper i_d(pd()->src_md(a)); - const memory_desc_wrapper o_d(pd()->src_image_md(a)); - - iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) - + i_d.blk_off(0); - optrs[a] = o_base_ptr + o_d.blk_off(0); - nelems_to_copy[a] = pd()->nelems_to_concat(i_d); - for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) { - if (i < perm[concat_dim]) - is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]); - else - is[a][i] = 0; - } - } - - const memory_desc_wrapper o_d(pd()->src_image_md(0)); - - strides_t os = { 0 }; - for (int i = 0; i < perm[concat_dim]; i++) - os[i] = o_d.blocking_desc().strides[iperm[i]]; - - dims_t phys_dims; - for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++) - phys_dims[i] = (i < (size_t)perm[concat_dim]) - ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1; - - if (perm[concat_dim] == 0) { - for (int a = 0; a < num_arrs; ++a) { - const data_t *i = &iptrs[a][0]; - data_t *o = &optrs[a][0]; - parallel_nd((ptrdiff_t)nelems_to_copy[a], - [&](ptrdiff_t e) { o[e] = i[e]; }); - } - } else { - parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], - phys_dims[4], num_arrs, - [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) { - // XXX: this code may access uninitialized values in is[*][0-4] -- - // that's why we have to set them to zero although this is - // probably benign - size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2 - + is[a][3] * n3 + is[a][4] * n4; - size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2 - + os[3] * n3 + os[4] * n4; - const data_t *i = &iptrs[a][in_off]; - data_t *o = &optrs[a][out_off]; -#if defined(__GNUC__) && !defined(__INTEL_COMPILER) - // The code below performs data copying: o[e] = i[e] - // and uses a workaround to make GNU compilers optimize it - uint8_t *ptro = reinterpret_cast(o); - const uint8_t *ptri = reinterpret_cast(i); - const dim_t main_part = - nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t); - const dim_t tail_part = - nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t); - - PRAGMA_OMP_SIMD() - for (dim_t e = 0; e < main_part; ++e) { - *(reinterpret_cast(ptro)) - = *(reinterpret_cast(ptri)); - ptro += sizeof(uint32_t); - ptri += sizeof(uint32_t); - } - for (dim_t e = 0; e < tail_part; ++e) { - *ptro = *ptri; - ++ptro; - ++ptri; - } -#else - PRAGMA_OMP_SIMD() - for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e]; -#endif - }); - } - - return status::success; -} - -template struct simple_concat_t; -template struct simple_concat_t; -template struct simple_concat_t; -template struct simple_concat_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp deleted file mode 100644 index 057cc3c4c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp +++ /dev/null @@ -1,155 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SIMPLE_CONCAT_HPP -#define SIMPLE_CONCAT_HPP - -#include "memory_tracking.hpp" - -#include "cpu_concat_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct simple_concat_t: public cpu_primitive_t { - struct pd_t: public cpu_concat_pd_t { - using cpu_concat_pd_t::cpu_concat_pd_t; - - pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { - int ndims = rhs.dst_md_.ndims; - utils::array_copy(perm_, rhs.perm_, ndims); - utils::array_copy(iperm_, rhs.iperm_, ndims); - utils::array_copy(blocks_, rhs.blocks_, ndims); - } - - DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); - - status_t init() { - const memory_desc_wrapper dst_d(dst_md()); - bool ok = true - && cpu_concat_pd_t::init() == status::success - && dst_d.ndims() <= 6; - if (!ok) return status::unimplemented; - - for (size_t i = 0; i < src_mds_.size(); ++i) { - const memory_desc_wrapper i_d(&src_mds_[i]); - const memory_desc_wrapper o_d(&src_image_mds_[i]); - - const int ignore_strides = 0; - - ok = ok - && utils::everyone_is(data_type, i_d.data_type(), - o_d.data_type()) - && utils::everyone_is(format_kind::blocked, - i_d.format_kind(), o_d.format_kind()) - && types::blocking_desc_is_equal(i_d.blocking_desc(), - o_d.blocking_desc(), ignore_strides) - && types::blocking_desc_is_equal(i_d.blocking_desc(), - dst_d.blocking_desc(), ignore_strides) - && !i_d.is_additional_buffer(); - if (!ok) return status::unimplemented; - } - - dst_d.compute_blocks(blocks_); - format_perm(); - - // start dim is the first dimension after which the concatenation - // would happen contiguously - const int start_dim = perm_[concat_dim()]; - - // check that contiguous part is indeed contiguous (i.e. dense) - if (nelems_to_concat(dst_d) != - dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] - * dst_d.blocking_desc().strides[concat_dim()]) - return status::unimplemented; - - // check that all inputs have the same strides for the - // contiguous part [concat_dim .. ndims] for the *major* dims. - // the block part is already checked above - for (size_t i = 0; i < src_mds_.size(); ++i) { - const memory_desc_wrapper i_d(&src_mds_[i]); - for (int d = start_dim; d < dst_d.ndims(); ++d) { - if (dst_d.blocking_desc().strides[iperm_[d]] - != i_d.blocking_desc().strides[iperm_[d]]) - return status::unimplemented; - } - } - - init_scratchpad(); - - return status::success; - } - - int perm_[MKLDNN_MAX_NDIMS] {}; - int iperm_[MKLDNN_MAX_NDIMS] {}; - dims_t blocks_ {}; - - dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { - const int ndims = data_d.ndims(); - - dim_t nelems = 1; - for (int i = perm_[concat_dim()]; i < ndims; i++) - nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; - for (int i = 0; i < ndims; i++) - nelems *= blocks_[i]; - - return nelems; - } - - private: - void format_perm() { - const memory_desc_wrapper dst_d(dst_md()); - const int ndims = dst_d.ndims(); - - strides_t strides; - utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); - for (int i = 0; i < ndims; i++) iperm_[i] = i; - - utils::simultaneous_sort(strides, iperm_, ndims, - [](stride_t a, stride_t b) { return b - a; }); - - for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i; - } - - void init_scratchpad() { - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs()); - scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs()); - scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs()); - scratchpad.book(key_concat_istrides, - sizeof(strides_t) * n_inputs()); - } - }; - - simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override; - - typedef typename prec_traits::type data_t; - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp deleted file mode 100644 index e6c3b8d7a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp +++ /dev/null @@ -1,98 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_SIMPLE_Q10N_HPP -#define CPU_SIMPLE_Q10N_HPP - -#include - -#include "c_types_map.hpp" -#include "math_utils.hpp" -#include "nstl.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::math; - -template -inline out_t round_and_saturate(float f) -{ return math::saturate(out_round(f)); } - -/* Quantization with alpha == 1 and beta == 0 */ -template -struct qz_a1b0 { - out_t operator()(in_t in) - { return round_and_saturate((float)in); } -}; - -template -struct qz_a1b0::value - && !is_subset::value - >::type> { - out_t operator()(in_t in) { return math::saturate(in); } -}; - -template -struct qz_a1b0::value>::type> { - out_t operator()(in_t in) { return (out_t)in; } -}; - -/* Quantization with alpha == 1 */ -template struct qz_a1 { - out_t operator()(in_t in, out_t out, float beta) - { return round_and_saturate((float)in + beta * out); } -}; - -template struct qz_a1 { - float operator()(in_t in, float out, float beta) - { return (float)in + beta * out; } -}; - -/* Quantization with beta == 0 */ -template struct qz_b0 { - out_t operator()(in_t in, float alpha) - { return round_and_saturate(alpha * in); } -}; - -template struct qz_b0 { - float operator()(in_t in, float alpha) { return alpha * in; } -}; - -/* Quantization */ -template struct qz { - out_t operator()(in_t in, out_t out, float alpha, float beta) { - return round_and_saturate( - alpha * in + (beta ? beta * out : 0)); - } -}; - -template struct qz { - float operator()(in_t in, float out, float alpha, float beta) - { return alpha * in + (beta ? beta * out : 0); } -}; - -} -} -} - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp deleted file mode 100644 index ff845f5bd..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp +++ /dev/null @@ -1,1022 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_SIMPLE_REORDER_HPP -#define CPU_SIMPLE_REORDER_HPP - -#include - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" -#include "utils.hpp" - -#include "tag_traits.hpp" -#include "cpu_reorder_pd.hpp" -#include "cpu_primitive.hpp" - -#include "simple_q10n.hpp" -#include "cpu_isa_traits.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::status; -using namespace mkldnn::impl::format_tag; -using namespace mkldnn::impl::data_type; - -using bd = block_dim_t; -using ib = inner_blk_t; - -using namespace mkldnn::impl::utils; -using math::saturate; - -template -using data_t = typename prec_traits::type; - -template -using _qz_a1b0 = qz_a1b0, data_t>; - -template -using _qz = qz, data_t>; - -namespace fmt_order { - const bool keep = true; - const bool reverse = false; - const bool any = keep; -} - -namespace spec { -struct direct_copy {}; -struct direct_copy_except_dim_0 {}; -struct reference {}; -struct conv_s8s8 {}; -} - -#define SIMPLE_REORDER_TEMPL_DECL \ - impl::data_type_t type_i, impl::format_tag_t tag_i, \ - impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep -#define SIMPLE_REORDER_TEMPL_CALL \ - type_i, tag_i, type_o, tag_o, order_keep - -#define DECLARE_COMMON_PARAMS() \ - const memory_desc_wrapper &input_d = pd->src_md(); \ - const memory_desc_wrapper &output_d = pd->dst_md(); \ - const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \ - const float beta = pd->beta(); MAYBE_UNUSED(beta); - -/* specific reorders: common template */ -template -struct simple_reorder_impl {}; - -namespace { -inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i, - impl::format_tag_t tag_o, const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d) { - return input_d.matches_tag(order_keep ? tag_i : tag_o) - && output_d.matches_tag(order_keep ? tag_o : tag_i); -} -inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) { - if (many_scales_support) - return true; - return IMPLICATION(attr, attr->output_scales_.mask_ == 0); -} -} - -/* specific reorders: implementation */ -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) - { - const size_t D_mask = utils::array_product(input_d.dims(), - math::ilog2q(attr->output_scales_.mask_ + 1)); - const int oc = (input_d.dims()[tag_o == hwigo + 0]); - const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1; - - return output_d.matches_tag(tag_o) - && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) - && (input_d.data_type() == f32 || input_d.data_type() == s8) - && output_d.data_type() == s8 - && (D_mask == 1 || D_mask == (size_t)g * oc); - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - static constexpr bool w_groups = tag_o == hwigo; - - const auto &dims = input_d.dims(); - const auto &pdims = output_d.padded_dims(); - - const int G = w_groups ? dims[0] : 1; - const int OC = dims[w_groups + 0]; - const int IC = dims[w_groups + 1]; - const int H = dims[w_groups + 2]; - const int W = dims[w_groups + 3]; - - const float *scales = pd->attr()->output_scales_.scales_; - const size_t D_mask = utils::array_product(input_d.dims(), - math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); - - assert(output_d.extra().flags - & memory_extra_flags::compensation_conv_s8s8); - float adj_scale = - (output_d.extra().flags & memory_extra_flags::scale_adjust) - ? output_d.extra().scale_adjust : 1.f; - - size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W; - int32_t *cp = reinterpret_cast(output + offset); - - parallel_nd(G, OC, [&](int g, int oc) { - cp[g * OC + oc] = 0; - for (int ic = 0; ic < IC; ic++) - for (int h = 0; h < H; h++) - for (int w = 0; w < W; w++) { - auto i = input[input_d.blk_off(g, oc, ic, h, w)]; - auto &o = output[output_d.blk_off(g, oc, ic, h, w)]; - const float s = scales[(D_mask == 1) ? 0 : g * OC + oc]; - - o = qz_b0, data_t>()( - i, s * adj_scale); - cp[g * OC + oc] -= (int32_t)o; - } - cp [g * OC + oc] *= 128; - }); - return success; - } -}; - -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) - { - const size_t D_mask = utils::array_product(input_d.dims(), - math::ilog2q(attr->output_scales_.mask_ + 1)); - const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); - const int oc = (input_d.dims()[w_groups ? 1 : 0]); - const int g = w_groups ? input_d.dims()[0] : 1; - - return input_d.matches_tag(tag_i) - && output_d.matches_tag(tag_o) - && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) - && (input_d.data_type() == f32 || input_d.data_type() == s8) - && output_d.data_type() == s8 - && (D_mask == 1 || D_mask == (size_t)g * oc); - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - static constexpr bool w_groups = - !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); - constexpr int is_1d = - utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i); - constexpr int blksize = tag_traits::inner_blks == ib::_4b4c - ? 4 - : tag_traits::inner_blks == ib::_2c8b4c - ? 8 - : 16; - - const auto &_g_oihw_d = order_keep ? input_d : output_d; - const auto &dims = input_d.dims(); - const auto &pdims = order_keep - ? output_d.padded_dims() - : input_d.padded_dims(); - - const int G = w_groups ? dims[0] : 1; - const int OC = dims[w_groups + 0]; - const int NB_OC = pdims[w_groups + 0] / blksize; - const int IC = dims[w_groups + 1]; - const int NB_IC = pdims[w_groups + 1] / blksize; - const int H = is_1d ? 1 : dims[w_groups + 2]; - const int W = dims[w_groups + 3 - is_1d]; - - const float *scales = pd->attr()->output_scales_.scales_; - const size_t D_mask = utils::array_product(input_d.dims(), - math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); - - assert(output_d.extra().flags - & memory_extra_flags::compensation_conv_s8s8); - float adj_scale = - (output_d.extra().flags & memory_extra_flags::scale_adjust) - ? output_d.extra().scale_adjust : 1.f; - - auto ker = [&](const data_t *inp, data_t *out, - int32_t *c, const float *s, const int oc_block, const int ic_block) { -# define index AB_or_BC_blk_off::inner_blks> - - for (int ic = 0; ic < ic_block; ++ic) { - for (int oc = 0; oc < oc_block; ++oc) { - const auto _g_oihw_off = - oc * _g_oihw_d.blocking_desc().strides[w_groups + 0] - + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1]; - out[index(oc, ic)] - = qz_b0, data_t>()( - inp[_g_oihw_off], s[oc] * adj_scale); - c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); - } - } -# undef index - }; - - constexpr int i_mult = blksize; - constexpr int o_mult = 1; - - size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W; - int32_t *cp = reinterpret_cast(output + offset); - parallel_nd(G * NB_OC * blksize, [&](int i) { - cp[i] = 0; - }); - -# define wei_blk_off(md, g, o, i, h, w) \ - (is_1d ? (md).blk_off(g, o, i, w) \ - : (md).blk_off(g, o, i, h, w)) - - parallel_nd(G, NB_OC, [&](int g, int O) { - for (int I = 0; I < NB_IC; I++) - for (int h = 0; h < H; h++) - for (int w = 0; w < W; w++) { - auto i = &input[wei_blk_off( - input_d, g, i_mult * O, i_mult * I, h, w)]; - auto o = &output[wei_blk_off( - output_d, g, o_mult * O, o_mult * I, h, w)]; - const int oc_block = nstl::min(blksize, OC - O * blksize); - const int ic_block = nstl::min(blksize, IC - I * blksize); - - int _offset = (g * NB_OC + O) * blksize; - ker(i, o, (order_keep) ? &cp[_offset] : nullptr, - &scales[(D_mask == 1) ? 0 : _offset], - oc_block, ic_block); - } - }); - -# undef wei_blk_off - - return success; - } -}; - -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - const size_t D_mask = utils::array_product(input_d.dims(), - math::ilog2q(attr->output_scales_.mask_ + 1)); - const int oc = input_d.dims()[1]; - const int g = input_d.dims()[0]; - - return true - && order_keep - && input_d.matches_tag(tag_i) - && output_d.matches_tag(tag_o) - && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) - && (input_d.data_type() == f32 || input_d.data_type() == s8) - && output_d.data_type() == s8 - && (D_mask == 1 || D_mask == (size_t)g * oc); - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - constexpr bool is_1d = tag_i == goiw; - constexpr int blksize = 16; - - const auto &dims = input_d.dims(); - const auto &pdims = output_d.padded_dims(); - const int G = dims[0]; - const int Gp = pdims[0]; - const int OC = dims[1]; - const int IC = dims[2]; - const int H = is_1d ? 1 : dims[3]; - const int W = dims[4 - is_1d]; - - const size_t D_mask = utils::array_product(input_d.dims(), - math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); - const float *scales = pd->attr()->output_scales_.scales_; - - assert(output_d.extra().flags - & memory_extra_flags::compensation_conv_s8s8); - float adj_scale = - (output_d.extra().flags & memory_extra_flags::scale_adjust) - ? output_d.extra().scale_adjust : 1.f; - - auto ker = [&](const data_t *inp, data_t *out, - int32_t *cp, const float *s, const int g_block) { - PRAGMA_OMP_SIMD() - for (int g = 0; g < g_block; g++) { - const auto i_off = g * input_d.blocking_desc().strides[0]; - out[g] = qz_b0, data_t>()( - inp[i_off], s[g * OC] * adj_scale); - cp[g * OC] -= 128 * (int32_t)(out[g]); - } - }; - - size_t cp_offset = output_d.size() - output_d.additional_buffer_size(); - int32_t *cp = reinterpret_cast(output + cp_offset); - parallel_nd((Gp/blksize) * OC, [&](int ib) { - PRAGMA_OMP_SIMD() - for (int i = 0; i < blksize; i++) - cp[ib * blksize + i] = 0; - }); - -# define wei_blk_off(md, g, o, i, h, w) \ - (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w)) - - parallel_nd(Gp/blksize, OC, [&](int gb, int O) { - for (int I = 0; I < IC; I++) { - for (int h = 0; h < H; h++) - for (int w = 0; w < W; w++) - { - const int g_block = nstl::min(G - gb * blksize, blksize); - const auto inp = &input[wei_blk_off( - input_d, gb * blksize, O, I, h, w)]; - const auto out = &output[wei_blk_off( - output_d, gb, O, I, h, w)]; - int offset = gb * blksize + O; - ker(inp, out, &cp[offset], - &scales[(D_mask == 1) ? 0 : offset], g_block); - } - } - }); - -# undef wei_blk_off - - return success; - } -}; - -/* reorders with tail support */ - -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) - { - return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) - && simple_attr_check(attr, false); - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - constexpr int is_1d = tag_i == nCw8c; - constexpr int is_3d = tag_i == nCdhw8c; - constexpr int blksize_16 = 16; - constexpr int blksize_8 = 8; - constexpr int ic_mult = order_keep ? 2 : 1; - constexpr int oc_mult = order_keep ? 1 : 2; - - const auto &dims = input_d.dims(); - const auto &pdims = order_keep ? output_d.padded_dims() - : input_d.padded_dims(); - - const int C = dims[1]; - const int D = is_3d ? dims[2] : 1; - const int H = is_1d ? 1 : dims[2 + is_3d]; - const int W = dims[3 + is_3d - is_1d]; - - auto ker = [&](const data_t *i, data_t *o, - const int block_16) { - const int nb = (block_16 - 1) / blksize_8 + 1; - if (alpha == 1.0 && beta == 0.0) { - for (int b = 0; b < nb; ++b) { - const ptrdiff_t i_off = order_keep ? b : b * blksize_8; - const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; - const int block_8 = nstl::min(blksize_8, - block_16 - b * blksize_8); - for (int c = 0; c < block_8; ++c) { - o[o_off + c] = _qz_a1b0()( - i[i_off + c]); - } - } - } else { - for (int b = 0; b < nb; ++b) { - const ptrdiff_t i_off = order_keep ? b : b * blksize_8; - const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; - const int block_8 = nstl::min(blksize_8, - block_16 - b * blksize_8); - for (int c = 0; c < block_8; ++c) { - o[o_off + c] = _qz()(i[i_off + c], - o[o_off + c], alpha, beta); - } - } - } - }; - -# define data_blk_off(md, n, c, d, h, w) \ - ( is_1d ? (md).blk_off(n, c, w) \ - : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) - - parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, - [&](int n, int nb_c, int d, int h, int w) { - auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)]; - auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)]; - const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16); - ker(i, o, block_16); - }); - -# undef data_blk_off - - return success; - } -}; - -#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ - static bool is_applicable(const memory_desc_wrapper &input_d, \ - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \ - return simple_attr_check(attr, false) && (order_keep \ - ? output_d.matches_tag(tag_o) && input_d.is_plain() \ - : input_d.matches_tag(tag_o) && output_d.is_plain()); \ - } - -template -struct simple_reorder_impl::block_dims == bd::_A - || tag_traits::block_dims == bd::_B) - && tag_traits::ndims >= 3 - && tag_traits::ndims <= 6 - >::type> -{ - PLAIN_TO_BLOCKED_IS_APPLICABLE(); - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - const auto &flat_d = order_keep ? input_d : output_d; - const auto &block_d = order_keep ? output_d : input_d; - const auto &dims = input_d.dims(); - const auto &pdims = block_d.padded_dims(); - - constexpr int ndims = tag_traits::ndims; - constexpr int blk_idx = tag_traits::block_dims == bd::_A ? 0 : 1; - - const dim_t H0 = dims[0]; - const dim_t H1 = dims[1]; - const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1; - const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; - const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; - const dim_t L = dims[ndims - 1]; - const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; - - constexpr int blksize = false ? 0 - : utils::one_of(tag_traits::inner_blks, ib::_4a, ib::_4b) ? 4 - : utils::one_of(tag_traits::inner_blks, ib::_8a, ib::_8b) ? 8 - : 16; - - auto ker = [&](const data_t *i, data_t *o, int block) { - if (alpha == 1.0 && beta == 0.0) { - for (int l = 0; l < L; ++l) - for (int blk = 0; blk < block; ++blk) { - const dim_t flat_off = 0 - + blk * flat_d.blocking_desc().strides[blk_idx] - + l * flat_d.blocking_desc().strides[ndims - 1]; - if (order_keep) { - o[l * l_blk_stride + blk] = _qz_a1b0()( - i[flat_off]); - } else { - o[flat_off] = _qz_a1b0()( - i[l * l_blk_stride + blk]); - } - } - } else { - for (int l = 0; l < L; ++l) - for (int blk = 0; blk < block; ++blk) { - const dim_t flat_off = 0 - + blk * flat_d.blocking_desc().strides[blk_idx] - + l * flat_d.blocking_desc().strides[ndims - 1]; - if (order_keep) { - o[l * l_blk_stride + blk] = _qz()( - i[flat_off], o[l * blksize + blk], - alpha, beta); - } else { - o[flat_off] = _qz()( - i[l * l_blk_stride + blk], o[flat_off], - alpha, beta); - } - } - } - }; - -# define off(md, h0, h1, m0, m1, m2) \ - (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ - : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ - : ndims >= 4 ? (md).blk_off(h0, h1, m2) \ - : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) - - constexpr int i_mult = order_keep ? blksize : 1; - constexpr int o_mult = order_keep ? 1 : blksize; - - if (blk_idx == 0) { - const dim_t BH0 = pdims[0] / blksize; - parallel_nd(BH0, H1, M0, M1, M2, - [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { - auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)]; - auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)]; - const int block = nstl::min(blksize, H0 - bh0 * blksize); - ker(i, o, block); - }); - } else if (blk_idx == 1) { - const dim_t BH1 = pdims[1] / blksize; - parallel_nd(H0, BH1, M0, M1, M2, - [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { - auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)]; - auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)]; - const int block = nstl::min(blksize, H1 - bh1 * blksize); - ker(i, o, block); - }); - } else { - assert(!"unimplemented"); - } - -# undef off - - return success; - } -}; - -template -struct simple_reorder_impl::block_dims == bd::_AB - || tag_traits::block_dims == bd::_BC) - && IMPLICATION(tag_traits::block_dims == bd::_AB, - tag_traits::ndims >= 3 && tag_traits::ndims <= 5) - && IMPLICATION(tag_traits::block_dims == bd::_BC, - tag_traits::ndims >= 4 && tag_traits::ndims <= 6) - >::type> -{ - PLAIN_TO_BLOCKED_IS_APPLICABLE(); - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - const auto &flat_d = order_keep ? input_d : output_d; - const auto &dims = input_d.dims(); - const auto &pdims = order_keep - ? output_d.padded_dims() - : input_d.padded_dims(); - - constexpr int ndims = tag_traits::ndims; - - static constexpr bool with_g = tag_traits::block_dims == bd::_BC; - const dim_t G = with_g ? dims[0] : 1; - - const dim_t H0 = dims[0 + with_g]; - const dim_t H1 = dims[1 + with_g]; - - const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; - const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; - const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; - - constexpr int blksize_0 = false ? 0 - : utils::one_of(tag_traits::inner_blks, - ib::_4b4a, ib::_4b4c, ib::_4c4b) - ? 4 - : utils::one_of(tag_traits::inner_blks, - ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) - ? 8 - : utils::one_of(tag_traits::inner_blks, - ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c, - ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b, - ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c) - ? 16 : INT_MIN; - - constexpr int blksize_1 = utils::one_of(tag_traits::inner_blks, - ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) - ? 8 - : utils::one_of(tag_traits::inner_blks, - ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b, - ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, - ib::_4c16b4c, ib::_8c16b2c) - ? 16 - : utils::one_of(tag_traits::inner_blks, - ib::_4b4a, ib::_4b4c, ib::_4c4b, - ib::_16a4b, ib::_16b4c) - ? 4 - : INT_MIN; - - const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; - const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; - - auto ker = [&](const data_t *i, data_t *o, - const int block_h0, const int block_h1) { -# define blk_off AB_or_BC_blk_off::inner_blks> - - if (alpha == 1.0 && beta == 0.0) { - for (int h0 = 0; h0 < block_h0; ++h0) - for (int h1 = 0; h1 < block_h1; ++h1) { - const dim_t flat_off = 0 - + h0 * flat_d.blocking_desc().strides[with_g + 0] - + h1 * flat_d.blocking_desc().strides[with_g + 1]; - if (order_keep) { - o[blk_off(h0, h1)] = _qz_a1b0()( - i[flat_off]); - } else { - o[flat_off] = _qz_a1b0()( - i[blk_off(h0, h1)]); - } - } - } else { - for (int h0 = 0; h0 < block_h0; ++h0) - for (int h1 = 0; h1 < block_h1; ++h1) { - const dim_t flat_off = 0 - + h0 * flat_d.blocking_desc().strides[with_g + 0] - + h1 * flat_d.blocking_desc().strides[with_g + 1]; - if (order_keep) { - o[blk_off(h0, h1)] = _qz()(i[flat_off], - o[blk_off(h0, h1)], alpha, beta); - } else { - o[flat_off] = _qz()(i[blk_off(h0, h1)], - o[flat_off], alpha, beta); - } - } - } - -# undef blk_off - }; - - constexpr int i_mult_0 = order_keep ? blksize_0 : 1; - constexpr int o_mult_0 = order_keep ? 1 : blksize_0; - - constexpr int i_mult_1 = order_keep ? blksize_1 : 1; - constexpr int o_mult_1 = order_keep ? 1 : blksize_1; - -# define off(md, g, h0, h1, m0, m1, m2) \ - (ndims >= 5 + with_g ? (md).blk_off(g, h0, h1, m0, m1, m2) \ - : ndims >= 4 + with_g ? (md).blk_off(g, h0, h1, m1, m2) \ - : /* ndims >= 3 + with_g ? */ (md).blk_off(g, h0, h1, m2)) - - parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, - [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) { - auto i = &input[off(input_d, - g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)]; - auto o = &output[off(output_d, - g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)]; - const int block_h0 = nstl::min(blksize_0, H0 - nb_h0 * blksize_0); - const int block_h1 = nstl::min(blksize_1, H1 - nb_h1 * blksize_1); - ker(i, o, block_h0, block_h1); - }); - -# undef off - - return success; - } -}; - -/* generic and direct-copy reorders */ - -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - /* FIXME: is the formula correct? */ - return input_d.similar_to(output_d, true, false, 0) - && input_d.is_dense() && output_d.is_dense() - && simple_attr_check(attr, false); - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - assert(input_d.is_dense()); - - input += input_d.blk_off(0); - output += output_d.blk_off(0); - - const size_t nelems = input_d.nelems(); - - constexpr int block_size = 16; - const auto num_blocks = nelems / block_size; - const auto rem_elems = nelems % block_size; - - parallel(0, [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - balance211(num_blocks, nthr, ithr, start, end); - start = start * block_size; - end = end * block_size; - - if (alpha == 1.0 && beta == 0.0) { - PRAGMA_OMP_SIMD() - for (size_t e = start; e < end; ++e) { - output[e] = qz_a1b0, data_t>() - (input[e]); - } - } else if (alpha == 1.0) { - PRAGMA_OMP_SIMD() - for (size_t e = start; e < end; ++e) { - output[e] = qz_a1, data_t>() - (input[e], output[e], beta); - } - } else if (beta == 0.0) { - PRAGMA_OMP_SIMD() - for (size_t e = start; e < end; ++e) { - output[e] = qz_b0, data_t>() - (input[e], alpha); - } - } else { - PRAGMA_OMP_SIMD() - for (size_t e = start; e < end; ++e) { - output[e] = qz, data_t>() - (input[e], output[e], alpha, beta); - } - } - - if (rem_elems != 0 && ithr == nthr - 1){ - if (alpha == 1.0 && beta == 0.0) { - PRAGMA_OMP_SIMD() - for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] = qz_a1b0, - data_t>()(input[e]); - } - } else if (alpha == 1.0) { - PRAGMA_OMP_SIMD() - for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] = qz_a1, - data_t>()(input[e], output[e], beta); - } - } else if (beta == 0.0) { - PRAGMA_OMP_SIMD() - for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] = qz_b0, - data_t>()(input[e], alpha); - } - } else { - PRAGMA_OMP_SIMD() - for (size_t e = nelems - rem_elems; e < nelems; ++e) { - output[e] = qz, data_t>() - (input[e], output[e], alpha, beta); - } - } - } - }); - return success; - } -}; - -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { - return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); - }; - /* FIXME: is the formula correct? */ - return input_d.similar_to(output_d, true, false, 1) - && is_dense_no_0(input_d) && is_dense_no_0(output_d) - && simple_attr_check(attr, false); - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - input += input_d.blk_off(0); - output += output_d.blk_off(0); - - const int N = input_d.dims()[0]; - const dim_t is = input_d.blocking_desc().strides[0]; - const dim_t os = output_d.blocking_desc().strides[0]; - const dim_t nelems_no_d0 = nelems_no_dim_0(input_d); - const dim_t work_amount = N * nelems_no_d0; - - if (alpha == 1.0 && beta == 0.0) { - parallel(0, [&](const int ithr, const int nthr) { - dim_t n{0}, dim1_s{0}; - dim_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); - while(start < end) { - dim_t work_rem = end - start; - dim_t dim1_e = dim1_s + work_rem > nelems_no_d0 - ? nelems_no_d0 : dim1_s + work_rem; - PRAGMA_OMP_SIMD() - for (dim_t e = dim1_s; e < dim1_e; ++e) { - output[os * n + e] = _qz_a1b0()( - input[is * n + e]); - } - nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); - } - }); - } else { - parallel(0, [&](const int ithr, const int nthr) { - dim_t n{0}, dim1_s{0}; - dim_t start{0}, end{0}; - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); - while(start < end) { - dim_t work_rem = end - start; - dim_t dim1_e = - dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0 - : dim1_s + work_rem; - PRAGMA_OMP_SIMD() - for (dim_t e = dim1_s; e < dim1_e; ++e){ - output[os * n + e] = _qz()( - input[is * n + e], output[os * n + e], alpha, - beta); - } - nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); - } - }); - } - - return success; - } - -private: - static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) { - const int ndims = data_d.ndims(); - if (ndims <= 1) return 1; - return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1); - } - - static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) { - dims_t blocks; - data_d.compute_blocks(blocks); - - const auto &blk = data_d.blocking_desc(); - - dim_t blk_size = 1; - for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) - blk_size *= blk.inner_blks[iblk]; - - dim_t max_size = blk_size; - for (int d = 1; d < data_d.ndims(); ++d) { - max_size = nstl::max(max_size, - data_d.padded_dims()[d] / blocks[d] * blk.strides[d]); - } - - return max_size; - } -}; - -template -struct simple_reorder_impl::type> -{ - static bool is_applicable(const memory_desc_wrapper &input_d, - const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { - /* supported smask: 0x0...011..10...0, - * i.e. 1 should be contiguous */ - int smask = attr ? attr->output_scales_.mask_ : 0; - for (; smask > 0 && !(smask & 0x1); smask >>= 1); - for (; smask > 0 && smask & 0x1; smask >>= 1); - return true - && input_d.is_blocking_desc() - && output_d.is_blocking_desc() - && !output_d.is_additional_buffer() - && !input_d.is_additional_buffer() - && smask == 0; - } - - static status_t execute(const cpu_reorder_pd_t *pd, - const data_t *input, data_t *output) { - DECLARE_COMMON_PARAMS(); - - const size_t nelems = input_d.nelems(); - - int ndims_start = 0, ndims_mask = 0; - int smask = pd->attr()->output_scales_.mask_; - for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start; - for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask; - assert(smask == 0); - - const ptrdiff_t D_start - = utils::array_product(input_d.dims(), ndims_start); - const ptrdiff_t D_mask - = utils::array_product(input_d.dims() + ndims_start, ndims_mask); - const ptrdiff_t D_rest = nelems / D_start / D_mask; - - const float *scales = pd->attr()->output_scales_.scales_; - - parallel_nd(D_start, D_mask, D_rest, - [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) { - const float scale = scales[dm]; - - const size_t e = (ds * D_mask + dm) * D_rest + dr; - const auto &i = input[input_d.off_l(e)]; - auto &o = output[output_d.off_l(e)]; - - o = _qz()(i, o, scale, beta); - }); - - return success; - } -}; - - -/* high level class declaration */ - -template -struct simple_reorder_t: public cpu_primitive_t { - struct pd_t: public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("simple:any", simple_reorder_t); - - static status_t create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { - bool args_ok = true - && src_md->data_type == type_i - && dst_md->data_type == type_o - && simple_reorder_impl:: - is_applicable(src_md, dst_md, attr); - if (!args_ok) - return status::invalid_arguments; - - auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, - dst_md); - if (_pd == nullptr) return status::out_of_memory; - if (_pd->init() != status::success) { - delete _pd; - return status::unimplemented; - } - return safe_ptr_assign(*reorder_pd, _pd); - } - }; - - simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto input = CTX_IN_MEM(const data_t *, MKLDNN_ARG_FROM); - auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_TO); - simple_reorder_impl::execute( - pd(), input, output); - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -#undef SIMPLE_REORDER_TEMPL_DECL -#undef SIMPLE_REORDER_TEMPL_CALL - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp deleted file mode 100644 index f0947573a..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn_thread.hpp" - -#include "simple_sum.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -status_t simple_sum_t::execute(const exec_ctx_t &ctx) const { - auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - const memory_desc_wrapper o_d(pd()->dst_md()); - output += o_d.blk_off(0); - - const int num_arrs = pd()->n_inputs(); - const data_t *input_ptrs[max_num_arrs]; - const size_t nelems = o_d.nelems(); - - for (int a = 0; a < num_arrs; ++a) { - const memory_desc_wrapper i_d(pd()->src_md(a)); - input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) - + i_d.blk_off(0); - } - - const size_t block_size = 16 * 1024 / sizeof(data_type); - const size_t blocks_number = nelems / block_size; - const size_t tail = nelems % block_size; - - const auto scales = pd()->scales(); - parallel(0, [&](const int ithr, const int nthr) { - size_t start{0}, end{0}; - balance211(blocks_number, nthr, ithr, start, end); - - for (size_t nb = start; nb < end; ++nb) { - size_t start_e = nb * block_size; - size_t end_e = start_e + block_size; - - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] = data_t(scales[0] * input_ptrs[0][e]); - } - for (int a = 1; a < num_arrs; a++) { - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] += data_t(scales[a] * input_ptrs[a][e]); - } - } - } - - if (tail != 0 && ithr == nthr - 1) { - size_t start_e = nelems - tail; - size_t end_e = nelems; - - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] = data_t(scales[0] * input_ptrs[0][e]); - } - for (int a = 1; a < num_arrs; a++) { - PRAGMA_OMP_SIMD() - for (size_t e = start_e; e < end_e; e++) { - output[e] += data_t(scales[a] * input_ptrs[a][e]); - } - } - } - }); - - return status::success; -} - -template struct simple_sum_t; - -} -} -} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp deleted file mode 100644 index 2a0187a18..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SIMPLE_SUM_HPP -#define SIMPLE_SUM_HPP - -#include "cpu_sum_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct simple_sum_t: public cpu_primitive_t { - struct pd_t: public cpu_sum_pd_t { - using cpu_sum_pd_t::cpu_sum_pd_t; - - DECLARE_SUM_PD_T("simple:any", simple_sum_t); - - status_t init() { - const int n = n_inputs(); - - bool ok = true - && cpu_sum_pd_t::init() == status::success - && n <= max_num_arrs; - if (!ok) return status::unimplemented; - - const memory_desc_wrapper o_d(dst_md()); - ok = ok - && o_d.data_type() == data_type - && o_d.is_dense(); - if (!ok) return status::unimplemented; - - for (int i = 0; i < n; ++i) { - const memory_desc_wrapper i_d(src_md(i)); - if (i_d != o_d) return status::unimplemented; - } - - return status::success; - } - }; - - simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override; - - enum {max_num_arrs = 16 }; - typedef typename prec_traits::type data_t; - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp deleted file mode 100644 index c2082d7d6..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp +++ /dev/null @@ -1,376 +0,0 @@ -/******************************************************************************* - * Copyright 2017-2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#ifndef CPU_WINO_REORDER_HPP -#define CPU_WINO_REORDER_HPP - -#include "mkldnn_thread.hpp" - -#include "simple_q10n.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template -struct wino_reorder_t : public cpu_primitive_t { - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t); - - static status_t create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, - engine_t *src_engine, const memory_desc_t *src_md, - engine_t *dst_engine, const memory_desc_t *dst_md) { - const memory_desc_wrapper id(src_md), od(dst_md); - bool args_ok = true - && id.data_type() == type_i - && od.data_type() == type_o - && id.matches_tag(utils::pick(id.ndims() - 4, - format_tag::oihw, format_tag::goihw)) - && od.format_kind() == format_kind::wino - && utils::one_of(od.wino_desc().wino_format, - mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, - mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); - if (!args_ok) return status::invalid_arguments; - - auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, - dst_md); - if (_pd == nullptr) return status::out_of_memory; - if (_pd->init() != status::success) { - delete _pd; - return status::unimplemented; - } - return safe_ptr_assign(*reorder_pd, _pd); - } - - status_t init() { - status_t status = cpu_reorder_pd_t::init(); - if (status != status::success) return status; - - init_scratchpad(); - - return status::success; - } - - private: - void init_scratchpad() { - auto &o = memory_desc_wrapper(dst_md()).wino_desc(); - size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; - size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; - - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(key_reorder_wino_transform_space, - sizeof(in_data_t) * transform_space_size); - scratchpad.book(key_reorder_wino_plain, - sizeof(out_data_t) * plain_size); - } - }; - -private: - typedef typename prec_traits::type in_data_t; - typedef typename prec_traits::type out_data_t; - const int unsign_val_in_wino_domain_ = 5; - - wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { - const memory_desc_wrapper src_d(pd()->src_md()); - const memory_desc_wrapper dst_d(pd()->dst_md()); - - r_ = dst_d.wino_desc().r; - w_alpha_ = dst_d.wino_desc().alpha; - wino_format_ = dst_d.wino_desc().wino_format; - - const auto &in_dims = src_d.dims(); - int groups; - int groups_offset; - if (src_d.ndims() == 5) { - groups = in_dims[0]; - groups_offset = 1; - } else { - groups = 1; - groups_offset = 0; - } - assert(groups == 1); // groups are not supported now - MAYBE_UNUSED(groups); - - or_oc_ = in_dims[0 + groups_offset]; - or_ic_ = in_dims[1 + groups_offset]; - kh_ = in_dims[2 + groups_offset]; - kw_ = in_dims[3 + groups_offset]; - - oc_ = dst_d.wino_desc().oc; - ic_ = dst_d.wino_desc().ic; - oc_block_ = dst_d.wino_desc().oc_block; - ic_block_ = dst_d.wino_desc().ic_block; - assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); - nb_oc_ = oc_ / oc_block_; - nb_ic_ = ic_ / ic_block_; - ic2_block_ = 1; - if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) - ic2_block_ = dst_d.wino_desc().ic2_block; - oc2_block_ = dst_d.wino_desc().oc2_block; - assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); - - adj_scale_ = dst_d.wino_desc().adj_scale; - - size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; - size_wspace_ = r_ * w_alpha_ * oc_block_; - } - - void transform(out_data_t *__restrict tmp_wei, - const in_data_t *__restrict input, - in_data_t *__restrict wspace) const { - const memory_desc_wrapper src_d(pd()->src_md()); - - const int smask = pd()->attr()->output_scales_.mask_; - const int ndims_mask = math::ilog2q(smask + 1); - const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); - const float *__restrict scales = pd()->attr()->output_scales_.scales_; - assert(D_mask == 1 || D_mask == (size_t)oc_); - - /* transform weights to winograd domain */ - const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, - { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; - - const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, - { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, - { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, - { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, - { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, - { 0.f, 0.f, 1.f } }; - - float *__restrict g; - if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, - mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) - g = (float *)G_2x2_3x3; - else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) - g = (float *)G_4x4_3x3; - else { - assert("Unknown winograd weights target layout"); - return; - } - - int Z = oc_ * ic_; - assert(r_ == kh_ && r_ == kw_); - - for (int iic = 0; iic < ic_; iic++) { - for (int ob = 0; ob < nb_oc_; ob++) { - const in_data_t *__restrict _inp - = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; - out_data_t *__restrict _out - = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; - - for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); - - for_nd(0, 1, r_, w_alpha_, oc_block_, - [&](int ih, int j, int ioc) { - for (int iw = 0; iw < r_; ++iw) { - int inp_oc = ob * oc_block_ + ioc; - int inp_ic = iic; - in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) - ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] - : 0.f; - wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] - += inp_v * g[j * r_ + iw]; - } - }); - - for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, - [&](int i, int j, int ioc) { - float t = 0; - for (int k = 0; k < r_; ++k) - t += g[i * r_ + k] - * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; - if (type_o == data_type::s8) { - const float scale = (D_mask == 1) - ? scales[0] - : scales[ob * oc_block_ + ioc]; - _out[(i * w_alpha_ + j) * Z + ioc] - = qz_b0()( - (in_data_t)t, scale * adj_scale_); - } else { - _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; - } - }); - }} - } - - void reorder_to_aaOIoi(out_data_t *__restrict output, - const out_data_t *__restrict tmp_wei) const { - int32_t *__restrict dst_bias = nullptr; - if (type_o == data_type::s8) { - const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; - const size_t bias_size = w_alpha_ * w_alpha_ * oc_; - - dst_bias = (int32_t *)(output + bias_shift); - utils::array_set((int32_t *)dst_bias, 0, bias_size); - } - int index = 0; - for (int u_h = 0; u_h < w_alpha_; u_h++) { - for (int u_w = 0; u_w < w_alpha_; u_w++) { - for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { - int u_h_shift = u_h * w_alpha_ * ic_ * oc_; - int u_w_shift = u_w * ic_ * oc_; - int u_h_shift_b = u_h * w_alpha_ * oc_; - int u_w_shift_b = u_w * oc_; - int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; - for (int ib = 0; ib < nb_ic_; ib++) { - for (int i = 0; i < ic_block_; i++) { - int _i = ib * ic_block_; - int _o = ob * oc_block_; - int ic_shift = (_i + i) * oc_; - int oc_shift = (_o + o); - int ic_block_shift = ib * oc_block_ * ic_block_ + i; - int src_offset = - u_h_shift + u_w_shift + ic_shift + oc_shift; - int dst_offset = u_h_shift + u_w_shift + oc_block_shift - + ic_block_shift; - - output[dst_offset] = tmp_wei[src_offset]; - if (type_o == data_type::s8) { - int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; - if (index != unsign_val_in_wino_domain_) - dst_bias[bias_offset] - -= (128 * (int32_t)output[dst_offset]); - else - dst_bias[bias_offset] = 0; - } - }} - }); - index++; - }} - } - - void reorder_to_aaOio(out_data_t *__restrict output, - const out_data_t *__restrict tmp_wei) const { - for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, - [&](int u_h, int u_w, int ob) { - for (int ib = 0; ib < nb_ic_; ib++) { - for (int i = 0; i < ic_block_; i++) { - for (int o = 0; o < oc_block_; o++) { - int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ - + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); - - int dst_offset - = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ - + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ - + ob * nb_ic_ * ic_block_ * oc_block_ - + ib * ic_block_ * oc_block_ + i * oc_block_ + o; - output[dst_offset] = tmp_wei[src_offset]; - }}} - }); - } - - void reorder_to_aaOBiOo(out_data_t *__restrict output, - const out_data_t *__restrict tmp_wei) const { - int oc_chunks = nb_oc_ / oc2_block_; - - for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, - [&](int u_h, int u_w, int occ) { - for (int ib = 0; ib < nb_ic_; ib++) { - out_data_t *__restrict wei_ptr = output - + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) - * oc2_block_ * ic_block_ * oc_block_; - int wei_offset = 0; - for (int i = 0; i < ic_block_; i++) { - for (int ob2 = 0; ob2 < oc2_block_; ob2++) { - for (int o = 0; o < oc_block_; o++) { - int icp = ib * ic_block_ + i; - int ocp = - occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; - - int src_offset = u_h * w_alpha_ * ic_ * oc_ - + u_w * ic_ * oc_ + icp * oc_ + ocp; - wei_ptr[wei_offset + o] = tmp_wei[src_offset]; - } - wei_offset += oc_block_; - }} - } - }); - } - - void reorder_to_OBaaIBOIio(out_data_t *__restrict output, - const out_data_t *__restrict tmp_wei) const { - int ic_chunks = nb_ic_ / ic2_block_; - int oc_chunks = nb_oc_ / oc2_block_; - - for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, - [&](int occ, int u_h, int u_w) { - for (int icc = 0; icc < ic_chunks; icc++) { - for (int ob = 0; ob < oc2_block_; ob++) { - int ocp = (occ * oc2_block_ + ob) * oc_block_; - for (int ib = 0; ib < ic2_block_; ib++) { - for (int i = 0; i < ic_block_; i++) { - int icp = (icc * ic2_block_ + ib) * ic_block_ + i; - - int src_offset = u_h * w_alpha_ * ic_ * oc_ - + u_w * ic_ * oc_ + icp * oc_ + ocp; - int wei_offset - = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) - * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ - + ib) * ic_block_ + i) * oc_block_; - for (int o = 0; o < oc_block_; o++) - output[wei_offset + o] = tmp_wei[src_offset + o]; - }} - }} - }); - } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); - auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); - - auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get( - memory_tracking::names::key_reorder_wino_transform_space); - auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get( - memory_tracking::names::key_reorder_wino_plain); - - transform(tmp_wei, input, wspace); - - /* reorder to winograd domain */ - switch (wino_format_) { - case mkldnn_wino_wei_aaOIoi: - reorder_to_aaOIoi(output, tmp_wei); break; - case mkldnn_wino_wei_aaOio: - reorder_to_aaOio(output, tmp_wei); break; - case mkldnn_wino_wei_aaOBiOo: - reorder_to_aaOBiOo(output, tmp_wei); break; - case mkldnn_wino_wei_OBaaIBOIio: - reorder_to_OBaaIBOIio(output, tmp_wei); break; - default: assert("Unknown wino format"); break; - } - - return status::success; - } - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - int r_, w_alpha_; - int ic_, oc_, or_ic_, or_oc_, kh_, kw_; - int oc_block_, ic_block_, oc2_block_, ic2_block_; - float adj_scale_; - int nb_oc_, nb_ic_; - mkldnn_wino_memory_format_t wino_format_; - int size_wino_wei_; - int size_wspace_; -}; - -} // namespace cpu -} // namespace impl -} // namespace mkldnn - -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT deleted file mode 100644 index 66b6ea55d..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT +++ /dev/null @@ -1,47 +0,0 @@ - -Copyright (c) 2007 MITSUNARI Shigeo -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -Redistributions of source code must retain the above copyright notice, this -list of conditions and the following disclaimer. -Redistributions in binary form must reproduce the above copyright notice, -this list of conditions and the following disclaimer in the documentation -and/or other materials provided with the distribution. -Neither the name of the copyright owner nor the names of its contributors may -be used to endorse or promote products derived from this software without -specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -THE POSSIBILITY OF SUCH DAMAGE. ------------------------------------------------------------------------------ -ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た -す場合に限り、再頒布および使用が許可されます。 - -ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 -を含めること。 -バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 -権表示、本条件一覧、および下記免責条項を含めること。 -書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 -に、著作権者の名前またはコントリビューターの名前を使用してはならない。 -本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ -れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 -に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 -著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを -問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で -あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 -本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の -喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 -損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 -一切責任を負わないものとします。 diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h deleted file mode 100644 index cf5771332..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h +++ /dev/null @@ -1,2658 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/******************************************************************************* -* Copyright (c) 2007 MITSUNARI Shigeo -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* Redistributions of source code must retain the above copyright notice, this -* list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* Neither the name of the copyright owner nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -*******************************************************************************/ - -#pragma once -#ifndef XBYAK_XBYAK_H_ -#define XBYAK_XBYAK_H_ -/*! - @file xbyak.h - @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++ - @author herumi - @url https://github.com/herumi/xbyak - @note modified new BSD license - http://opensource.org/licenses/BSD-3-Clause -*/ -#ifndef XBYAK_NO_OP_NAMES - #if not +0 // trick to detect whether 'not' is operator or not - #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()." - #endif -#endif - -#include // for debug print -#include -#include -#include -#include -#ifndef NDEBUG -#include -#endif - -// #define XBYAK_DISABLE_AVX512 - -//#define XBYAK_USE_MMAP_ALLOCATOR -#if !defined(__GNUC__) || defined(__MINGW32__) - #undef XBYAK_USE_MMAP_ALLOCATOR -#endif - -#ifdef __GNUC__ - #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor)) -#else - #define XBYAK_GNUC_PREREQ(major, minor) 0 -#endif - -// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft. -#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\ - ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__))) - #include - #define XBYAK_STD_UNORDERED_SET std::unordered_set - #include - #define XBYAK_STD_UNORDERED_MAP std::unordered_map - #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap - -/* - Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using - libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version). -*/ -#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__) - #include - #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set - #include - #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map - #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap - -#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600) - #include - #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set - #include - #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map - #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap - -#else - #include - #define XBYAK_STD_UNORDERED_SET std::set - #include - #define XBYAK_STD_UNORDERED_MAP std::map - #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap -#endif -#ifdef _WIN32 - #include - #include - #include -#elif defined(__GNUC__) - #include - #include - #include -#endif -#if !defined(_MSC_VER) || (_MSC_VER >= 1600) - #include -#endif - -#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__)) - #define XBYAK64_WIN -#elif defined(__x86_64__) - #define XBYAK64_GCC -#endif -#if !defined(XBYAK64) && !defined(XBYAK32) - #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN) - #define XBYAK64 - #else - #define XBYAK32 - #endif -#endif - -#if (__cplusplus >= 201103) || (_MSC_VER >= 1800) - #define XBYAK_VARIADIC_TEMPLATE -#endif - -#ifdef _MSC_VER - #pragma warning(push) - #pragma warning(disable : 4514) /* remove inline function */ - #pragma warning(disable : 4786) /* identifier is too long */ - #pragma warning(disable : 4503) /* name is too long */ - #pragma warning(disable : 4127) /* constant expresison */ -#endif - -namespace Xbyak { - -enum { - DEFAULT_MAX_CODE_SIZE = 4096, - VERSION = 0x5760 /* 0xABCD = A.BC(D) */ -}; - -#ifndef MIE_INTEGER_TYPE_DEFINED -#define MIE_INTEGER_TYPE_DEFINED -#ifdef _MSC_VER - typedef unsigned __int64 uint64; - typedef __int64 sint64; -#else - typedef uint64_t uint64; - typedef int64_t sint64; -#endif -typedef unsigned int uint32; -typedef unsigned short uint16; -typedef unsigned char uint8; -#endif - -#ifndef MIE_ALIGN - #ifdef _MSC_VER - #define MIE_ALIGN(x) __declspec(align(x)) - #else - #define MIE_ALIGN(x) __attribute__((aligned(x))) - #endif -#endif -#ifndef MIE_PACK // for shufps - #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w)) -#endif - -enum { - ERR_NONE = 0, - ERR_BAD_ADDRESSING, - ERR_CODE_IS_TOO_BIG, - ERR_BAD_SCALE, - ERR_ESP_CANT_BE_INDEX, - ERR_BAD_COMBINATION, - ERR_BAD_SIZE_OF_REGISTER, - ERR_IMM_IS_TOO_BIG, - ERR_BAD_ALIGN, - ERR_LABEL_IS_REDEFINED, - ERR_LABEL_IS_TOO_FAR, - ERR_LABEL_IS_NOT_FOUND, - ERR_CODE_ISNOT_COPYABLE, - ERR_BAD_PARAMETER, - ERR_CANT_PROTECT, - ERR_CANT_USE_64BIT_DISP, - ERR_OFFSET_IS_TOO_BIG, - ERR_MEM_SIZE_IS_NOT_SPECIFIED, - ERR_BAD_MEM_SIZE, - ERR_BAD_ST_COMBINATION, - ERR_OVER_LOCAL_LABEL, // not used - ERR_UNDER_LOCAL_LABEL, - ERR_CANT_ALLOC, - ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW, - ERR_BAD_PROTECT_MODE, - ERR_BAD_PNUM, - ERR_BAD_TNUM, - ERR_BAD_VSIB_ADDRESSING, - ERR_CANT_CONVERT, - ERR_LABEL_ISNOT_SET_BY_L, - ERR_LABEL_IS_ALREADY_SET_BY_L, - ERR_BAD_LABEL_STR, - ERR_MUNMAP, - ERR_OPMASK_IS_ALREADY_SET, - ERR_ROUNDING_IS_ALREADY_SET, - ERR_K0_IS_INVALID, - ERR_EVEX_IS_INVALID, - ERR_SAE_IS_INVALID, - ERR_ER_IS_INVALID, - ERR_INVALID_BROADCAST, - ERR_INVALID_OPMASK_WITH_MEMORY, - ERR_INVALID_ZERO, - ERR_INVALID_RIP_IN_AUTO_GROW, - ERR_INVALID_MIB_ADDRESS, - ERR_INTERNAL, - ERR_X2APIC_IS_NOT_SUPPORTED -}; - -class Error : public std::exception { - int err_; -public: - explicit Error(int err) : err_(err) - { - if (err_ < 0 || err_ > ERR_INTERNAL) { - fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_); - //exit(1); - } - } - operator int() const { return err_; } - const char *what() const throw() - { - static const char *errTbl[] = { - "none", - "bad addressing", - "code is too big", - "bad scale", - "esp can't be index", - "bad combination", - "bad size of register", - "imm is too big", - "bad align", - "label is redefined", - "label is too far", - "label is not found", - "code is not copyable", - "bad parameter", - "can't protect", - "can't use 64bit disp(use (void*))", - "offset is too big", - "MEM size is not specified", - "bad mem size", - "bad st combination", - "over local label", - "under local label", - "can't alloc", - "T_SHORT is not supported in AutoGrow", - "bad protect mode", - "bad pNum", - "bad tNum", - "bad vsib addressing", - "can't convert", - "label is not set by L()", - "label is already set by L()", - "bad label string", - "err munmap", - "opmask is already set", - "rounding is already set", - "k0 is invalid", - "evex is invalid", - "sae(suppress all exceptions) is invalid", - "er(embedded rounding) is invalid", - "invalid broadcast", - "invalid opmask with memory", - "invalid zero", - "invalid rip in AutoGrow", - "invalid mib address", - "internal error", - "x2APIC is not supported" - }; - assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl)); - return errTbl[err_]; - } -}; - -inline const char *ConvertErrorToString(const Error& err) -{ - return err.what(); -} - -inline void *AlignedMalloc(size_t size, size_t alignment) -{ -#ifdef __MINGW32__ - return __mingw_aligned_malloc(size, alignment); -#elif defined(_WIN32) - return _aligned_malloc(size, alignment); -#else - void *p; - int ret = posix_memalign(&p, alignment, size); - return (ret == 0) ? p : 0; -#endif -} - -inline void AlignedFree(void *p) -{ -#ifdef __MINGW32__ - __mingw_aligned_free(p); -#elif defined(_MSC_VER) - _aligned_free(p); -#else - free(p); -#endif -} - -template -inline const To CastTo(From p) throw() -{ - return (const To)(size_t)(p); -} -namespace inner { - -static const size_t ALIGN_PAGE_SIZE = 4096; - -inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; } -inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; } - -inline uint32 VerifyInInt32(uint64 x) -{ -#ifdef XBYAK64 - if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG); -#endif - return static_cast(x); -} - -enum LabelMode { - LasIs, // as is - Labs, // absolute - LaddTop // (addr + top) for mov(reg, label) with AutoGrow -}; - -} // inner - -/* - custom allocator -*/ -struct Allocator { - virtual uint8 *alloc(size_t size) { return reinterpret_cast(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); } - virtual void free(uint8 *p) { AlignedFree(p); } - virtual ~Allocator() {} - /* override to return false if you call protect() manually */ - virtual bool useProtect() const { return true; } -}; - -#ifdef XBYAK_USE_MMAP_ALLOCATOR -class MmapAllocator : Allocator { - typedef XBYAK_STD_UNORDERED_MAP SizeList; - SizeList sizeList_; -public: - uint8 *alloc(size_t size) - { - const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1; - size = (size + alignedSizeM1) & ~alignedSizeM1; -#ifdef MAP_ANONYMOUS - const int mode = MAP_PRIVATE | MAP_ANONYMOUS; -#elif defined(MAP_ANON) - const int mode = MAP_PRIVATE | MAP_ANON; -#else - #error "not supported" -#endif - void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0); - if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC); - assert(p); - sizeList_[(uintptr_t)p] = size; - return (uint8*)p; - } - void free(uint8 *p) - { - if (p == 0) return; - SizeList::iterator i = sizeList_.find((uintptr_t)p); - if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER); - if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP); - sizeList_.erase(i); - } -}; -#endif - -class Address; -class Reg; - -class Operand { - static const uint8 EXT8BIT = 0x20; - unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil - unsigned int kind_:9; - unsigned int bit_:10; -protected: - unsigned int zero_:1; - unsigned int mask_:3; - unsigned int rounding_:3; - void setIdx(int idx) { idx_ = idx; } -public: - enum Kind { - NONE = 0, - MEM = 1 << 0, - REG = 1 << 1, - MMX = 1 << 2, - FPU = 1 << 3, - XMM = 1 << 4, - YMM = 1 << 5, - ZMM = 1 << 6, - OPMASK = 1 << 7, - BNDREG = 1 << 8 - }; - enum Code { -#ifdef XBYAK64 - RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15, - R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D, - R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W, - R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B, - SPL = 4, BPL, SIL, DIL, -#endif - EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI, - AX = 0, CX, DX, BX, SP, BP, SI, DI, - AL = 0, CL, DL, BL, AH, CH, DH, BH - }; - Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { } - Operand(int idx, Kind kind, int bit, bool ext8bit = 0) - : idx_(static_cast(idx | (ext8bit ? EXT8BIT : 0))) - , kind_(kind) - , bit_(bit) - , zero_(0), mask_(0), rounding_(0) - { - assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two - } - Kind getKind() const { return static_cast(kind_); } - int getIdx() const { return idx_ & (EXT8BIT - 1); } - bool isNone() const { return kind_ == 0; } - bool isMMX() const { return is(MMX); } - bool isXMM() const { return is(XMM); } - bool isYMM() const { return is(YMM); } - bool isZMM() const { return is(ZMM); } - bool isXMEM() const { return is(XMM | MEM); } - bool isYMEM() const { return is(YMM | MEM); } - bool isZMEM() const { return is(ZMM | MEM); } - bool isOPMASK() const { return is(OPMASK); } - bool isBNDREG() const { return is(BNDREG); } - bool isREG(int bit = 0) const { return is(REG, bit); } - bool isMEM(int bit = 0) const { return is(MEM, bit); } - bool isFPU() const { return is(FPU); } - bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; } - bool isExtIdx() const { return (getIdx() & 8) != 0; } - bool isExtIdx2() const { return (getIdx() & 16) != 0; } - bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); } - bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); } - bool hasZero() const { return zero_; } - int getOpmaskIdx() const { return mask_; } - int getRounding() const { return rounding_; } - void setKind(Kind kind) - { - if ((kind & (XMM|YMM|ZMM)) == 0) return; - kind_ = kind; - bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512; - } - void setBit(int bit) { bit_ = bit; } - void setOpmaskIdx(int idx, bool ignore_idx0 = false) - { - if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID); - if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET); - mask_ = idx; - } - void setRounding(int idx) - { - if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET); - rounding_ = idx; - } - void setZero() { zero_ = true; } - // ah, ch, dh, bh? - bool isHigh8bit() const - { - if (!isBit(8)) return false; - if (isExt8bit()) return false; - const int idx = getIdx(); - return AH <= idx && idx <= BH; - } - // any bit is accetable if bit == 0 - bool is(int kind, uint32 bit = 0) const - { - return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16) - } - bool isBit(uint32 bit) const { return (bit_ & bit) != 0; } - uint32 getBit() const { return bit_; } - const char *toString() const - { - const int idx = getIdx(); - if (kind_ == REG) { - if (isExt8bit()) { - static const char *tbl[4] = { "spl", "bpl", "sil", "dil" }; - return tbl[idx - 4]; - } - static const char *tbl[4][16] = { - { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" }, - { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" }, - { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" }, - { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" }, - }; - return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx]; - } else if (isOPMASK()) { - static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" }; - return tbl[idx]; - } else if (isZMM()) { - static const char *tbl[32] = { - "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", - "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31" - }; - return tbl[idx]; - } else if (isYMM()) { - static const char *tbl[32] = { - "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", - "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31" - }; - return tbl[idx]; - } else if (isXMM()) { - static const char *tbl[32] = { - "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", - "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31" - }; - return tbl[idx]; - } else if (isMMX()) { - static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" }; - return tbl[idx]; - } else if (isFPU()) { - static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" }; - return tbl[idx]; - } else if (isBNDREG()) { - static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" }; - return tbl[idx]; - } - throw Error(ERR_INTERNAL); - } - bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; } - bool operator==(const Operand& rhs) const; - bool operator!=(const Operand& rhs) const { return !operator==(rhs); } - const Address& getAddress() const; - const Reg& getReg() const; -}; - -class Label; - -struct Reg8; -struct Reg16; -struct Reg32; -#ifdef XBYAK64 -struct Reg64; -#endif -class Reg : public Operand { -public: - Reg() { } - Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { } - Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); } - uint8 getRexW() const { return isREG(64) ? 8 : 0; } - uint8 getRexR() const { return isExtIdx() ? 4 : 0; } - uint8 getRexX() const { return isExtIdx() ? 2 : 0; } - uint8 getRexB() const { return isExtIdx() ? 1 : 0; } - uint8 getRex(const Reg& base = Reg()) const - { - uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB(); - if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40; - return rex; - } - Reg8 cvt8() const; - Reg16 cvt16() const; - Reg32 cvt32() const; -#ifdef XBYAK64 - Reg64 cvt64() const; -#endif -}; - -inline const Reg& Operand::getReg() const -{ - assert(!isMEM()); - return static_cast(*this); -} - -struct Reg8 : public Reg { - explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { } -}; - -struct Reg16 : public Reg { - explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { } -}; - -struct Mmx : public Reg { - explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { } -}; - -struct EvexModifierRounding { - enum { - T_RN_SAE = 1, - T_RD_SAE = 2, - T_RU_SAE = 3, - T_RZ_SAE = 4, - T_SAE = 5 - }; - explicit EvexModifierRounding(int rounding) : rounding(rounding) {} - int rounding; -}; -struct EvexModifierZero{EvexModifierZero() {}}; - -struct Xmm : public Mmx { - explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { } - Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { } - Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; } - Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; } - Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; } -}; - -struct Ymm : public Xmm { - explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { } - Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; } -}; - -struct Zmm : public Ymm { - explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { } - Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; } -}; - -struct Opmask : public Reg { - explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {} -}; - -struct BoundsReg : public Reg { - explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {} -}; - -templateT operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; } -templateT operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; } -templateT operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; } - -struct Fpu : public Reg { - explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { } -}; - -struct Reg32e : public Reg { - explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {} -}; -struct Reg32 : public Reg32e { - explicit Reg32(int idx = 0) : Reg32e(idx, 32) {} -}; -#ifdef XBYAK64 -struct Reg64 : public Reg32e { - explicit Reg64(int idx = 0) : Reg32e(idx, 64) {} -}; -struct RegRip { - sint64 disp_; - const Label* label_; - bool isAddr_; - explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {} - friend const RegRip operator+(const RegRip& r, int disp) { - return RegRip(r.disp_ + disp, r.label_, r.isAddr_); - } - friend const RegRip operator-(const RegRip& r, int disp) { - return RegRip(r.disp_ - disp, r.label_, r.isAddr_); - } - friend const RegRip operator+(const RegRip& r, sint64 disp) { - return RegRip(r.disp_ + disp, r.label_, r.isAddr_); - } - friend const RegRip operator-(const RegRip& r, sint64 disp) { - return RegRip(r.disp_ - disp, r.label_, r.isAddr_); - } - friend const RegRip operator+(const RegRip& r, const Label& label) { - if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); - return RegRip(r.disp_, &label); - } - friend const RegRip operator+(const RegRip& r, const void *addr) { - if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); - return RegRip(r.disp_ + (sint64)addr, 0, true); - } -}; -#endif - -inline Reg8 Reg::cvt8() const -{ - const int idx = getIdx(); - if (isBit(8)) return Reg8(idx, isExt8bit()); -#ifdef XBYAK32 - if (idx >= 4) throw Error(ERR_CANT_CONVERT); -#endif - return Reg8(idx, 4 <= idx && idx < 8); -} - -inline Reg16 Reg::cvt16() const -{ - const int idx = getIdx(); - if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); - return Reg16(idx); -} - -inline Reg32 Reg::cvt32() const -{ - const int idx = getIdx(); - if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); - return Reg32(idx); -} - -#ifdef XBYAK64 -inline Reg64 Reg::cvt64() const -{ - const int idx = getIdx(); - if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); - return Reg64(idx); -} -#endif - -#ifndef XBYAK_DISABLE_SEGMENT -// not derived from Reg -class Segment { - int idx_; -public: - enum { - es, cs, ss, ds, fs, gs - }; - explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); } - int getIdx() const { return idx_; } - const char *toString() const - { - static const char tbl[][3] = { - "es", "cs", "ss", "ds", "fs", "gs" - }; - return tbl[idx_]; - } -}; -#endif - -class RegExp { -public: -#ifdef XBYAK64 - enum { i32e = 32 | 64 }; -#else - enum { i32e = 32 }; -#endif - RegExp(size_t disp = 0) : scale_(0), disp_(disp) { } - RegExp(const Reg& r, int scale = 1) - : scale_(scale) - , disp_(0) - { - if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER); - if (scale == 0) return; - if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE); - if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index - index_ = r; - } else { - base_ = r; - } - } - bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); } - RegExp optimize() const - { - RegExp exp = *this; - // [reg * 2] => [reg + reg] - if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) { - exp.base_ = index_; - exp.scale_ = 1; - } - return exp; - } - bool operator==(const RegExp& rhs) const - { - return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_; - } - const Reg& getBase() const { return base_; } - const Reg& getIndex() const { return index_; } - int getScale() const { return scale_; } - size_t getDisp() const { return disp_; } - void verify() const - { - if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER); - if (index_.getBit() && index_.getBit() <= 64) { - if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX); - if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER); - } - } - friend RegExp operator+(const RegExp& a, const RegExp& b); - friend RegExp operator-(const RegExp& e, size_t disp); - uint8 getRex() const - { - uint8 rex = index_.getRexX() | base_.getRexB(); - return rex ? uint8(rex | 0x40) : 0; - } -private: - /* - [base_ + index_ * scale_ + disp_] - base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm - */ - Reg base_; - Reg index_; - int scale_; - size_t disp_; -}; - -inline RegExp operator+(const RegExp& a, const RegExp& b) -{ - if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); - RegExp ret = a; - if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; } - if (b.base_.getBit()) { - if (ret.base_.getBit()) { - if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); - // base + base => base + index * 1 - ret.index_ = b.base_; - // [reg + esp] => [esp + reg] - if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_); - ret.scale_ = 1; - } else { - ret.base_ = b.base_; - } - } - ret.disp_ += b.disp_; - return ret; -} -inline RegExp operator*(const Reg& r, int scale) -{ - return RegExp(r, scale); -} -inline RegExp operator-(const RegExp& e, size_t disp) -{ - RegExp ret = e; - ret.disp_ -= disp; - return ret; -} - -// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc) -void *const AutoGrow = (void*)1; //-V566 -void *const DontSetProtectRWE = (void*)2; //-V566 - -class CodeArray { - enum Type { - USER_BUF = 1, // use userPtr(non alignment, non protect) - ALLOC_BUF, // use new(alignment, protect) - AUTO_GROW // automatically move and grow memory if necessary - }; - CodeArray(const CodeArray& rhs); - void operator=(const CodeArray&); - bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; } - struct AddrInfo { - size_t codeOffset; // position to write - size_t jmpAddr; // value to write - int jmpSize; // size of jmpAddr - inner::LabelMode mode; - AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode) - : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {} - uint64 getVal(const uint8 *top) const - { - uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top); - if (jmpSize == 4) disp = inner::VerifyInInt32(disp); - return disp; - } - }; - typedef std::list AddrInfoList; - AddrInfoList addrInfoList_; - const Type type_; -#ifdef XBYAK_USE_MMAP_ALLOCATOR - MmapAllocator defaultAllocator_; -#else - Allocator defaultAllocator_; -#endif - Allocator *alloc_; -protected: - size_t maxSize_; - uint8 *top_; - size_t size_; - bool isCalledCalcJmpAddress_; - - bool useProtect() const { return alloc_->useProtect(); } - /* - allocate new memory and copy old data to the new area - */ - void growMemory() - { - const size_t newSize = (std::max)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2); - uint8 *newTop = alloc_->alloc(newSize); - if (newTop == 0) throw Error(ERR_CANT_ALLOC); - for (size_t i = 0; i < size_; i++) newTop[i] = top_[i]; - alloc_->free(top_); - top_ = newTop; - maxSize_ = newSize; - } - /* - calc jmp address for AutoGrow mode - */ - void calcJmpAddress() - { - if (isCalledCalcJmpAddress_) return; - for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) { - uint64 disp = i->getVal(top_); - rewrite(i->codeOffset, disp, i->jmpSize); - } - isCalledCalcJmpAddress_ = true; - } -public: - enum ProtectMode { - PROTECT_RW = 0, // read/write - PROTECT_RWE = 1, // read/write/exec - PROTECT_RE = 2 // read/exec - }; - explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0) - : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF) - , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_) - , maxSize_(maxSize) - , top_(type_ == USER_BUF ? reinterpret_cast(userPtr) : alloc_->alloc((std::max)(maxSize, 1))) - , size_(0) - , isCalledCalcJmpAddress_(false) - { - if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC); - if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) { - alloc_->free(top_); - throw Error(ERR_CANT_PROTECT); - } - } - virtual ~CodeArray() - { - if (isAllocType()) { - if (useProtect()) setProtectModeRW(false); - alloc_->free(top_); - } - } - bool setProtectMode(ProtectMode mode, bool throwException = true) - { - bool isOK = protect(top_, maxSize_, mode); - if (isOK) return true; - if (throwException) throw Error(ERR_CANT_PROTECT); - return false; - } - bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); } - bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); } - void resetSize() - { - size_ = 0; - addrInfoList_.clear(); - isCalledCalcJmpAddress_ = false; - } - void db(int code) - { - if (size_ >= maxSize_) { - if (type_ == AUTO_GROW) { - growMemory(); - } else { - throw Error(ERR_CODE_IS_TOO_BIG); - } - } - top_[size_++] = static_cast(code); - } - void db(const uint8 *code, size_t codeSize) - { - for (size_t i = 0; i < codeSize; i++) db(code[i]); - } - void db(uint64 code, size_t codeSize) - { - if (codeSize > 8) throw Error(ERR_BAD_PARAMETER); - for (size_t i = 0; i < codeSize; i++) db(static_cast(code >> (i * 8))); - } - void dw(uint32 code) { db(code, 2); } - void dd(uint32 code) { db(code, 4); } - void dq(uint64 code) { db(code, 8); } - const uint8 *getCode() const { return top_; } - template - const F getCode() const { return reinterpret_cast(top_); } - const uint8 *getCurr() const { return &top_[size_]; } - template - const F getCurr() const { return reinterpret_cast(&top_[size_]); } - size_t getSize() const { return size_; } - void setSize(size_t size) - { - if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG); - size_ = size; - } - void dump() const - { - const uint8 *p = getCode(); - size_t bufSize = getSize(); - size_t remain = bufSize; - for (int i = 0; i < 4; i++) { - size_t disp = 16; - if (remain < 16) { - disp = remain; - } - for (size_t j = 0; j < 16; j++) { - if (j < disp) { - printf("%02X", p[i * 16 + j]); - } - } - putchar('\n'); - remain -= disp; - if (remain == 0) { - break; - } - } - } - /* - @param offset [in] offset from top - @param disp [in] offset from the next of jmp - @param size [in] write size(1, 2, 4, 8) - */ - void rewrite(size_t offset, uint64 disp, size_t size) - { - assert(offset < maxSize_); - if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER); - uint8 *const data = top_ + offset; - for (size_t i = 0; i < size; i++) { - data[i] = static_cast(disp >> (i * 8)); - } - } - void save(size_t offset, size_t val, int size, inner::LabelMode mode) - { - addrInfoList_.push_back(AddrInfo(offset, val, size, mode)); - } - bool isAutoGrow() const { return type_ == AUTO_GROW; } - bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; } - /** - change exec permission of memory - @param addr [in] buffer address - @param size [in] buffer size - @param protectMode [in] mode(RW/RWE/RE) - @return true(success), false(failure) - */ - static inline bool protect(const void *addr, size_t size, int protectMode) - { -#if defined(_WIN32) - const DWORD c_rw = PAGE_READWRITE; - const DWORD c_rwe = PAGE_EXECUTE_READWRITE; - const DWORD c_re = PAGE_EXECUTE_READ; - DWORD mode; -#else - const int c_rw = PROT_READ | PROT_WRITE; - const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC; - const int c_re = PROT_READ | PROT_EXEC; - int mode; -#endif - switch (protectMode) { - case PROTECT_RW: mode = c_rw; break; - case PROTECT_RWE: mode = c_rwe; break; - case PROTECT_RE: mode = c_re; break; - default: - return false; - } -#if defined(_WIN32) - DWORD oldProtect; - return VirtualProtect(const_cast(addr), size, mode, &oldProtect) != 0; -#elif defined(__GNUC__) - size_t pageSize = sysconf(_SC_PAGESIZE); - size_t iaddr = reinterpret_cast(addr); - size_t roundAddr = iaddr & ~(pageSize - static_cast(1)); -#ifndef NDEBUG - if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize); -#endif - return mprotect(reinterpret_cast(roundAddr), size + (iaddr - roundAddr), mode) == 0; -#else - return true; -#endif - } - /** - get aligned memory pointer - @param addr [in] address - @param alignedSize [in] power of two - @return aligned addr by alingedSize - */ - static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16) - { - return reinterpret_cast((reinterpret_cast(addr) + alignedSize - 1) & ~(alignedSize - static_cast(1))); - } -}; - -class Address : public Operand { -public: - enum Mode { - M_ModRM, - M_64bitDisp, - M_rip, - M_ripAddr - }; - Address(uint32 sizeBit, bool broadcast, const RegExp& e) - : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast) - { - e_.verify(); - } -#ifdef XBYAK64 - explicit Address(size_t disp) - : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ } - Address(uint32 sizeBit, bool broadcast, const RegRip& addr) - : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { } -#endif - RegExp getRegExp(bool optimize = true) const - { - return optimize ? e_.optimize() : e_; - } - Mode getMode() const { return mode_; } - bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; } - bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax - size_t getDisp() const { return e_.getDisp(); } - uint8 getRex() const - { - if (mode_ != M_ModRM) return 0; - return getRegExp().getRex(); - } - bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset - bool isBroadcast() const { return broadcast_; } - const Label* getLabel() const { return label_; } - bool operator==(const Address& rhs) const - { - return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_; - } - bool operator!=(const Address& rhs) const { return !operator==(rhs); } - bool isVsib() const { return e_.isVsib(); } -private: - RegExp e_; - const Label* label_; - Mode mode_; - bool broadcast_; -}; - -inline const Address& Operand::getAddress() const -{ - assert(isMEM()); - return static_cast(*this); -} - -inline bool Operand::operator==(const Operand& rhs) const -{ - if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress(); - return isEqualIfNotInherited(rhs); -} - -class AddressFrame { - void operator=(const AddressFrame&); - AddressFrame(const AddressFrame&); -public: - const uint32 bit_; - const bool broadcast_; - explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { } - Address operator[](const RegExp& e) const - { - return Address(bit_, broadcast_, e); - } - Address operator[](const void *disp) const - { - return Address(bit_, broadcast_, RegExp(reinterpret_cast(disp))); - } -#ifdef XBYAK64 - Address operator[](uint64 disp) const { return Address(disp); } - Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); } -#endif -}; - -struct JmpLabel { - size_t endOfJmp; /* offset from top to the end address of jmp */ - int jmpSize; - inner::LabelMode mode; - size_t disp; // disp for [rip + disp] - explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0) - : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp) - { - } -}; - -class LabelManager; - -class Label { - mutable LabelManager *mgr; - mutable int id; - friend class LabelManager; -public: - Label() : mgr(0), id(0) {} - Label(const Label& rhs); - Label& operator=(const Label& rhs); - ~Label(); - void clear() { mgr = 0; id = 0; } - int getId() const { return id; } - const uint8 *getAddress() const; - - // backward compatibility - static inline std::string toStr(int num) - { - char buf[16]; -#if defined(_MSC_VER) && (_MSC_VER < 1900) - _snprintf_s -#else - snprintf -#endif - (buf, sizeof(buf), ".%08x", num); - return buf; - } -}; - -class LabelManager { - // for string label - struct SlabelVal { - size_t offset; - SlabelVal(size_t offset) : offset(offset) {} - }; - typedef XBYAK_STD_UNORDERED_MAP SlabelDefList; - typedef XBYAK_STD_UNORDERED_MULTIMAP SlabelUndefList; - struct SlabelState { - SlabelDefList defList; - SlabelUndefList undefList; - }; - typedef std::list StateList; - // for Label class - struct ClabelVal { - ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {} - size_t offset; - int refCount; - }; - typedef XBYAK_STD_UNORDERED_MAP ClabelDefList; - typedef XBYAK_STD_UNORDERED_MULTIMAP ClabelUndefList; - typedef XBYAK_STD_UNORDERED_SET LabelPtrList; - - CodeArray *base_; - // global : stateList_.front(), local : stateList_.back() - StateList stateList_; - mutable int labelId_; - ClabelDefList clabelDefList_; - ClabelUndefList clabelUndefList_; - LabelPtrList labelPtrList_; - - int getId(const Label& label) const - { - if (label.id == 0) label.id = labelId_++; - return label.id; - } - template - void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset) - { - // add label - typename DefList::value_type item(labelId, addrOffset); - std::pair ret = defList.insert(item); - if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED); - // search undefined label - for (;;) { - typename UndefList::iterator itr = undefList.find(labelId); - if (itr == undefList.end()) break; - const JmpLabel *jmp = &itr->second; - const size_t offset = jmp->endOfJmp - jmp->jmpSize; - size_t disp; - if (jmp->mode == inner::LaddTop) { - disp = addrOffset; - } else if (jmp->mode == inner::Labs) { - disp = size_t(base_->getCurr()); - } else { - disp = addrOffset - jmp->endOfJmp + jmp->disp; -#ifdef XBYAK64 - if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG); -#endif - if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR); - } - if (base_->isAutoGrow()) { - base_->save(offset, disp, jmp->jmpSize, jmp->mode); - } else { - base_->rewrite(offset, disp, jmp->jmpSize); - } - undefList.erase(itr); - } - } - template - bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const - { - typename DefList::const_iterator i = defList.find(label); - if (i == defList.end()) return false; - *offset = i->second.offset; - return true; - } - friend class Label; - void incRefCount(int id, Label *label) - { - clabelDefList_[id].refCount++; - labelPtrList_.insert(label); - } - void decRefCount(int id, Label *label) - { - labelPtrList_.erase(label); - ClabelDefList::iterator i = clabelDefList_.find(id); - if (i == clabelDefList_.end()) return; - if (i->second.refCount == 1) { - clabelDefList_.erase(id); - } else { - --i->second.refCount; - } - } - template - bool hasUndefinedLabel_inner(const T& list) const - { -#ifndef NDEBUG - for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) { - std::cerr << "undefined label:" << i->first << std::endl; - } -#endif - return !list.empty(); - } - // detach all labels linked to LabelManager - void resetLabelPtrList() - { - for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) { - (*i)->clear(); - } - labelPtrList_.clear(); - } -public: - LabelManager() - { - reset(); - } - ~LabelManager() - { - resetLabelPtrList(); - } - void reset() - { - base_ = 0; - labelId_ = 1; - stateList_.clear(); - stateList_.push_back(SlabelState()); - stateList_.push_back(SlabelState()); - clabelDefList_.clear(); - clabelUndefList_.clear(); - resetLabelPtrList(); - } - void enterLocal() - { - stateList_.push_back(SlabelState()); - } - void leaveLocal() - { - if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL); - if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND); - stateList_.pop_back(); - } - void set(CodeArray *base) { base_ = base; } - void defineSlabel(std::string label) - { - if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR); - if (label == "@@") { - SlabelDefList& defList = stateList_.front().defList; - SlabelDefList::iterator i = defList.find("@f"); - if (i != defList.end()) { - defList.erase(i); - label = "@b"; - } else { - i = defList.find("@b"); - if (i != defList.end()) { - defList.erase(i); - } - label = "@f"; - } - } - SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); - define_inner(st.defList, st.undefList, label, base_->getSize()); - } - void defineClabel(Label& label) - { - define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize()); - label.mgr = this; - labelPtrList_.insert(&label); - } - void assign(Label& dst, const Label& src) - { - ClabelDefList::const_iterator i = clabelDefList_.find(src.id); - if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L); - define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset); - dst.mgr = this; - labelPtrList_.insert(&dst); - } - bool getOffset(size_t *offset, std::string& label) const - { - const SlabelDefList& defList = stateList_.front().defList; - if (label == "@b") { - if (defList.find("@f") != defList.end()) { - label = "@f"; - } else if (defList.find("@b") == defList.end()) { - throw Error(ERR_LABEL_IS_NOT_FOUND); - } - } else if (label == "@f") { - if (defList.find("@f") != defList.end()) { - label = "@b"; - } - } - const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); - return getOffset_inner(st.defList, offset, label); - } - bool getOffset(size_t *offset, const Label& label) const - { - return getOffset_inner(clabelDefList_, offset, getId(label)); - } - void addUndefinedLabel(const std::string& label, const JmpLabel& jmp) - { - SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); - st.undefList.insert(SlabelUndefList::value_type(label, jmp)); - } - void addUndefinedLabel(const Label& label, const JmpLabel& jmp) - { - clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp)); - } - bool hasUndefSlabel() const - { - for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) { - if (hasUndefinedLabel_inner(i->undefList)) return true; - } - return false; - } - bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); } - const uint8 *getCode() const { return base_->getCode(); } - bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); } -}; - -inline Label::Label(const Label& rhs) -{ - id = rhs.id; - mgr = rhs.mgr; - if (mgr) mgr->incRefCount(id, this); -} -inline Label& Label::operator=(const Label& rhs) -{ - if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L); - id = rhs.id; - mgr = rhs.mgr; - if (mgr) mgr->incRefCount(id, this); - return *this; -} -inline Label::~Label() -{ - if (id && mgr) mgr->decRefCount(id, this); -} -inline const uint8* Label::getAddress() const -{ - if (mgr == 0 || !mgr->isReady()) return 0; - size_t offset; - if (!mgr->getOffset(&offset, *this)) return 0; - return mgr->getCode() + offset; -} - -class CodeGenerator : public CodeArray { -public: - enum LabelType { - T_SHORT, - T_NEAR, - T_AUTO // T_SHORT if possible - }; -private: - CodeGenerator operator=(const CodeGenerator&); // don't call -#ifdef XBYAK64 - enum { i32e = 32 | 64, BIT = 64 }; - static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788; - typedef Reg64 NativeReg; -#else - enum { i32e = 32, BIT = 32 }; - static const size_t dummyAddr = 0x12345678; - typedef Reg32 NativeReg; -#endif - // (XMM, XMM|MEM) - static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2) - { - return op1.isXMM() && (op2.isXMM() || op2.isMEM()); - } - // (MMX, MMX|MEM) or (XMM, XMM|MEM) - static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2) - { - return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2); - } - // (XMM, MMX|MEM) - static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2) - { - return op1.isXMM() && (op2.isMMX() || op2.isMEM()); - } - // (MMX, XMM|MEM) - static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2) - { - return op1.isMMX() && (op2.isXMM() || op2.isMEM()); - } - // (XMM, REG32|MEM) - static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2) - { - return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM()); - } - // (REG32, XMM|MEM) - static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2) - { - return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM()); - } - // (REG32, REG32|MEM) - static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2) - { - return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM()); - } - void rex(const Operand& op1, const Operand& op2 = Operand()) - { - uint8 rex = 0; - const Operand *p1 = &op1, *p2 = &op2; - if (p1->isMEM()) std::swap(p1, p2); - if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); - if (p2->isMEM()) { - const Address& addr = p2->getAddress(); - if (BIT == 64 && addr.is32bit()) db(0x67); - rex = addr.getRex() | p1->getReg().getRex(); - } else { - // ModRM(reg, base); - rex = op2.getReg().getRex(op1.getReg()); - } - // except movsx(16bit, 32/64bit) - if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66); - if (rex) db(rex); - } - enum AVXtype { - // low 3 bit - T_N1 = 1, - T_N2 = 2, - T_N4 = 3, - T_N8 = 4, - T_N16 = 5, - T_N32 = 6, - T_NX_MASK = 7, - // - T_N_VL = 1 << 3, // N * (1, 2, 4) for VL - T_DUP = 1 << 4, // N = (8, 32, 64) - T_66 = 1 << 5, - T_F3 = 1 << 6, - T_F2 = 1 << 7, - T_0F = 1 << 8, - T_0F38 = 1 << 9, - T_0F3A = 1 << 10, - T_L0 = 1 << 11, - T_L1 = 1 << 12, - T_W0 = 1 << 13, - T_W1 = 1 << 14, - T_EW0 = 1 << 15, - T_EW1 = 1 << 16, - T_YMM = 1 << 17, // support YMM, ZMM - T_EVEX = 1 << 18, - T_ER_X = 1 << 19, // xmm{er} - T_ER_Y = 1 << 20, // ymm{er} - T_ER_Z = 1 << 21, // zmm{er} - T_SAE_X = 1 << 22, // xmm{sae} - T_SAE_Y = 1 << 23, // ymm{sae} - T_SAE_Z = 1 << 24, // zmm{sae} - T_MUST_EVEX = 1 << 25, // contains T_EVEX - T_B32 = 1 << 26, // m32bcst - T_B64 = 1 << 27, // m64bcst - T_M_K = 1 << 28, // mem{k} - T_VSIB = 1 << 29, - T_MEM_EVEX = 1 << 30, // use evex if mem - T_XXX - }; - void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false) - { - int w = (type & T_W1) ? 1 : 0; - bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM(); - bool r = reg.isExtIdx(); - bool b = base.isExtIdx(); - int idx = v ? v->getIdx() : 0; - if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION); - uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; - uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp; - if (!b && !x && !w && (type & T_0F)) { - db(0xC5); db((r ? 0 : 0x80) | vvvv); - } else { - uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; - db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv); - } - db(code); - } - void verifySAE(const Reg& r, int type) const - { - if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return; - throw Error(ERR_SAE_IS_INVALID); - } - void verifyER(const Reg& r, int type) const - { - if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return; - throw Error(ERR_ER_IS_INVALID); - } - // (a, b, c) contains non zero two or three values then err - int verifyDuplicate(int a, int b, int c, int err) - { - int v = a | b | c; - if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err); - return v; - } - int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false) - { - if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID); - int w = (type & T_EW1) ? 1 : 0; - uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; - uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; - - int idx = v ? v->getIdx() : 0; - uint32 vvvv = ~idx; - - bool R = !reg.isExtIdx(); - bool X = x ? false : !base.isExtIdx2(); - bool B = !base.isExtIdx(); - bool Rp = !reg.isExtIdx2(); - int LL; - int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET); - int disp8N = 1; - if (rounding) { - if (rounding == EvexModifierRounding::T_SAE) { - verifySAE(base, type); LL = 0; - } else { - verifyER(base, type); LL = rounding - 1; - } - b = true; - } else { - if (v) VL = (std::max)(VL, v->getBit()); - VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL); - LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0; - if (b) { - disp8N = (type & T_B32) ? 4 : 8; - } else if (type & T_DUP) { - disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64; - } else { - if ((type & (T_NX_MASK | T_N_VL)) == 0) { - type |= T_N16 | T_N_VL; // default - } - int low = type & T_NX_MASK; - if (low > 0) { - disp8N = 1 << (low - 1); - if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1); - } - } - } - bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx); - bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false); - if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET); - db(0x62); - db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3)); - db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3)); - db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7)); - db(code); - return disp8N; - } - void setModRM(int mod, int r1, int r2) - { - db(static_cast((mod << 6) | ((r1 & 7) << 3) | (r2 & 7))); - } - void setSIB(const RegExp& e, int reg, int disp8N = 0) - { - size_t disp64 = e.getDisp(); -#ifdef XBYAK64 - size_t high = disp64 >> 32; - if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG); -#endif - uint32 disp = static_cast(disp64); - const Reg& base = e.getBase(); - const Reg& index = e.getIndex(); - const int baseIdx = base.getIdx(); - const int baseBit = base.getBit(); - const int indexBit = index.getBit(); - enum { - mod00 = 0, mod01 = 1, mod10 = 2 - }; - int mod = mod10; // disp32 - if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) { - mod = mod00; - } else { - if (disp8N == 0) { - if (inner::IsInDisp8(disp)) { - mod = mod01; - } - } else { - // disp must be casted to signed - uint32 t = static_cast(static_cast(disp) / disp8N); - if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) { - disp = t; - mod = mod01; - } - } - } - const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP; - /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */ - bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP; -#ifdef XBYAK64 - if (!baseBit && !indexBit) hasSIB = true; -#endif - if (hasSIB) { - setModRM(mod, reg, Operand::ESP); - /* SIB = [2:3:3] = [SS:index:base(=rm)] */ - const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP; - const int scale = e.getScale(); - const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0; - setModRM(SS, idx, newBaseIdx); - } else { - setModRM(mod, reg, newBaseIdx); - } - if (mod == mod01) { - db(disp); - } else if (mod == mod10 || (mod == mod00 && !baseBit)) { - dd(disp); - } - } - LabelManager labelMgr_; - bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; } - void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE) - { - rex(reg2, reg1); - db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); - setModRM(3, reg1.getIdx(), reg2.getIdx()); - } - void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) - { - if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); - rex(addr, reg); - db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); - opAddr(addr, reg.getIdx(), immSize); - } - void opMIB(const Address& addr, const Reg& reg, int code0, int code1) - { - if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); - if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS); - if (BIT == 64 && addr.is32bit()) db(0x67); - const RegExp& regExp = addr.getRegExp(false); - uint8 rex = regExp.getRex(); - if (rex) db(rex); - db(code0); db(code1); - setSIB(regExp, reg.getIdx()); - } - void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) - { - const int shortJmpSize = 2; - const int longHeaderSize = longPref ? 2 : 1; - const int longJmpSize = longHeaderSize + 4; - if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) { - db(shortCode); db(disp - shortJmpSize); - } else { - if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR); - if (longPref) db(longPref); - db(longCode); dd(disp - longJmpSize); - } - } - template - void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) - { - if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */ - size_t offset = 0; - if (labelMgr_.getOffset(&offset, label)) { /* label exists */ - makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref); - } else { - int jmpSize = 0; - if (type == T_NEAR) { - jmpSize = 4; - if (longPref) db(longPref); - db(longCode); dd(0); - } else { - jmpSize = 1; - db(shortCode); db(0); - } - JmpLabel jmp(size_, jmpSize, inner::LasIs); - labelMgr_.addUndefinedLabel(label, jmp); - } - } - void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0) - { - if (isAutoGrow()) { - if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW); - if (size_ + 16 >= maxSize_) growMemory(); - if (longPref) db(longPref); - db(longCode); - dd(0); - save(size_ - 4, size_t(addr) - size_, 4, inner::Labs); - } else { - makeJmp(inner::VerifyInInt32(reinterpret_cast(addr) - getCurr()), type, shortCode, longCode, longPref); - } - - } - // reg is reg field of ModRM - // immSize is the size for immediate value - // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement - void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false) - { - if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING); - if (addr.getMode() == Address::M_ModRM) { - setSIB(addr.getRegExp(), reg, disp8N); - } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) { - setModRM(0, reg, 5); - if (addr.getLabel()) { // [rip + Label] - putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize); - } else { - size_t disp = addr.getDisp(); - if (addr.getMode() == Address::M_ripAddr) { - if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW); - disp -= (size_t)getCurr() + 4 + immSize; - } - dd(inner::VerifyInInt32(disp)); - } - } - } - /* preCode is for SSSE3/SSE4 */ - void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE) - { - if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION); - if (pref != NONE) db(pref); - if (op.isMEM()) { - opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0); - } else { - opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code); - } - if (imm8 != NONE) db(imm8); - } - void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext) - { - if (mmx.isXMM()) db(0x66); - opModR(Reg32(ext), mmx, 0x0F, code); - db(imm8); - } - void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE) - { - opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode); - } - void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref) - { - if (pref != NONE) db(pref); - if (op1.isXMM() && op2.isMEM()) { - opModM(op2.getAddress(), op1.getReg(), 0x0F, code); - } else if (op1.isMEM() && op2.isXMM()) { - opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1); - } else { - throw Error(ERR_BAD_COMBINATION); - } - } - void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false) - { - if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */ - if (mmx.isXMM()) db(0x66); - opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm); - } else { - opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A); - } - } - void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0) - { - int opBit = op.getBit(); - if (disableRex && opBit == 64) opBit = 32; - if (op.isREG(bit)) { - opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2); - } else if (op.isMEM()) { - opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize); - } else { - throw Error(ERR_BAD_COMBINATION); - } - } - void opShift(const Operand& op, int imm, int ext) - { - verifyMemHasSize(op); - opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0); - if (imm != 1) db(imm); - } - void opShift(const Operand& op, const Reg8& _cl, int ext) - { - if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); - opR_ModM(op, 0, ext, 0xD2); - } - void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) - { - if (condR) { - opModR(op1.getReg(), op2.getReg(), code0, code1, code2); - } else if (condM) { - opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize); - } else { - throw Error(ERR_BAD_COMBINATION); - } - } - void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0) - { - if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); - opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1); - if (!_cl) db(imm); - } - // (REG, REG|MEM), (MEM, REG) - void opRM_RM(const Operand& op1, const Operand& op2, int code) - { - if (op1.isREG() && op2.isMEM()) { - opModM(op2.getAddress(), op1.getReg(), code | 2); - } else { - opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code); - } - } - // (REG|MEM, IMM) - void opRM_I(const Operand& op, uint32 imm, int code, int ext) - { - verifyMemHasSize(op); - uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32; - if (op.isBit(8)) immBit = 8; - if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG); - if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */ - if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al - rex(op); - db(code | 4 | (immBit == 8 ? 0 : 1)); - } else { - int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0; - opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8); - } - db(imm, immBit / 8); - } - void opIncDec(const Operand& op, int code, int ext) - { - verifyMemHasSize(op); -#ifndef XBYAK64 - if (op.isREG() && !op.isBit(8)) { - rex(op); db(code | op.getIdx()); - return; - } -#endif - code = 0xFE; - if (op.isREG()) { - opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code); - } else { - opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code); - } - } - void opPushPop(const Operand& op, int code, int ext, int alt) - { - int bit = op.getBit(); - if (bit == 16 || bit == BIT) { - if (bit == 16) db(0x66); - if (op.isREG()) { - if (op.getReg().getIdx() >= 8) db(0x41); - db(alt | (op.getIdx() & 7)); - return; - } - if (op.isMEM()) { - opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code); - return; - } - } - throw Error(ERR_BAD_COMBINATION); - } - void verifyMemHasSize(const Operand& op) const - { - if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED); - } - /* - mov(r, imm) = db(imm, mov_imm(r, imm)) - */ - int mov_imm(const Reg& reg, size_t imm) - { - int bit = reg.getBit(); - const int idx = reg.getIdx(); - int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3); - if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) { - rex(Reg32(idx)); - bit = 32; - } else { - rex(reg); - if (bit == 64 && inner::IsInInt32(imm)) { - db(0xC7); - code = 0xC0; - bit = 32; - } - } - db(code | (idx & 7)); - return bit / 8; - } - template - void putL_inner(T& label, bool relative = false, size_t disp = 0) - { - const int jmpSize = relative ? 4 : (int)sizeof(size_t); - if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); - size_t offset = 0; - if (labelMgr_.getOffset(&offset, label)) { - if (relative) { - db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize); - } else if (isAutoGrow()) { - db(uint64(0), jmpSize); - save(size_ - jmpSize, offset, jmpSize, inner::LaddTop); - } else { - db(size_t(top_) + offset, jmpSize); - } - return; - } - db(uint64(0), jmpSize); - JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp); - labelMgr_.addUndefinedLabel(label, jmp); - } - void opMovxx(const Reg& reg, const Operand& op, uint8 code) - { - if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION); - int w = op.isBit(16); -#ifdef XBYAK64 - if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION); -#endif - bool cond = reg.isREG() && (reg.getBit() > op.getBit()); - opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w); - } - void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext) - { - if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); - uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0; - if (!code) throw Error(ERR_BAD_MEM_SIZE); - if (m64ext && addr.isBit(64)) ext = m64ext; - - rex(addr, st0); - db(code); - opAddr(addr, ext); - } - // use code1 if reg1 == st0 - // use code2 if reg1 != st0 && reg2 == st0 - void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2) - { - uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0; - if (!code) throw Error(ERR_BAD_ST_COMBINATION); - db(uint8(code >> 8)); - db(uint8(code | (reg1.getIdx() | reg2.getIdx()))); - } - void opFpu(const Fpu& reg, uint8 code1, uint8 code2) - { - db(code1); db(code2 | reg.getIdx()); - } - void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE) - { - if (op2.isMEM()) { - const Address& addr = op2.getAddress(); - const RegExp& regExp = addr.getRegExp(); - const Reg& base = regExp.getBase(); - const Reg& index = regExp.getIndex(); - if (BIT == 64 && addr.is32bit()) db(0x67); - int disp8N = 0; - bool x = index.isExtIdx(); - if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) { - int aaa = addr.getOpmaskIdx(); - if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY); - bool b = false; - if (addr.isBroadcast()) { - if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST); - b = true; - } - int VL = regExp.isVsib() ? index.getBit() : 0; - disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2()); - } else { - vex(r, base, p1, type, code, x); - } - opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0); - } else { - const Reg& base = op2.getReg(); - if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) { - evex(r, base, p1, type, code); - } else { - vex(r, base, p1, type, code); - } - setModRM(3, r.getIdx(), base.getIdx()); - } - if (imm8 != NONE) db(imm8); - } - // (r, r, r/m) if isR_R_RM - // (r, r/m, r) - void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE) - { - const Operand *p1 = &op1; - const Operand *p2 = &op2; - if (!isR_R_RM) std::swap(p1, p2); - const unsigned int bit = r.getBit(); - if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION); - type |= (bit == 64) ? T_W1 : T_W0; - opVex(r, p1, *p2, type, code, imm8); - } - void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE) - { - const Xmm *x2 = static_cast(&op1); - const Operand *op = &op2; - if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1) - x2 = &x1; - op = &op1; - } - // (x1, x2, op) - if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION); - opVex(x1, x2, *op, type, code0, imm8); - } - void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE) - { - if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION); - opVex(k, &x2, op3, type, code0, imm8); - } - // (x, x/m), (y, x/m256), (z, y/m) - void checkCvt1(const Operand& x, const Operand& op) const - { - if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION); - } - // (x, x/m), (x, y/m256), (y, z/m) - void checkCvt2(const Xmm& x, const Operand& op) const - { - if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION); - } - void opCvt2(const Xmm& x, const Operand& op, int type, int code) - { - checkCvt2(x, op); - Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM; - opVex(x.copyAndSetKind(kind), &xm0, op, type, code); - } - void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code) - { - if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER); - Xmm x(op.getIdx()); - const Operand *p = op.isREG() ? &x : &op; - opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code); - } - const Xmm& cvtIdx0(const Operand& x) const - { - return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0; - } - // support (x, x/m, imm), (y, y/m, imm) - void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE) - { - opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8); - } - // QQQ:need to refactor - void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1) - { - if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); - bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM()); - if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); - if (is16bit) db(0x66); - db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1); - } - void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode) - { - const RegExp& regExp = addr.getRegExp(); - if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING); - const int y_vx_y = 0; - const int y_vy_y = 1; -// const int x_vy_x = 2; - const bool isAddrYMM = regExp.getIndex().getBit() == 256; - if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) { - bool isOK = false; - if (mode == y_vx_y) { - isOK = x1.isYMM() && !isAddrYMM && x2.isYMM(); - } else if (mode == y_vy_y) { - isOK = x1.isYMM() && isAddrYMM && x2.isYMM(); - } else { // x_vy_x - isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM(); - } - if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING); - } - opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code); - } - enum { - xx_yy_zz = 0, - xx_yx_zy = 1, - xx_xy_yz = 2 - }; - void checkGather2(const Xmm& x1, const Reg& x2, int mode) const - { - if (x1.isXMM() && x2.isXMM()) return; - switch (mode) { - case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return; - break; - case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return; - break; - case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return; - break; - } - throw Error(ERR_BAD_VSIB_ADDRESSING); - } - void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode) - { - if (x.hasZero()) throw Error(ERR_INVALID_ZERO); - checkGather2(x, addr.getRegExp().getIndex(), mode); - opVex(x, 0, addr, type, code); - } - /* - xx_xy_yz ; mode = true - xx_xy_xz ; mode = false - */ - void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode) - { - if (mode) { - if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION); - } else { - if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); - } - opVex(x, 0, op, type, code); - } - void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind) - { - if (addr.hasZero()) throw Error(ERR_INVALID_ZERO); - if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING); - opVex(x, 0, addr, type, code); - } -public: - unsigned int getVersion() const { return VERSION; } - using CodeArray::db; - const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7; - const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7; - const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; - const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; - const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7; - const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7; - const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7; - const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi; - const Reg16 ax, cx, dx, bx, sp, bp, si, di; - const Reg8 al, cl, dl, bl, ah, ch, dh, bh; - const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM - const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b} - const Fpu st0, st1, st2, st3, st4, st5, st6, st7; - const Opmask k0, k1, k2, k3, k4, k5, k6, k7; - const BoundsReg bnd0, bnd1, bnd2, bnd3; - const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae} - const EvexModifierZero T_z; // {z} -#ifdef XBYAK64 - const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15; - const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d; - const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w; - const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b; - const Reg8 spl, bpl, sil, dil; - const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15; - const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23; - const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31; - const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15; - const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23; - const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; - const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15; - const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23; - const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; - const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience - const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23; - const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31; - const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15; - const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23; - const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31; - const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15; - const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23; - const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31; - const RegRip rip; -#endif -#ifndef XBYAK_DISABLE_SEGMENT - const Segment es, cs, ss, ds, fs, gs; -#endif - void L(const std::string& label) { labelMgr_.defineSlabel(label); } - void L(Label& label) { labelMgr_.defineClabel(label); } - Label L() { Label label; L(label); return label; } - void inLocalLabel() { labelMgr_.enterLocal(); } - void outLocalLabel() { labelMgr_.leaveLocal(); } - /* - assign src to dst - require - dst : does not used by L() - src : used by L() - */ - void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); } - /* - put address of label to buffer - @note the put size is 4(32-bit), 8(64-bit) - */ - void putL(std::string label) { putL_inner(label); } - void putL(const Label& label) { putL_inner(label); } - - void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); } - void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } - void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); } - void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } - void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); } - - void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); } - // call(string label), not const std::string& - void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } - void call(const char *label) { call(std::string(label)); } - void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } - // call(function pointer) -#ifdef XBYAK_VARIADIC_TEMPLATE - template - void call(Ret(*func)(Params...)) { call(reinterpret_cast(func)); } -#endif - void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); } - - void test(const Operand& op, const Reg& reg) - { - opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84); - } - void test(const Operand& op, uint32 imm) - { - verifyMemHasSize(op); - int immSize = (std::min)(op.getBit() / 8, 4U); - if (op.isREG() && op.getIdx() == 0) { // al, ax, eax - rex(op); - db(0xA8 | (op.isBit(8) ? 0 : 1)); - } else { - opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize); - } - db(imm, immSize); - } - void imul(const Reg& reg, const Operand& op) - { - opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF); - } - void imul(const Reg& reg, const Operand& op, int imm) - { - int s = inner::IsInDisp8(imm) ? 1 : 0; - int immSize = s ? 1 : reg.isREG(16) ? 2 : 4; - opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize); - db(imm, immSize); - } - void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); } - void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); } - void push(const AddressFrame& af, uint32 imm) - { - if (af.bit_ == 8 && inner::IsInDisp8(imm)) { - db(0x6A); db(imm); - } else if (af.bit_ == 16 && isInDisp16(imm)) { - db(0x66); db(0x68); dw(imm); - } else { - db(0x68); dd(imm); - } - } - /* use "push(word, 4)" if you want "push word 4" */ - void push(uint32 imm) - { - if (inner::IsInDisp8(imm)) { - push(byte, imm); - } else { - push(dword, imm); - } - } - void mov(const Operand& reg1, const Operand& reg2) - { - const Reg *reg = 0; - const Address *addr = 0; - uint8 code = 0; - if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp] - reg = ®1.getReg(); - addr= ®2.getAddress(); - code = 0xA0; - } else - if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al - reg = ®2.getReg(); - addr= ®1.getAddress(); - code = 0xA2; - } -#ifdef XBYAK64 - if (addr && addr->is64bitDisp()) { - if (code) { - rex(*reg); - db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3); - db(addr->getDisp(), 8); - } else { - throw Error(ERR_BAD_COMBINATION); - } - } else -#else - if (code && addr->isOnlyDisp()) { - rex(*reg, *addr); - db(code | (reg->isBit(8) ? 0 : 1)); - dd(static_cast(addr->getDisp())); - } else -#endif - { - opRM_RM(reg1, reg2, 0x88); - } - } - void mov(const Operand& op, size_t imm) - { - if (op.isREG()) { - const int size = mov_imm(op.getReg(), imm); - db(imm, size); - } else if (op.isMEM()) { - verifyMemHasSize(op); - int immSize = op.getBit() / 8; - if (immSize <= 4) { - sint64 s = sint64(imm) >> (immSize * 8); - if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG); - } else { - if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG); - immSize = 4; - } - opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize); - db(static_cast(imm), immSize); - } else { - throw Error(ERR_BAD_COMBINATION); - } - } - void mov(const NativeReg& reg, const char *label) // can't use std::string - { - if (label == 0) { - mov(static_cast(reg), 0); // call imm - return; - } - mov_imm(reg, dummyAddr); - putL(label); - } - void mov(const NativeReg& reg, const Label& label) - { - mov_imm(reg, dummyAddr); - putL(label); - } - void xchg(const Operand& op1, const Operand& op2) - { - const Operand *p1 = &op1, *p2 = &op2; - if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) { - p1 = &op2; p2 = &op1; - } - if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); - if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0) -#ifdef XBYAK64 - && (p2->getIdx() != 0 || !p1->isREG(32)) -#endif - ) { - rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7)); - return; - } - opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1)); - } - -#ifndef XBYAK_DISABLE_SEGMENT - void push(const Segment& seg) - { - switch (seg.getIdx()) { - case Segment::es: db(0x06); break; - case Segment::cs: db(0x0E); break; - case Segment::ss: db(0x16); break; - case Segment::ds: db(0x1E); break; - case Segment::fs: db(0x0F); db(0xA0); break; - case Segment::gs: db(0x0F); db(0xA8); break; - default: - assert(0); - } - } - void pop(const Segment& seg) - { - switch (seg.getIdx()) { - case Segment::es: db(0x07); break; - case Segment::cs: throw Error(ERR_BAD_COMBINATION); - case Segment::ss: db(0x17); break; - case Segment::ds: db(0x1F); break; - case Segment::fs: db(0x0F); db(0xA1); break; - case Segment::gs: db(0x0F); db(0xA9); break; - default: - assert(0); - } - } - void putSeg(const Segment& seg) - { - switch (seg.getIdx()) { - case Segment::es: db(0x2E); break; - case Segment::cs: db(0x36); break; - case Segment::ss: db(0x3E); break; - case Segment::ds: db(0x26); break; - case Segment::fs: db(0x64); break; - case Segment::gs: db(0x65); break; - default: - assert(0); - } - } - void mov(const Operand& op, const Segment& seg) - { - opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C); - } - void mov(const Segment& seg, const Operand& op) - { - opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E); - } -#endif - - enum { NONE = 256 }; - // constructor - CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0) - : CodeArray(maxSize, userPtr, allocator) - , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7) - , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7) - , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7) - , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7) - // for my convenience - , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7) - , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7) - , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7) - - , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI) - , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI) - , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH) - , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512) - , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true) - , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7) - , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7) - , bnd0(0), bnd1(1), bnd2(2), bnd3(3) - , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE) - , T_z() -#ifdef XBYAK64 - , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15) - , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15) - , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15) - , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15) - , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true) - , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15) - , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23) - , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31) - , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15) - , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23) - , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31) - , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15) - , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23) - , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31) - // for my convenience - , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15) - , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23) - , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31) - , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15) - , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23) - , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31) - , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15) - , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23) - , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31) - , rip() -#endif -#ifndef XBYAK_DISABLE_SEGMENT - , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs) -#endif - { - labelMgr_.set(this); - } - void reset() - { - resetSize(); - labelMgr_.reset(); - labelMgr_.set(this); - } - bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); } - /* - MUST call ready() to complete generating code if you use AutoGrow mode. - It is not necessary for the other mode if hasUndefinedLabel() is true. - */ - void ready(ProtectMode mode = PROTECT_RWE) - { - if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND); - if (isAutoGrow()) { - calcJmpAddress(); - if (useProtect()) setProtectMode(mode); - } - } - // set read/exec - void readyRE() { return ready(PROTECT_RE); } -#ifdef XBYAK_TEST - void dump(bool doClear = true) - { - CodeArray::dump(); - if (doClear) size_ = 0; - } -#endif - -#ifdef XBYAK_UNDEF_JNL - #undef jnl -#endif - - /* - use single byte nop if useMultiByteNop = false - */ - void nop(size_t size = 1, bool useMultiByteNop = true) - { - if (!useMultiByteNop) { - for (size_t i = 0; i < size; i++) { - db(0x90); - } - return; - } - /* - Intel Architectures Software Developer's Manual Volume 2 - recommended multi-byte sequence of NOP instruction - AMD and Intel seem to agree on the same sequences for up to 9 bytes: - https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf - */ - static const uint8 nopTbl[9][9] = { - {0x90}, - {0x66, 0x90}, - {0x0F, 0x1F, 0x00}, - {0x0F, 0x1F, 0x40, 0x00}, - {0x0F, 0x1F, 0x44, 0x00, 0x00}, - {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00}, - {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00}, - {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, - }; - const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]); - while (size > 0) { - size_t len = (std::min)(n, size); - const uint8 *seq = nopTbl[len - 1]; - db(seq, len); - size -= len; - } - } - -#ifndef XBYAK_DONT_READ_LIST -#include "xbyak_mnemonic.h" - /* - use single byte nop if useMultiByteNop = false - */ - void align(size_t x = 16, bool useMultiByteNop = true) - { - if (x == 1) return; - if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN); - if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x); - size_t remain = size_t(getCurr()) % x; - if (remain) { - nop(x - remain, useMultiByteNop); - } - } -#endif -}; - -namespace util { -static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7); -static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7); -static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7); -static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7); -static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI); -static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI); -static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH); -static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512); -static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true); -static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7); -static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7); -static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3); -static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE); -static const EvexModifierZero T_z; -#ifdef XBYAK64 -static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15); -static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15); -static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15); -static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true); -static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15); -static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23); -static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31); -static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15); -static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23); -static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31); -static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15); -static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23); -static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31); -static const RegRip rip; -#endif -#ifndef XBYAK_DISABLE_SEGMENT -static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs); -#endif -} // util - -#ifdef _MSC_VER - #pragma warning(pop) -#endif - -} // end of namespace - -#endif // XBYAK_XBYAK_H_ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h deleted file mode 100644 index a22e5224c..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h +++ /dev/null @@ -1,303 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/******************************************************************************* -* Copyright (c) 2007 MITSUNARI Shigeo -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* Redistributions of source code must retain the above copyright notice, this -* list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* Neither the name of the copyright owner nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -*******************************************************************************/ - -enum { - B00000000= 0, - B00000001= 1, - B00000010= 2, - B00000011= 3, - B00000100= 4, - B00000101= 5, - B00000110= 6, - B00000111= 7, - B00001000= 8, - B00001001= 9, - B00001010= 10, - B00001011= 11, - B00001100= 12, - B00001101= 13, - B00001110= 14, - B00001111= 15, - B00010000= 16, - B00010001= 17, - B00010010= 18, - B00010011= 19, - B00010100= 20, - B00010101= 21, - B00010110= 22, - B00010111= 23, - B00011000= 24, - B00011001= 25, - B00011010= 26, - B00011011= 27, - B00011100= 28, - B00011101= 29, - B00011110= 30, - B00011111= 31, - B00100000= 32, - B00100001= 33, - B00100010= 34, - B00100011= 35, - B00100100= 36, - B00100101= 37, - B00100110= 38, - B00100111= 39, - B00101000= 40, - B00101001= 41, - B00101010= 42, - B00101011= 43, - B00101100= 44, - B00101101= 45, - B00101110= 46, - B00101111= 47, - B00110000= 48, - B00110001= 49, - B00110010= 50, - B00110011= 51, - B00110100= 52, - B00110101= 53, - B00110110= 54, - B00110111= 55, - B00111000= 56, - B00111001= 57, - B00111010= 58, - B00111011= 59, - B00111100= 60, - B00111101= 61, - B00111110= 62, - B00111111= 63, - B01000000= 64, - B01000001= 65, - B01000010= 66, - B01000011= 67, - B01000100= 68, - B01000101= 69, - B01000110= 70, - B01000111= 71, - B01001000= 72, - B01001001= 73, - B01001010= 74, - B01001011= 75, - B01001100= 76, - B01001101= 77, - B01001110= 78, - B01001111= 79, - B01010000= 80, - B01010001= 81, - B01010010= 82, - B01010011= 83, - B01010100= 84, - B01010101= 85, - B01010110= 86, - B01010111= 87, - B01011000= 88, - B01011001= 89, - B01011010= 90, - B01011011= 91, - B01011100= 92, - B01011101= 93, - B01011110= 94, - B01011111= 95, - B01100000= 96, - B01100001= 97, - B01100010= 98, - B01100011= 99, - B01100100= 100, - B01100101= 101, - B01100110= 102, - B01100111= 103, - B01101000= 104, - B01101001= 105, - B01101010= 106, - B01101011= 107, - B01101100= 108, - B01101101= 109, - B01101110= 110, - B01101111= 111, - B01110000= 112, - B01110001= 113, - B01110010= 114, - B01110011= 115, - B01110100= 116, - B01110101= 117, - B01110110= 118, - B01110111= 119, - B01111000= 120, - B01111001= 121, - B01111010= 122, - B01111011= 123, - B01111100= 124, - B01111101= 125, - B01111110= 126, - B01111111= 127, - B10000000= 128, - B10000001= 129, - B10000010= 130, - B10000011= 131, - B10000100= 132, - B10000101= 133, - B10000110= 134, - B10000111= 135, - B10001000= 136, - B10001001= 137, - B10001010= 138, - B10001011= 139, - B10001100= 140, - B10001101= 141, - B10001110= 142, - B10001111= 143, - B10010000= 144, - B10010001= 145, - B10010010= 146, - B10010011= 147, - B10010100= 148, - B10010101= 149, - B10010110= 150, - B10010111= 151, - B10011000= 152, - B10011001= 153, - B10011010= 154, - B10011011= 155, - B10011100= 156, - B10011101= 157, - B10011110= 158, - B10011111= 159, - B10100000= 160, - B10100001= 161, - B10100010= 162, - B10100011= 163, - B10100100= 164, - B10100101= 165, - B10100110= 166, - B10100111= 167, - B10101000= 168, - B10101001= 169, - B10101010= 170, - B10101011= 171, - B10101100= 172, - B10101101= 173, - B10101110= 174, - B10101111= 175, - B10110000= 176, - B10110001= 177, - B10110010= 178, - B10110011= 179, - B10110100= 180, - B10110101= 181, - B10110110= 182, - B10110111= 183, - B10111000= 184, - B10111001= 185, - B10111010= 186, - B10111011= 187, - B10111100= 188, - B10111101= 189, - B10111110= 190, - B10111111= 191, - B11000000= 192, - B11000001= 193, - B11000010= 194, - B11000011= 195, - B11000100= 196, - B11000101= 197, - B11000110= 198, - B11000111= 199, - B11001000= 200, - B11001001= 201, - B11001010= 202, - B11001011= 203, - B11001100= 204, - B11001101= 205, - B11001110= 206, - B11001111= 207, - B11010000= 208, - B11010001= 209, - B11010010= 210, - B11010011= 211, - B11010100= 212, - B11010101= 213, - B11010110= 214, - B11010111= 215, - B11011000= 216, - B11011001= 217, - B11011010= 218, - B11011011= 219, - B11011100= 220, - B11011101= 221, - B11011110= 222, - B11011111= 223, - B11100000= 224, - B11100001= 225, - B11100010= 226, - B11100011= 227, - B11100100= 228, - B11100101= 229, - B11100110= 230, - B11100111= 231, - B11101000= 232, - B11101001= 233, - B11101010= 234, - B11101011= 235, - B11101100= 236, - B11101101= 237, - B11101110= 238, - B11101111= 239, - B11110000= 240, - B11110001= 241, - B11110010= 242, - B11110011= 243, - B11110100= 244, - B11110101= 245, - B11110110= 246, - B11110111= 247, - B11111000= 248, - B11111001= 249, - B11111010= 250, - B11111011= 251, - B11111100= 252, - B11111101= 253, - B11111110= 254, - B11111111= 255 -}; diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h deleted file mode 100644 index 28d2d222f..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h +++ /dev/null @@ -1,2017 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/******************************************************************************* -* Copyright (c) 2007 MITSUNARI Shigeo -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* Redistributions of source code must retain the above copyright notice, this -* list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* Neither the name of the copyright owner nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -*******************************************************************************/ - -const char *getVersionString() const { return "5.76"; } -void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); } -void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); } -void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); } -void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); } -void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); } -void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); } -void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); } -void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); } -void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); } -void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); } -void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); } -void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); } -void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); } -void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); } -void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); } -void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); } -void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); } -void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); } -void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); } -void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); } -void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); } -void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); } -void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); } -void bnd() { db(0xF2); } -void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } -void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); } -void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } -void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); } -void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); } -void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); } -void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); } -void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); } -void bsf(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); } -void bsr(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); } -void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); } -void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); } -void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); } -void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); } -void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); } -void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); } -void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); } -void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); } -void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); } -void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); } -void cbw() { db(0x66); db(0x98); } -void cdq() { db(0x99); } -void clc() { db(0xF8); } -void cld() { db(0xFC); } -void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); } -void cli() { db(0xFA); } -void cmc() { db(0xF5); } -void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 -void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 -void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 -void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 -void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 -void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 -void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 -void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 -void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 -void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 -void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 -void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 -void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 -void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 -void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 -void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 -void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 -void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 -void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 -void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 -void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524 -void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 -void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524 -void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 -void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524 -void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 -void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 -void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 -void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524 -void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 -void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); } -void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); } -void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); } -void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); } -void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); } -void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); } -void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); } -void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); } -void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); } -void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); } -void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); } -void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); } -void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); } -void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); } -void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); } -void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); } -void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); } -void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); } -void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); } -void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); } -void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); } -void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); } -void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); } -void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); } -void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); } -void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); } -void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); } -void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); } -void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); } -void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); } -void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); } -void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); } -void cmpsb() { db(0xA6); } -void cmpsd() { db(0xA7); } -void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); } -void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); } -void cmpsw() { db(0x66); db(0xA7); } -void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); } -void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); } -void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); } -void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); } -void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); } -void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); } -void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); } -void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); } -void cpuid() { db(0x0F); db(0xA2); } -void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); } -void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); } -void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); } -void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); } -void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); } -void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); } -void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); } -void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); } -void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); } -void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); } -void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); } -void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); } -void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); } -void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); } -void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); } -void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); } -void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); } -void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); } -void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); } -void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); } -void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); } -void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); } -void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); } -void cwd() { db(0x66); db(0x99); } -void cwde() { db(0x98); } -void dec(const Operand& op) { opIncDec(op, 0x48, 1); } -void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); } -void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); } -void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); } -void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); } -void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); } -void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void emms() { db(0x0F); db(0x77); } -void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); } -void f2xm1() { db(0xD9); db(0xF0); } -void fabs() { db(0xD9); db(0xE1); } -void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); } -void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); } -void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); } -void faddp() { db(0xDE); db(0xC1); } -void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); } -void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); } -void fchs() { db(0xD9); db(0xE0); } -void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); } -void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); } -void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); } -void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); } -void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); } -void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); } -void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); } -void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); } -void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); } -void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); } -void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); } -void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); } -void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); } -void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); } -void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); } -void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); } -void fcom() { db(0xD8); db(0xD1); } -void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); } -void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); } -void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); } -void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); } -void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); } -void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); } -void fcomp() { db(0xD8); db(0xD9); } -void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); } -void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); } -void fcompp() { db(0xDE); db(0xD9); } -void fcos() { db(0xD9); db(0xFF); } -void fdecstp() { db(0xD9); db(0xF6); } -void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); } -void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); } -void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); } -void fdivp() { db(0xDE); db(0xF9); } -void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); } -void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); } -void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); } -void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); } -void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); } -void fdivrp() { db(0xDE); db(0xF1); } -void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); } -void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); } -void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); } -void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); } -void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); } -void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); } -void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); } -void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); } -void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); } -void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); } -void fincstp() { db(0xD9); db(0xF7); } -void finit() { db(0x9B); db(0xDB); db(0xE3); } -void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); } -void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); } -void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); } -void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); } -void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); } -void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); } -void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); } -void fld1() { db(0xD9); db(0xE8); } -void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); } -void fldl2e() { db(0xD9); db(0xEA); } -void fldl2t() { db(0xD9); db(0xE9); } -void fldlg2() { db(0xD9); db(0xEC); } -void fldln2() { db(0xD9); db(0xED); } -void fldpi() { db(0xD9); db(0xEB); } -void fldz() { db(0xD9); db(0xEE); } -void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); } -void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); } -void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); } -void fmulp() { db(0xDE); db(0xC9); } -void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); } -void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); } -void fninit() { db(0xDB); db(0xE3); } -void fnop() { db(0xD9); db(0xD0); } -void fpatan() { db(0xD9); db(0xF3); } -void fprem() { db(0xD9); db(0xF8); } -void fprem1() { db(0xD9); db(0xF5); } -void fptan() { db(0xD9); db(0xF2); } -void frndint() { db(0xD9); db(0xFC); } -void fscale() { db(0xD9); db(0xFD); } -void fsin() { db(0xD9); db(0xFE); } -void fsincos() { db(0xD9); db(0xFB); } -void fsqrt() { db(0xD9); db(0xFA); } -void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); } -void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); } -void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); } -void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); } -void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); } -void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); } -void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); } -void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); } -void fsubp() { db(0xDE); db(0xE9); } -void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); } -void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); } -void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); } -void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); } -void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); } -void fsubrp() { db(0xDE); db(0xE1); } -void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); } -void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); } -void ftst() { db(0xD9); db(0xE4); } -void fucom() { db(0xDD); db(0xE1); } -void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); } -void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); } -void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); } -void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); } -void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); } -void fucomp() { db(0xDD); db(0xE9); } -void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); } -void fucompp() { db(0xDA); db(0xE9); } -void fwait() { db(0x9B); } -void fxam() { db(0xD9); db(0xE5); } -void fxch() { db(0xD9); db(0xC9); } -void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); } -void fxtract() { db(0xD9); db(0xF4); } -void fyl2x() { db(0xD9); db(0xF1); } -void fyl2xp1() { db(0xD9); db(0xF9); } -void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); } -void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); } -void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); } -void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); } -void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); } -void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); } -void inc(const Operand& op) { opIncDec(op, 0x40, 0); } -void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 -void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524 -void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 -void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 -void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 -void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524 -void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 -void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 -void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 -void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524 -void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 -void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 -void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 -void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524 -void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 -void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 -void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 -void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524 -void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 -void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 -void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 -void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524 -void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 -void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 -void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 -void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524 -void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 -void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 -void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 -void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524 -void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 -void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 -void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 -void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524 -void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 -void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 -void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 -void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524 -void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 -void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 -void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 -void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524 -void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 -void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 -void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 -void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524 -void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 -void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 -void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 -void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524 -void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 -void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 -void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 -void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524 -void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 -void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 -void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 -void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524 -void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 -void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 -void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 -void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524 -void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 -void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 -void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 -void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524 -void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 -void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 -void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 -void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524 -void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 -void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 -void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 -void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524 -void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 -void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 -void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 -void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524 -void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 -void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 -void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 -void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524 -void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524 -void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 -void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 -void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524 -void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 -void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 -void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 -void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524 -void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524 -void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 -void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 -void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524 -void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 -void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 -void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 -void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524 -void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524 -void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 -void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 -void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524 -void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 -void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 -void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 -void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524 -void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 -void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 -void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 -void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524 -void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 -void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 -void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 -void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524 -void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524 -void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 -void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 -void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524 -void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 -void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 -void lahf() { db(0x9F); } -void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); } -void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); } -void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); } -void lfence() { db(0x0F); db(0xAE); db(0xE8); } -void lock() { db(0xF0); } -void lzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); } -void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); } -void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); } -void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); } -void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); } -void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); } -void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); } -void mfence() { db(0x0F); db(0xAE); db(0xF0); } -void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); } -void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); } -void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); } -void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); } -void monitor() { db(0x0F); db(0x01); db(0xC8); } -void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); } -void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); } -void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); } -void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); } -void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); } -void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); } -void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); } -void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); } -void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } -void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } -void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); } -void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); } -void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); } -void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); } -void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); } -void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); } -void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); } -void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); } -void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); } -void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); } -void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); } -void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); } -void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); } -void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); } -void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); } -void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); } -void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); } -void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); } -void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); } -void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); } -void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); } -void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); } -void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); } -void movsb() { db(0xA4); } -void movsd() { db(0xA5); } -void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); } -void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); } -void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); } -void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); } -void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); } -void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); } -void movsw() { db(0x66); db(0xA5); } -void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); } -void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); } -void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); } -void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); } -void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); } -void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); } -void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); } -void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); } -void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); } -void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); } -void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); } -void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); } -void mwait() { db(0x0F); db(0x01); db(0xC9); } -void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); } -void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); } -void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); } -void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); } -void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); } -void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); } -void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); } -void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); } -void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); } -void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); } -void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); } -void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); } -void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); } -void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); } -void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); } -void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); } -void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); } -void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); } -void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); } -void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); } -void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast(imm), 0x3a); } -void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); } -void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); } -void pause() { db(0xF3); db(0x90); } -void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); } -void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); } -void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); } -void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); } -void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); } -void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); } -void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); } -void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); } -void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); } -void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); } -void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); } -void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); } -void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); } -void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); } -void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); } -void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); } -void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); } -void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); } -void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); } -void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); } -void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); } -void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); } -void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); } -void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); } -void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); } -void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); } -void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); } -void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); } -void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); } -void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); } -void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); } -void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); } -void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); } -void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); } -void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); } -void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); } -void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); } -void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); } -void popcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); } -void popf() { db(0x9D); } -void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); } -void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); } -void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); } -void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); } -void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); } -void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); } -void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); } -void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); } -void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); } -void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); } -void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); } -void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); } -void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); } -void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); } -void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); } -void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); } -void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); } -void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); } -void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); } -void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); } -void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); } -void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); } -void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); } -void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); } -void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); } -void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); } -void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); } -void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); } -void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); } -void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); } -void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); } -void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); } -void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); } -void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); } -void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); } -void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); } -void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); } -void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); } -void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); } -void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); } -void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); } -void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); } -void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); } -void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); } -void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); } -void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); } -void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); } -void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); } -void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); } -void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); } -void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); } -void pushf() { db(0x9C); } -void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); } -void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); } -void rcl(const Operand& op, int imm) { opShift(op, imm, 2); } -void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); } -void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); } -void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); } -void rcr(const Operand& op, int imm) { opShift(op, imm, 3); } -void rdmsr() { db(0x0F); db(0x32); } -void rdpmc() { db(0x0F); db(0x33); } -void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } -void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } -void rdtsc() { db(0x0F); db(0x31); } -void rdtscp() { db(0x0F); db(0x01); db(0xF9); } -void rep() { db(0xF3); } -void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } } -void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); } -void rol(const Operand& op, int imm) { opShift(op, imm, 0); } -void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); } -void ror(const Operand& op, int imm) { opShift(op, imm, 1); } -void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); } -void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); } -void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast(imm), 0x3A); } -void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); } -void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); } -void sahf() { db(0x9E); } -void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } -void sal(const Operand& op, int imm) { opShift(op, imm, 4); } -void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); } -void sar(const Operand& op, int imm) { opShift(op, imm, 7); } -void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); } -void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); } -void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); } -void scasb() { db(0xAE); } -void scasd() { db(0xAF); } -void scasw() { db(0x66); db(0xAF); } -void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 -void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 -void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 -void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 -void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 -void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 -void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 -void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 -void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 -void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 -void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 -void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 -void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 -void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 -void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 -void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 -void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 -void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 -void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 -void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 -void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524 -void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 -void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524 -void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 -void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524 -void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 -void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 -void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 -void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524 -void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 -void sfence() { db(0x0F); db(0xAE); db(0xF8); } -void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); } -void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); } -void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); } -void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); } -void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); } -void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); } -void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); } -void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } -void shl(const Operand& op, int imm) { opShift(op, imm, 4); } -void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); } -void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); } -void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); } -void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); } -void shr(const Operand& op, int imm) { opShift(op, imm, 5); } -void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); } -void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); } -void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); } -void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); } -void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); } -void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); } -void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); } -void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); } -void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); } -void stac() { db(0x0F); db(0x01); db(0xCB); } -void stc() { db(0xF9); } -void std() { db(0xFD); } -void sti() { db(0xFB); } -void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); } -void stosb() { db(0xAA); } -void stosd() { db(0xAB); } -void stosw() { db(0x66); db(0xAB); } -void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); } -void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); } -void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); } -void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); } -void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); } -void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); } -void tzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); } -void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); } -void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); } -void ud2() { db(0x0F); db(0x0B); } -void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); } -void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); } -void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); } -void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); } -void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); } -void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); } -void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); } -void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); } -void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); } -void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); } -void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); } -void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); } -void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); } -void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); } -void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); } -void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); } -void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); } -void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); } -void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); } -void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); } -void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); } -void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); } -void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); } -void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); } -void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); } -void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); } -void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); } -void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); } -void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); } -void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); } -void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); } -void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); } -void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); } -void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); } -void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); } -void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); } -void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); } -void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); } -void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); } -void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); } -void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); } -void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); } -void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); } -void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); } -void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); } -void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); } -void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); } -void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); } -void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); } -void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); } -void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); } -void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); } -void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); } -void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); } -void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); } -void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); } -void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); } -void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); } -void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); } -void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); } -void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); } -void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); } -void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); } -void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); } -void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); } -void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); } -void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); } -void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); } -void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); } -void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); } -void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); } -void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); } -void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); } -void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); } -void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); } -void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); } -void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); } -void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); } -void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); } -void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); } -void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); } -void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); } -void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); } -void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); } -void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); } -void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); } -void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); } -void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); } -void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); } -void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); } -void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); } -void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); } -void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); } -void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); } -void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); } -void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); } -void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); } -void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); } -void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); } -void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); } -void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); } -void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); } -void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); } -void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); } -void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); } -void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); } -void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); } -void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); } -void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); } -void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); } -void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); } -void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); } -void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); } -void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); } -void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); } -void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); } -void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); } -void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); } -void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); } -void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); } -void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); } -void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); } -void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); } -void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); } -void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); } -void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); } -void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); } -void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); } -void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); } -void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); } -void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); } -void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); } -void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); } -void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); } -void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); } -void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); } -void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); } -void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); } -void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); } -void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); } -void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); } -void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); } -void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); } -void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); } -void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); } -void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); } -void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); } -void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); } -void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); } -void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); } -void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); } -void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); } -void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); } -void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); } -void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); } -void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); } -void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); } -void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); } -void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); } -void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); } -void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); } -void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); } -void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); } -void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } -void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); } -void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); } -void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); } -void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } -void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); } -void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); } -void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); } -void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); } -void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } -void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } -void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); } -void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); } -void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); } -void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); } -void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); } -void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); } -void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); } -void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); } -void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); } -void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); } -void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); } -void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); } -void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); } -void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); } -void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); } -void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); } -void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); } -void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); } -void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); } -void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); } -void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); } -void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); } -void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); } -void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); } -void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); } -void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); } -void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); } -void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); } -void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); } -void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); } -void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); } -void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); } -void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); } -void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); } -void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); } -void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); } -void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); } -void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); } -void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); } -void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); } -void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); } -void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); } -void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); } -void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); } -void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); } -void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); } -void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); } -void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); } -void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); } -void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); } -void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); } -void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); } -void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); } -void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); } -void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); } -void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); } -void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); } -void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); } -void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); } -void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); } -void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); } -void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); } -void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); } -void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); } -void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); } -void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); } -void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); } -void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); } -void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); } -void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); } -void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); } -void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); } -void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); } -void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); } -void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); } -void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); } -void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); } -void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); } -void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); } -void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); } -void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); } -void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); } -void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); } -void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); } -void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); } -void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); } -void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); } -void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); } -void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); } -void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); } -void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); } -void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); } -void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); } -void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); } -void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); } -void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); } -void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); } -void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); } -void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); } -void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); } -void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); } -void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); } -void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); } -void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); } -void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); } -void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); } -void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); } -void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); } -void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); } -void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); } -void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); } -void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); } -void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); } -void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); } -void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); } -void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); } -void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); } -void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); } -void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); } -void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); } -void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); } -void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); } -void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); } -void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); } -void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); } -void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); } -void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); } -void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); } -void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); } -void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); } -void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); } -void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); } -void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); } -void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); } -void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); } -void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } -void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } -void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); } -void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); } -void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); } -void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } -void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } -void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); } -void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); } -void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); } -void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); } -void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); } -void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); } -void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); } -void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); } -void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); } -void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); } -void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); } -void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); } -void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); } -void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); } -void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); } -void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); } -void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); } -void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); } -void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); } -void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); } -void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); } -void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); } -void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); } -void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); } -void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); } -void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); } -void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); } -void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); } -void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); } -void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); } -void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); } -void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); } -void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); } -void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); } -void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); } -void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); } -void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); } -void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); } -void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); } -void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); } -void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); } -void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); } -void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); } -void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); } -void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); } -void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); } -void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); } -void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); } -void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); } -void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); } -void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); } -void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); } -void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); } -void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); } -void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); } -void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); } -void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); } -void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); } -void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); } -void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); } -void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); } -void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); } -void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); } -void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); } -void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); } -void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); } -void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } } -void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); } -void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); } -void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); } -void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); } -void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); } -void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); } -void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); } -void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); } -void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); } -void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); } -void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); } -void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); } -void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); } -void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); } -void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); } -void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); } -void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); } -void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); } -void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); } -void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); } -void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); } -void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); } -void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); } -void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); } -void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); } -void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); } -void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); } -void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); } -void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); } -void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); } -void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); } -void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); } -void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); } -void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); } -void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); } -void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); } -void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); } -void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); } -void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); } -void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); } -void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); } -void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); } -void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); } -void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); } -void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); } -void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); } -void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); } -void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); } -void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); } -void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); } -void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); } -void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); } -void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); } -void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); } -void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); } -void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); } -void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); } -void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); } -void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); } -void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); } -void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); } -void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); } -void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } -void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); } -void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } -void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } -void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); } -void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); } -void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); } -void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } -void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); } -void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } -void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); } -void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); } -void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } -void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); } -void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } -void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); } -void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } -void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } -void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); } -void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); } -void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); } -void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } -void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); } -void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); } -void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); } -void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); } -void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); } -void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); } -void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); } -void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); } -void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); } -void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); } -void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); } -void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); } -void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); } -void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); } -void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); } -void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); } -void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); } -void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); } -void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); } -void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); } -void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); } -void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); } -void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); } -void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); } -void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); } -void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); } -void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); } -void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); } -void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); } -void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); } -void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); } -void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); } -void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); } -void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); } -void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); } -void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); } -void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); } -void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); } -void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); } -void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); } -void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); } -void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); } -void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); } -void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); } -void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); } -void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); } -void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); } -void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); } -void vzeroall() { db(0xC5); db(0xFC); db(0x77); } -void vzeroupper() { db(0xC5); db(0xF8); db(0x77); } -void wait() { db(0x9B); } -void wbinvd() { db(0x0F); db(0x09); } -void wrmsr() { db(0x0F); db(0x30); } -void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); } -void xgetbv() { db(0x0F); db(0x01); db(0xD0); } -void xlatb() { db(0xD7); } -void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); } -void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); } -void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); } -void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); } -#ifdef XBYAK_ENABLE_OMITTED_OPERAND -void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); } -void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); } -void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); } -void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); } -void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); } -void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); } -void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); } -void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); } -void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); } -void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); } -void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); } -void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); } -void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); } -void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); } -void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); } -void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); } -void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); } -void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); } -void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); } -void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); } -void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); } -void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); } -void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); } -void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); } -void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); } -void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); } -void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); } -void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); } -void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); } -void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); } -void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); } -void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); } -void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); } -void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); } -void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); } -void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); } -void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); } -void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); } -void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); } -void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); } -void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); } -void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); } -void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); } -void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); } -void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); } -void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); } -void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); } -void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); } -void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); } -void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); } -void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); } -void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); } -void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); } -void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); } -void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); } -void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); } -void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); } -void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); } -void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); } -void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); } -void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); } -void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); } -void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); } -void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); } -void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); } -void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); } -void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); } -void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); } -void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); } -void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); } -void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); } -void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); } -void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); } -void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); } -void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); } -void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); } -void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); } -void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); } -void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); } -void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); } -void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); } -void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); } -void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); } -void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); } -void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); } -void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); } -void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); } -void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); } -void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); } -void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); } -void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); } -void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); } -void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); } -void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); } -void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); } -void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); } -void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); } -void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); } -void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); } -void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); } -void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); } -void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); } -void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); } -void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); } -void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); } -void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); } -void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); } -void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); } -void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); } -void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); } -void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); } -void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); } -void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); } -void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); } -void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); } -void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); } -void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); } -void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); } -void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); } -void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); } -void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); } -void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); } -void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); } -void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); } -void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); } -void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); } -void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); } -void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); } -void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); } -void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); } -void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); } -void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); } -void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); } -void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); } -void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); } -void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); } -void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); } -void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); } -void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); } -void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); } -void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); } -void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); } -void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); } -void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); } -void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); } -void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); } -void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); } -void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); } -void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); } -void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); } -void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); } -void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); } -void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); } -void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); } -void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); } -void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); } -void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); } -void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); } -void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); } -void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); } -void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); } -void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); } -void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); } -void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); } -void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); } -void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); } -void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); } -void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); } -void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); } -void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); } -void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); } -void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); } -void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); } -void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); } -void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); } -void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); } -void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); } -void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); } -void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); } -void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); } -void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); } -void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); } -void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); } -void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); } -void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); } -void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); } -void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); } -void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); } -void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); } -void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); } -void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); } -void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); } -void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); } -void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); } -void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); } -void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); } -void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); } -void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); } -void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); } -void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); } -void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); } -void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); } -void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); } -void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); } -void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); } -void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); } -void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); } -void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); } -void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); } -void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); } -void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); } -void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); } -void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); } -void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); } -void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); } -void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); } -void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); } -void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); } -void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); } -void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); } -void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); } -void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); } -void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); } -void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); } -void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); } -void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); } -void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); } -void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); } -void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); } -void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); } -void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); } -void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); } -void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); } -void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); } -void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); } -void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); } -void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); } -void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); } -void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); } -void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); } -void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); } -void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); } -void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); } -void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); } -void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); } -void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); } -void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); } -void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); } -void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); } -void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); } -void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); } -void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); } -void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); } -void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); } -void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); } -void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); } -#endif -#ifdef XBYAK64 -void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } -void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } -void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } -void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } -void cdqe() { db(0x48); db(0x98); } -void cqo() { db(0x48); db(0x99); } -void cmpsq() { db(0x48); db(0xA7); } -void movsq() { db(0x48); db(0xA5); } -void scasq() { db(0x48); db(0xAF); } -void stosq() { db(0x48); db(0xAB); } -void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); } -void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } -void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } -void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); } -void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); } -void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); } -void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); } -void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); } -void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); } -void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); } -void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); } -void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); } -#else -void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } -void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } -void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } -void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } -void aaa() { db(0x37); } -void aad() { db(0xD5); db(0x0A); } -void aam() { db(0xD4); db(0x0A); } -void aas() { db(0x3F); } -void daa() { db(0x27); } -void das() { db(0x2F); } -void popad() { db(0x61); } -void popfd() { db(0x9D); } -void pusha() { db(0x60); } -void pushad() { db(0x60); } -void pushfd() { db(0x9C); } -void popa() { db(0x61); } -#endif -#ifndef XBYAK_NO_OP_NAMES -void and(const Operand& op1, const Operand& op2) { and_(op1, op2); } -void and(const Operand& op, uint32 imm) { and_(op, imm); } -void or(const Operand& op1, const Operand& op2) { or_(op1, op2); } -void or(const Operand& op, uint32 imm) { or_(op, imm); } -void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); } -void xor(const Operand& op, uint32 imm) { xor_(op, imm); } -void not(const Operand& op) { not_(op); } -#endif -#ifndef XBYAK_DISABLE_AVX512 -void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); } -void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); } -void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); } -void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); } -void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); } -void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); } -void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); } -void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); } -void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); } -void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); } -void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); } -void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); } -void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); } -void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); } -void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); } -void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); } -void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); } -void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); } -void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); } -void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); } -void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); } -void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); } -void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); } -void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); } -void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); } -void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); } -void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); } -void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); } -void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); } -void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); } -void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); } -void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); } -void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); } -void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); } -void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); } -void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); } -void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); } -void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); } -void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); } -void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); } -void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); } -void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); } -void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); } -void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); } -void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); } -void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); } -void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); } -void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); } -void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); } -void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); } -void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); } -void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); } -void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); } -void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); } -void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); } -void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); } -void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); } -void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); } -void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); } -void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); } -void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); } -void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); } -void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); } -void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); } -void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); } -void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); } -void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); } -void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); } -void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); } -void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); } -void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); } -void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); } -void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); } -void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); } -void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); } -void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); } -void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); } -void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); } -void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); } -void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } -void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } -void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } -void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } -void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); } -void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); } -void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); } -void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); } -void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); } -void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); } -void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); } -void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); } -void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); } -void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); } -void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); } -void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); } -void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } -void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } -void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); } -void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); } -void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); } -void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); } -void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); } -void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); } -void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } -void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } -void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); } -void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); } -void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); } -void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); } -void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } -void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } -void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); } -void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); } -void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); } -void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); } -void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); } -void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); } -void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); } -void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); } -void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); } -void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); } -void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); } -void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); } -void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); } -void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); } -void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); } -void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } -void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } -void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); } -void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); } -void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); } -void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); } -void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); } -void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); } -void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } -void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } -void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } -void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } -void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); } -void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); } -void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); } -void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); } -void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); } -void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); } -void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); } -void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); } -void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } -void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } -void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); } -void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); } -void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); } -void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); } -void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); } -void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); } -void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); } -void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); } -void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } -void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } -void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } -void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } -void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } -void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } -void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } -void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } -void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } -void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } -void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } -void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } -void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); } -void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); } -void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); } -void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); } -void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); } -void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); } -void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); } -void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); } -void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); } -void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); } -void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); } -void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); } -void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); } -void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); } -void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); } -void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); } -void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); } -void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); } -void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); } -void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); } -void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); } -void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); } -void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); } -void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); } -void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); } -void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); } -void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); } -void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); } -void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); } -void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); } -void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); } -void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); } -void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); } -void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); } -void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); } -void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); } -void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); } -void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); } -void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); } -void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); } -void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); } -void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); } -void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); } -void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); } -void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); } -void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); } -void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); } -void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); } -void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); } -void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); } -void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); } -void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); } -void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); } -void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); } -void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } -void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); } -void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); } -void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } -void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); } -void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); } -void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); } -void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); } -void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); } -void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); } -void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); } -void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); } -void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); } -void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); } -void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); } -void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); } -void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); } -void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); } -void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); } -void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); } -void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); } -void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); } -void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); } -void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); } -void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); } -void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); } -void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); } -void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); } -void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); } -void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); } -void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); } -void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); } -void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); } -void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); } -void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); } -void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); } -void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); } -void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); } -void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); } -void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); } -void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); } -void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); } -void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); } -void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); } -void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } -void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); } -void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); } -void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } -void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); } -void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); } -void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } -void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } -void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); } -void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); } -void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } -void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } -void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); } -void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); } -void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); } -void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); } -void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); } -void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); } -void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); } -void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); } -void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); } -void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); } -void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); } -void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); } -void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); } -void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); } -void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); } -void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); } -void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); } -void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); } -void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); } -void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); } -void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } -void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); } -void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); } -void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); } -void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); } -void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); } -void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); } -void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } -void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } -void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } -void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } -void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } -void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } -void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } -void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } -void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); } -void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); } -void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); } -void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); } -void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } -void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } -void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); } -void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); } -void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); } -void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); } -void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); } -void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); } -void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); } -void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); } -void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); } -void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); } -void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } -void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } -void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); } -void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); } -void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); } -void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); } -void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); } -void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); } -void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); } -void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); } -void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); } -void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); } -void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); } -void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); } -void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); } -void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); } -void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); } -void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); } -void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); } -void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); } -void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } -void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } -void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } -void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } -void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } -void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); } -void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); } -void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); } -void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); } -void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); } -void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); } -#ifdef XBYAK64 -void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); } -void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); } -void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); } -#endif -#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h deleted file mode 100644 index 8ef076e68..000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h +++ /dev/null @@ -1,772 +0,0 @@ -/******************************************************************************* -* Copyright 2016-2019 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/******************************************************************************* -* Copyright (c) 2007 MITSUNARI Shigeo -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* Redistributions of source code must retain the above copyright notice, this -* list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* Neither the name of the copyright owner nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -*******************************************************************************/ - -#ifndef XBYAK_XBYAK_UTIL_H_ -#define XBYAK_XBYAK_UTIL_H_ - -/** - utility class and functions for Xbyak - Xbyak::util::Clock ; rdtsc timer - Xbyak::util::Cpu ; detect CPU - @note this header is UNDER CONSTRUCTION! -*/ -#include "xbyak.h" - -#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) - #define XBYAK_INTEL_CPU_SPECIFIC -#endif - -#ifdef XBYAK_INTEL_CPU_SPECIFIC -#ifdef _MSC_VER - #if (_MSC_VER < 1400) && defined(XBYAK32) - static inline __declspec(naked) void __cpuid(int[4], int) - { - __asm { - push ebx - push esi - mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn - cpuid - mov esi, dword ptr [esp + 4 * 2 + 4] // data - mov dword ptr [esi], eax - mov dword ptr [esi + 4], ebx - mov dword ptr [esi + 8], ecx - mov dword ptr [esi + 12], edx - pop esi - pop ebx - ret - } - } - #else - #include // for __cpuid - #endif -#else - #ifndef __GNUC_PREREQ - #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor))) - #endif - #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__) - #include - #else - #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm' - #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) - #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) - #else - #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) - #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) - #endif - #endif -#endif -#endif - -namespace Xbyak { namespace util { - -typedef enum { - SmtLevel = 1, - CoreLevel = 2 -} IntelCpuTopologyLevel; - -/** - CPU detection class -*/ -class Cpu { - uint64 type_; - //system topology - bool x2APIC_supported_; - static const size_t maxTopologyLevels = 2; - unsigned int numCores_[maxTopologyLevels]; - - static const unsigned int maxNumberCacheLevels = 10; - unsigned int dataCacheSize_[maxNumberCacheLevels]; - unsigned int coresSharignDataCache_[maxNumberCacheLevels]; - unsigned int dataCacheLevels_; - - unsigned int get32bitAsBE(const char *x) const - { - return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); - } - unsigned int mask(int n) const - { - return (1U << n) - 1; - } - void setFamily() - { - unsigned int data[4] = {}; - getCpuid(1, data); - stepping = data[0] & mask(4); - model = (data[0] >> 4) & mask(4); - family = (data[0] >> 8) & mask(4); - // type = (data[0] >> 12) & mask(2); - extModel = (data[0] >> 16) & mask(4); - extFamily = (data[0] >> 20) & mask(8); - if (family == 0x0f) { - displayFamily = family + extFamily; - } else { - displayFamily = family; - } - if (family == 6 || family == 0x0f) { - displayModel = (extModel << 4) + model; - } else { - displayModel = model; - } - } - unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end) - { - return (val >> base) & ((1u << (end - base)) - 1); - } - void setNumCores() - { - if ((type_ & tINTEL) == 0) return; - - unsigned int data[4] = {}; - - /* CAUTION: These numbers are configuration as shipped by Intel. */ - getCpuidEx(0x0, 0, data); - if (data[0] >= 0xB) { - /* - if leaf 11 exists(x2APIC is supported), - we use it to get the number of smt cores and cores on socket - - leaf 0xB can be zeroed-out by a hypervisor - */ - x2APIC_supported_ = true; - for (unsigned int i = 0; i < maxTopologyLevels; i++) { - getCpuidEx(0xB, i, data); - IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15); - if (level == SmtLevel || level == CoreLevel) { - numCores_[level - 1] = extractBit(data[1], 0, 15); - } - } - } else { - /* - Failed to deremine num of cores without x2APIC support. - TODO: USE initial APIC ID to determine ncores. - */ - numCores_[SmtLevel - 1] = 0; - numCores_[CoreLevel - 1] = 0; - } - - } - void setCacheHierarchy() - { - if ((type_ & tINTEL) == 0) return; - const unsigned int NO_CACHE = 0; - const unsigned int DATA_CACHE = 1; -// const unsigned int INSTRUCTION_CACHE = 2; - const unsigned int UNIFIED_CACHE = 3; - unsigned int smt_width = 0; - unsigned int logical_cores = 0; - unsigned int data[4] = {}; - - if (x2APIC_supported_) { - smt_width = numCores_[0]; - logical_cores = numCores_[1]; - } - - /* - Assumptions: - the first level of data cache is not shared (which is the - case for every existing architecture) and use this to - determine the SMT width for arch not supporting leaf 11. - when leaf 4 reports a number of core less than numCores_ - on socket reported by leaf 11, then it is a correct number - of cores not an upperbound. - */ - for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) { - getCpuidEx(0x4, i, data); - unsigned int cacheType = extractBit(data[0], 0, 4); - if (cacheType == NO_CACHE) break; - if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) { - unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1; - if (logical_cores != 0) { // true only if leaf 0xB is supported and valid - actual_logical_cores = (std::min)(actual_logical_cores, logical_cores); - } - assert(actual_logical_cores != 0); - dataCacheSize_[dataCacheLevels_] = - (extractBit(data[1], 22, 31) + 1) - * (extractBit(data[1], 12, 21) + 1) - * (extractBit(data[1], 0, 11) + 1) - * (data[2] + 1); - if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores; - assert(smt_width != 0); - // FIXME: check and fix number of cores sharing L3 cache for different configurations - // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket) - coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u); - dataCacheLevels_++; - } - } - } - -public: - int model; - int family; - int stepping; - int extModel; - int extFamily; - int displayFamily; // family + extFamily - int displayModel; // model + extModel - - unsigned int getNumCores(IntelCpuTopologyLevel level) { - if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER); - if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED); - return (level == CoreLevel) - ? numCores_[level - 1] / numCores_[SmtLevel - 1] - : numCores_[level - 1]; - } - - unsigned int getDataCacheLevels() const { return dataCacheLevels_; } - unsigned int getCoresSharingDataCache(unsigned int i) const - { - if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); - return coresSharignDataCache_[i]; - } - unsigned int getDataCacheSize(unsigned int i) const - { - if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); - return dataCacheSize_[i]; - } - - /* - data[] = { eax, ebx, ecx, edx } - */ - static inline void getCpuid(unsigned int eaxIn, unsigned int data[4]) - { -#ifdef XBYAK_INTEL_CPU_SPECIFIC - #ifdef _MSC_VER - __cpuid(reinterpret_cast(data), eaxIn); - #else - __cpuid(eaxIn, data[0], data[1], data[2], data[3]); - #endif -#else - (void)eaxIn; - (void)data; -#endif - } - static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4]) - { -#ifdef XBYAK_INTEL_CPU_SPECIFIC - #ifdef _MSC_VER - __cpuidex(reinterpret_cast(data), eaxIn, ecxIn); - #else - __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]); - #endif -#else - (void)eaxIn; - (void)ecxIn; - (void)data; -#endif - } - static inline uint64 getXfeature() - { -#ifdef XBYAK_INTEL_CPU_SPECIFIC - #ifdef _MSC_VER - return _xgetbv(0); - #else - unsigned int eax, edx; - // xgetvb is not support on gcc 4.2 -// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0)); - __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0)); - return ((uint64)edx << 32) | eax; - #endif -#else - return 0; -#endif - } - typedef uint64 Type; - - static const Type NONE = 0; - static const Type tMMX = 1 << 0; - static const Type tMMX2 = 1 << 1; - static const Type tCMOV = 1 << 2; - static const Type tSSE = 1 << 3; - static const Type tSSE2 = 1 << 4; - static const Type tSSE3 = 1 << 5; - static const Type tSSSE3 = 1 << 6; - static const Type tSSE41 = 1 << 7; - static const Type tSSE42 = 1 << 8; - static const Type tPOPCNT = 1 << 9; - static const Type tAESNI = 1 << 10; - static const Type tSSE5 = 1 << 11; - static const Type tOSXSAVE = 1 << 12; - static const Type tPCLMULQDQ = 1 << 13; - static const Type tAVX = 1 << 14; - static const Type tFMA = 1 << 15; - - static const Type t3DN = 1 << 16; - static const Type tE3DN = 1 << 17; - static const Type tSSE4a = 1 << 18; - static const Type tRDTSCP = 1 << 19; - static const Type tAVX2 = 1 << 20; - static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt - static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx - static const Type tLZCNT = 1 << 23; - - static const Type tINTEL = 1 << 24; - static const Type tAMD = 1 << 25; - - static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb - static const Type tRDRAND = 1 << 27; - static const Type tADX = 1 << 28; // adcx, adox - static const Type tRDSEED = 1 << 29; // rdseed - static const Type tSMAP = 1 << 30; // stac - static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest - static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort - static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph - static const Type tMOVBE = uint64(1) << 34; // mobve - static const Type tAVX512F = uint64(1) << 35; - static const Type tAVX512DQ = uint64(1) << 36; - static const Type tAVX512_IFMA = uint64(1) << 37; - static const Type tAVX512IFMA = tAVX512_IFMA; - static const Type tAVX512PF = uint64(1) << 38; - static const Type tAVX512ER = uint64(1) << 39; - static const Type tAVX512CD = uint64(1) << 40; - static const Type tAVX512BW = uint64(1) << 41; - static const Type tAVX512VL = uint64(1) << 42; - static const Type tAVX512_VBMI = uint64(1) << 43; - static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual - static const Type tAVX512_4VNNIW = uint64(1) << 44; - static const Type tAVX512_4FMAPS = uint64(1) << 45; - static const Type tPREFETCHWT1 = uint64(1) << 46; - static const Type tPREFETCHW = uint64(1) << 47; - static const Type tSHA = uint64(1) << 48; - static const Type tMPX = uint64(1) << 49; - static const Type tAVX512_VBMI2 = uint64(1) << 50; - static const Type tGFNI = uint64(1) << 51; - static const Type tVAES = uint64(1) << 52; - static const Type tVPCLMULQDQ = uint64(1) << 53; - static const Type tAVX512_VNNI = uint64(1) << 54; - static const Type tAVX512_BITALG = uint64(1) << 55; - static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56; - - Cpu() - : type_(NONE) - , x2APIC_supported_(false) - , numCores_() - , dataCacheSize_() - , coresSharignDataCache_() - , dataCacheLevels_(0) - { - unsigned int data[4] = {}; - const unsigned int& EAX = data[0]; - const unsigned int& EBX = data[1]; - const unsigned int& ECX = data[2]; - const unsigned int& EDX = data[3]; - getCpuid(0, data); - const unsigned int maxNum = EAX; - static const char intel[] = "ntel"; - static const char amd[] = "cAMD"; - if (ECX == get32bitAsBE(amd)) { - type_ |= tAMD; - getCpuid(0x80000001, data); - if (EDX & (1U << 31)) type_ |= t3DN; - if (EDX & (1U << 15)) type_ |= tCMOV; - if (EDX & (1U << 30)) type_ |= tE3DN; - if (EDX & (1U << 22)) type_ |= tMMX2; - if (EDX & (1U << 27)) type_ |= tRDTSCP; - } - if (ECX == get32bitAsBE(intel)) { - type_ |= tINTEL; - getCpuid(0x80000001, data); - if (EDX & (1U << 27)) type_ |= tRDTSCP; - if (ECX & (1U << 5)) type_ |= tLZCNT; - if (ECX & (1U << 8)) type_ |= tPREFETCHW; - } - getCpuid(1, data); - if (ECX & (1U << 0)) type_ |= tSSE3; - if (ECX & (1U << 9)) type_ |= tSSSE3; - if (ECX & (1U << 19)) type_ |= tSSE41; - if (ECX & (1U << 20)) type_ |= tSSE42; - if (ECX & (1U << 22)) type_ |= tMOVBE; - if (ECX & (1U << 23)) type_ |= tPOPCNT; - if (ECX & (1U << 25)) type_ |= tAESNI; - if (ECX & (1U << 1)) type_ |= tPCLMULQDQ; - if (ECX & (1U << 27)) type_ |= tOSXSAVE; - if (ECX & (1U << 30)) type_ |= tRDRAND; - if (ECX & (1U << 29)) type_ |= tF16C; - - if (EDX & (1U << 15)) type_ |= tCMOV; - if (EDX & (1U << 23)) type_ |= tMMX; - if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE; - if (EDX & (1U << 26)) type_ |= tSSE2; - - if (type_ & tOSXSAVE) { - // check XFEATURE_ENABLED_MASK[2:1] = '11b' - uint64 bv = getXfeature(); - if ((bv & 6) == 6) { - if (ECX & (1U << 28)) type_ |= tAVX; - if (ECX & (1U << 12)) type_ |= tFMA; - if (((bv >> 5) & 7) == 7) { - getCpuidEx(7, 0, data); - if (EBX & (1U << 16)) type_ |= tAVX512F; - if (type_ & tAVX512F) { - if (EBX & (1U << 17)) type_ |= tAVX512DQ; - if (EBX & (1U << 21)) type_ |= tAVX512_IFMA; - if (EBX & (1U << 26)) type_ |= tAVX512PF; - if (EBX & (1U << 27)) type_ |= tAVX512ER; - if (EBX & (1U << 28)) type_ |= tAVX512CD; - if (EBX & (1U << 30)) type_ |= tAVX512BW; - if (EBX & (1U << 31)) type_ |= tAVX512VL; - if (ECX & (1U << 1)) type_ |= tAVX512_VBMI; - if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2; - if (ECX & (1U << 8)) type_ |= tGFNI; - if (ECX & (1U << 9)) type_ |= tVAES; - if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ; - if (ECX & (1U << 11)) type_ |= tAVX512_VNNI; - if (ECX & (1U << 12)) type_ |= tAVX512_BITALG; - if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ; - if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW; - if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS; - } - } - } - } - if (maxNum >= 7) { - getCpuidEx(7, 0, data); - if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2; - if (EBX & (1U << 3)) type_ |= tBMI1; - if (EBX & (1U << 8)) type_ |= tBMI2; - if (EBX & (1U << 9)) type_ |= tENHANCED_REP; - if (EBX & (1U << 18)) type_ |= tRDSEED; - if (EBX & (1U << 19)) type_ |= tADX; - if (EBX & (1U << 20)) type_ |= tSMAP; - if (EBX & (1U << 4)) type_ |= tHLE; - if (EBX & (1U << 11)) type_ |= tRTM; - if (EBX & (1U << 14)) type_ |= tMPX; - if (EBX & (1U << 29)) type_ |= tSHA; - if (ECX & (1U << 0)) type_ |= tPREFETCHWT1; - } - setFamily(); - setNumCores(); - setCacheHierarchy(); - } - void putFamily() const - { - printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n", - family, model, stepping, extFamily, extModel); - printf("display:family=%X, model=%X\n", displayFamily, displayModel); - } - bool has(Type type) const - { - return (type & type_) != 0; - } -}; - -class Clock { -public: - static inline uint64 getRdtsc() - { -#ifdef XBYAK_INTEL_CPU_SPECIFIC - #ifdef _MSC_VER - return __rdtsc(); - #else - unsigned int eax, edx; - __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx)); - return ((uint64)edx << 32) | eax; - #endif -#else - // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu - return 0; -#endif - } - Clock() - : clock_(0) - , count_(0) - { - } - void begin() - { - clock_ -= getRdtsc(); - } - void end() - { - clock_ += getRdtsc(); - count_++; - } - int getCount() const { return count_; } - uint64 getClock() const { return clock_; } - void clear() { count_ = 0; clock_ = 0; } -private: - uint64 clock_; - int count_; -}; - -#ifdef XBYAK64 -const int UseRCX = 1 << 6; -const int UseRDX = 1 << 7; - -class Pack { - static const size_t maxTblNum = 15; - const Xbyak::Reg64 *tbl_[maxTblNum]; - size_t n_; -public: - Pack() : tbl_(), n_(0) {} - Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); } - Pack(const Pack& rhs) - : n_(rhs.n_) - { - for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; - } - Pack& operator=(const Pack& rhs) - { - n_ = rhs.n_; - for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; - return *this; - } - Pack(const Xbyak::Reg64& t0) - { n_ = 1; tbl_[0] = &t0; } - Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; } - Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; } - Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; } - Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; } - Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; } - Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; } - Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; } - Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; } - Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) - { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; } - Pack& append(const Xbyak::Reg64& t) - { - if (n_ == maxTblNum) { - fprintf(stderr, "ERR Pack::can't append\n"); - throw Error(ERR_BAD_PARAMETER); - } - tbl_[n_++] = &t; - return *this; - } - void init(const Xbyak::Reg64 *tbl, size_t n) - { - if (n > maxTblNum) { - fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n); - throw Error(ERR_BAD_PARAMETER); - } - n_ = n; - for (size_t i = 0; i < n; i++) { - tbl_[i] = &tbl[i]; - } - } - const Xbyak::Reg64& operator[](size_t n) const - { - if (n >= n_) { - fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_); - throw Error(ERR_BAD_PARAMETER); - } - return *tbl_[n]; - } - size_t size() const { return n_; } - /* - get tbl[pos, pos + num) - */ - Pack sub(size_t pos, size_t num = size_t(-1)) const - { - if (num == size_t(-1)) num = n_ - pos; - if (pos + num > n_) { - fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num); - throw Error(ERR_BAD_PARAMETER); - } - Pack pack; - pack.n_ = num; - for (size_t i = 0; i < num; i++) { - pack.tbl_[i] = tbl_[pos + i]; - } - return pack; - } - void put() const - { - for (size_t i = 0; i < n_; i++) { - printf("%s ", tbl_[i]->toString()); - } - printf("\n"); - } -}; - -class StackFrame { -#ifdef XBYAK64_WIN - static const int noSaveNum = 6; - static const int rcxPos = 0; - static const int rdxPos = 1; -#else - static const int noSaveNum = 8; - static const int rcxPos = 3; - static const int rdxPos = 2; -#endif - static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax - Xbyak::CodeGenerator *code_; - int pNum_; - int tNum_; - bool useRcx_; - bool useRdx_; - int saveNum_; - int P_; - bool makeEpilog_; - Xbyak::Reg64 pTbl_[4]; - Xbyak::Reg64 tTbl_[maxRegNum]; - Pack p_; - Pack t_; - StackFrame(const StackFrame&); - void operator=(const StackFrame&); -public: - const Pack& p; - const Pack& t; - /* - make stack frame - @param sf [in] this - @param pNum [in] num of function parameter(0 <= pNum <= 4) - @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14 - @param stackSizeByte [in] local stack size - @param makeEpilog [in] automatically call close() if true - - you can use - rax - gp0, ..., gp(pNum - 1) - gt0, ..., gt(tNum-1) - rcx if tNum & UseRCX - rdx if tNum & UseRDX - rsp[0..stackSizeByte - 1] - */ - StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true) - : code_(code) - , pNum_(pNum) - , tNum_(tNum & ~(UseRCX | UseRDX)) - , useRcx_((tNum & UseRCX) != 0) - , useRdx_((tNum & UseRDX) != 0) - , saveNum_(0) - , P_(0) - , makeEpilog_(makeEpilog) - , p(p_) - , t(t_) - { - using namespace Xbyak; - if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM); - const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0); - if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM); - const Reg64& _rsp = code->rsp; - saveNum_ = (std::max)(0, allRegNum - noSaveNum); - const int *tbl = getOrderTbl() + noSaveNum; - for (int i = 0; i < saveNum_; i++) { - code->push(Reg64(tbl[i])); - } - P_ = (stackSizeByte + 7) / 8; - if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment - P_ *= 8; - if (P_ > 0) code->sub(_rsp, P_); - int pos = 0; - for (int i = 0; i < pNum; i++) { - pTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); - } - for (int i = 0; i < tNum_; i++) { - tTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); - } - if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx); - if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx); - p_.init(pTbl_, pNum); - t_.init(tTbl_, tNum_); - } - /* - make epilog manually - @param callRet [in] call ret() if true - */ - void close(bool callRet = true) - { - using namespace Xbyak; - const Reg64& _rsp = code_->rsp; - const int *tbl = getOrderTbl() + noSaveNum; - if (P_ > 0) code_->add(_rsp, P_); - for (int i = 0; i < saveNum_; i++) { - code_->pop(Reg64(tbl[saveNum_ - 1 - i])); - } - - if (callRet) code_->ret(); - } - ~StackFrame() - { - if (!makeEpilog_) return; - try { - close(); - } catch (std::exception& e) { - printf("ERR:StackFrame %s\n", e.what()); - //exit(1); - } - } -private: - const int *getOrderTbl() const - { - using namespace Xbyak; - static const int tbl[] = { -#ifdef XBYAK64_WIN - Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI, -#else - Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, -#endif - Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15 - }; - return &tbl[0]; - } - int getRegIdx(int& pos) const - { - assert(pos < maxRegNum); - using namespace Xbyak; - const int *tbl = getOrderTbl(); - int r = tbl[pos++]; - if (useRcx_) { - if (r == Operand::RCX) { return Operand::R10; } - if (r == Operand::R10) { r = tbl[pos++]; } - } - if (useRdx_) { - if (r == Operand::RDX) { return Operand::R11; } - if (r == Operand::R11) { return tbl[pos++]; } - } - return r; - } -}; -#endif - -} } // end of util -#endif diff --git a/thirdparty/oidn/patches/godot-changes-c58c5216.patch b/thirdparty/oidn/patches/godot-changes-c58c5216.patch deleted file mode 100644 index c01f00187..000000000 --- a/thirdparty/oidn/patches/godot-changes-c58c5216.patch +++ /dev/null @@ -1,337 +0,0 @@ -diff --git a/common/platform.h b/common/platform.h -index be14bc7..9373b61 100644 ---- a/common/platform.h -+++ b/common/platform.h -@@ -19,7 +19,7 @@ - #if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN - #define NOMINMAX -- #include -+ #include - #elif defined(__APPLE__) - #include - #endif -@@ -129,4 +129,3 @@ namespace oidn { - std::string getBuildName(); - - } // namespace oidn -- -diff --git a/core/autoencoder.cpp b/core/autoencoder.cpp -index d6915e6..d8da684 100644 ---- a/core/autoencoder.cpp -+++ b/core/autoencoder.cpp -@@ -90,13 +90,19 @@ namespace oidn { - if (!dirty) - return; - -- device->executeTask([&]() -- { -+ // -- GODOT start -- -+ //device->executeTask([&]() -+ //{ -+ // GODOT end -- -+ - if (mayiuse(avx512_common)) - net = buildNet<16>(); - else - net = buildNet<8>(); -- }); -+ -+ // GODOT start -- -+ //}); -+ // GODOT end -- - - dirty = false; - } -@@ -108,9 +114,10 @@ namespace oidn { - - if (!net) - return; -- -- device->executeTask([&]() -- { -+ // -- GODOT start -- -+ //device->executeTask([&]() -+ //{ -+ // -- GODOT end -- - Progress progress; - progress.func = progressFunc; - progress.userPtr = progressUserPtr; -@@ -156,7 +163,9 @@ namespace oidn { - tileIndex++; - } - } -- }); -+ // -- GODOT start -- -+ //}); -+ // -- GODOT end -- - } - - void AutoencoderFilter::computeTileSize() -@@ -464,6 +473,11 @@ namespace oidn { - return std::make_shared(); - } - -+// -- GODOT start -- -+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -+#if 0 -+// -- GODOT end -- -+ - // -------------------------------------------------------------------------- - // RTFilter - // -------------------------------------------------------------------------- -@@ -491,6 +505,9 @@ namespace oidn { - weightData.hdr_alb = weights::rt_hdr_alb; - weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; - } -+// -- GODOT start -- -+#endif -+// -- GODOT end -- - - // -------------------------------------------------------------------------- - // RTLightmapFilter -diff --git a/core/autoencoder.h b/core/autoencoder.h -index c199052..98b6108 100644 ---- a/core/autoencoder.h -+++ b/core/autoencoder.h -@@ -93,11 +93,18 @@ namespace oidn { - // RTFilter - Generic ray tracing denoiser - // -------------------------------------------------------------------------- - -+// -- GODOT start -- -+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -+#if 0 -+// -- GODOT end -- - class RTFilter : public AutoencoderFilter - { - public: - explicit RTFilter(const Ref& device); - }; -+// -- GODOT start -- -+#endif -+// -- GODOT end -- - - // -------------------------------------------------------------------------- - // RTLightmapFilter - Ray traced lightmap denoiser -diff --git a/core/common.h b/core/common.h -index a3a7e8a..a35dd90 100644 ---- a/core/common.h -+++ b/core/common.h -@@ -27,7 +27,9 @@ - #include "common/ref.h" - #include "common/exception.h" - #include "common/thread.h" --#include "common/tasking.h" -+// -- GODOT start -- -+//#include "common/tasking.h" -+// -- GODOT end -- - #include "math.h" - - namespace oidn { -diff --git a/core/device.cpp b/core/device.cpp -index c455695..3cd658b 100644 ---- a/core/device.cpp -+++ b/core/device.cpp -@@ -29,7 +29,9 @@ namespace oidn { - - Device::~Device() - { -- observer.reset(); -+ // -- GODOT start -- -+ //observer.reset(); -+ // -- GODOT end -- - } - - void Device::setError(Device* device, Error code, const std::string& message) -@@ -141,6 +143,9 @@ namespace oidn { - if (isCommitted()) - throw Exception(Error::InvalidOperation, "device can be committed only once"); - -+ // -- GODOT start -- -+ #if 0 -+ // -- GODOT end -- - // Get the optimal thread affinities - if (setAffinity) - { -@@ -157,7 +162,10 @@ namespace oidn { - // Automatically set the thread affinities - if (affinity) - observer = std::make_shared(affinity, *arena); -- -+ // -- GODOT start -- -+ #endif -+ numThreads = 1; -+ // -- GODOT end -- - dirty = false; - - if (isVerbose()) -@@ -191,9 +199,17 @@ namespace oidn { - - Ref filter; - -+// -- GODOT start -- -+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -+#if 0 -+// -- GODOT end -- - if (type == "RT") - filter = makeRef(Ref(this)); -- else if (type == "RTLightmap") -+// -- GODOT start -- -+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. -+#endif -+ if (type == "RTLightmap") -+// -- GODOT end -- - filter = makeRef(Ref(this)); - else - throw Exception(Error::InvalidArgument, "unknown filter type"); -@@ -210,11 +226,12 @@ namespace oidn { - std::cout << " Build : " << getBuildName() << std::endl; - std::cout << " Platform: " << getPlatformName() << std::endl; - -- std::cout << " Tasking :"; -- std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR; -- std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version(); -- std::cout << std::endl; -- -+// -- GODOT start -- -+// std::cout << " Tasking :"; -+// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR; -+// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version(); -+// std::cout << std::endl; -+// -- GODOT end -- - std::cout << std::endl; - } - -diff --git a/core/device.h b/core/device.h -index c2df714..d9cfd85 100644 ---- a/core/device.h -+++ b/core/device.h -@@ -41,10 +41,12 @@ namespace oidn { - ErrorFunction errorFunc = nullptr; - void* errorUserPtr = nullptr; - -- // Tasking -- std::shared_ptr arena; -- std::shared_ptr observer; -- std::shared_ptr affinity; -+// -- GODOT start -- -+// // Tasking -+// std::shared_ptr arena; -+// std::shared_ptr observer; -+// std::shared_ptr affinity; -+// -- GODOT end -- - - // Parameters - int numThreads = 0; // autodetect by default -@@ -66,17 +68,19 @@ namespace oidn { - - void commit(); - -- template -- void executeTask(F& f) -- { -- arena->execute(f); -- } -+// -- GODOT start -- -+// template -+// void executeTask(F& f) -+// { -+// arena->execute(f); -+// } - -- template -- void executeTask(const F& f) -- { -- arena->execute(f); -- } -+// template -+// void executeTask(const F& f) -+// { -+// arena->execute(f); -+// } -+// -- GODOT end -- - - Ref newBuffer(size_t byteSize); - Ref newBuffer(void* ptr, size_t byteSize); -@@ -86,7 +90,10 @@ namespace oidn { - __forceinline std::mutex& getMutex() { return mutex; } - - private: -- bool isCommitted() const { return bool(arena); } -+// -- GODOT start -- -+ //bool isCommitted() const { return bool(arena); } -+ bool isCommitted() const { return false; } -+// -- GODOT end -- - void checkCommitted(); - - void print(); -diff --git a/core/network.cpp b/core/network.cpp -index 8c2de09..ed8328c 100644 ---- a/core/network.cpp -+++ b/core/network.cpp -@@ -17,6 +17,9 @@ - #include "upsample.h" - #include "weights_reorder.h" - #include "network.h" -+// -- GODOT start -- -+#include -+// -- GODOT end -- - - namespace oidn { - -diff --git a/core/transfer_function.cpp b/core/transfer_function.cpp -index 601f814..ce5deca 100644 ---- a/core/transfer_function.cpp -+++ b/core/transfer_function.cpp -@@ -38,16 +38,24 @@ namespace oidn { - // Compute the average log luminance of the downsampled image - using Sum = std::pair; - -- Sum sum = -- tbb::parallel_reduce( -- tbb::blocked_range2d(0, HK, 0, WK), -- Sum(0.f, 0), -- [&](const tbb::blocked_range2d& r, Sum sum) -> Sum -+ // -- GODOT start -- -+ // Sum sum = -+ // tbb::parallel_reduce( -+ // tbb::blocked_range2d(0, HK, 0, WK), -+ // Sum(0.f, 0), -+ // [&](const tbb::blocked_range2d& r, Sum sum) -> Sum -+ // { -+ // // Iterate over blocks -+ // for (int i = r.rows().begin(); i != r.rows().end(); ++i) -+ // { -+ // for (int j = r.cols().begin(); j != r.cols().end(); ++j) -+ // { -+ -+ Sum sum = Sum(0.0f, 0); -+ -+ for (int i = 0; i != HK; ++i) - { -- // Iterate over blocks -- for (int i = r.rows().begin(); i != r.rows().end(); ++i) -- { -- for (int j = r.cols().begin(); j != r.cols().end(); ++j) -+ for (int j = 0; j != WK; ++j) - { - // Compute the average luminance in the current block - const int beginH = int(ptrdiff_t(i) * H / HK); -@@ -82,11 +90,12 @@ namespace oidn { - } - } - -- return sum; -- }, -- [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, -- tbb::static_partitioner() -- ); -+ // return sum; -+ // }, -+ // [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, -+ // tbb::static_partitioner() -+ // ); -+ // -- GODOT end -- - - return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f; - } diff --git a/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch b/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch deleted file mode 100644 index 50d94ebff..000000000 --- a/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch +++ /dev/null @@ -1,45 +0,0 @@ -Rediffed by @akien-mga to match oidn 1.1.0 source. - -From 1e42e6db81e1a5270ecc0191c5385ce7e7d978e9 Mon Sep 17 00:00:00 2001 -From: Jeremy Wong -Date: Wed, 11 Sep 2019 04:46:53 +0800 -Subject: [PATCH] src: initialize members in some structures to prevent compile - errors with VS2017 - -addresses "error C3615: constexpr function '...' cannot result in a constant expression" with VS2017 ---- - src/cpu/rnn/rnn_reorders.hpp | 2 +- - src/cpu/simple_concat.hpp | 6 +++--- - src/cpu/simple_sum.hpp | 2 +- - 3 files changed, 5 insertions(+), 5 deletions(-) - -diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp -index 597c63e3f8..ae1551390a 100644 ---- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp -+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp -@@ -131,7 +131,7 @@ struct rnn_weights_reorder_t : public cpu_primitive_t { - return status::success; - } - -- format_tag_t itag_; -+ format_tag_t itag_ = mkldnn_format_tag_undef; - - private: - void init_scratchpad() { -diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp -index 5177275452..057cc3c4c7 100644 ---- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp -+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp -@@ -96,9 +96,9 @@ struct simple_concat_t: public cpu_primitive_t { - return status::success; - } - -- int perm_[MKLDNN_MAX_NDIMS]; -- int iperm_[MKLDNN_MAX_NDIMS]; -- dims_t blocks_; -+ int perm_[MKLDNN_MAX_NDIMS] {}; -+ int iperm_[MKLDNN_MAX_NDIMS] {}; -+ dims_t blocks_ {}; - - dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { - const int ndims = data_d.ndims(); diff --git a/thirdparty/oidn/weights/LICENSE.txt b/thirdparty/oidn/weights/LICENSE.txt deleted file mode 100644 index d64569567..000000000 --- a/thirdparty/oidn/weights/LICENSE.txt +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/thirdparty/oidn/weights/rtlightmap_hdr.tza b/thirdparty/oidn/weights/rtlightmap_hdr.tza deleted file mode 100644 index 12459a33b..000000000 Binary files a/thirdparty/oidn/weights/rtlightmap_hdr.tza and /dev/null differ