From e9bbec3b4ce7993d4a6134896d011e28e9f51eca Mon Sep 17 00:00:00 2001 From: chauneahhin <2645168370@qq.com> Date: Sun, 19 Dec 2021 23:51:00 +0800 Subject: [PATCH] [feat] [assistant] [I4CRJN] [I4CRJM] [I4CRJL] Add MatrixDiagV3, MatrixSetDiagV3 and MatrixDiagPartV3 --- .../kernel/matrix_diag_part_v3_cpu_kernel.cc | 300 +++++++++++++++++ .../kernel/matrix_diag_part_v3_cpu_kernel.h | 73 ++++ .../cpu/kernel/matrix_diag_v3_cpu_kernel.cc | 314 ++++++++++++++++++ .../cpu/kernel/matrix_diag_v3_cpu_kernel.h | 73 ++++ .../kernel/matrix_set_diag_v3_cpu_kernel.cc | 307 +++++++++++++++++ .../kernel/matrix_set_diag_v3_cpu_kernel.h | 75 +++++ .../core/abstract/primitive_infer_map.cc | 9 + mindspore/core/base/core_ops.h | 7 +- mindspore/core/ops/matrix_diag_part.cc | 63 ---- mindspore/core/ops/matrix_diag_part_v3.cc | 174 ++++++++++ ...trix_diag_part.h => matrix_diag_part_v3.h} | 28 +- mindspore/core/ops/matrix_diag_v3.cc | 224 +++++++++++++ mindspore/core/ops/matrix_diag_v3.h | 49 +++ mindspore/core/ops/matrix_set_diag_v3.cc | 182 ++++++++++ mindspore/core/ops/matrix_set_diag_v3.h | 47 +++ .../ops/_grad_experimental/grad_array_ops.py | 84 +++++ .../mindspore/ops/_op_impl/aicpu/__init__.py | 3 + .../ops/_op_impl/aicpu/matrix_diag_part_v3.py | 54 +++ .../ops/_op_impl/aicpu/matrix_diag_v3.py | 56 ++++ .../ops/_op_impl/aicpu/matrix_set_diag_v3.py | 54 +++ .../mindspore/ops/operations/array_ops.py | 233 +++++++++++++ tests/st/scipy_st/matrix_diag_part_test.py | 262 --------------- tests/ut/python/ops/test_ops.py | 100 ++++++ 23 files changed, 2434 insertions(+), 337 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h delete mode 100644 mindspore/core/ops/matrix_diag_part.cc create mode 100644 mindspore/core/ops/matrix_diag_part_v3.cc rename mindspore/core/ops/{matrix_diag_part.h => matrix_diag_part_v3.h} (58%) create mode 100644 mindspore/core/ops/matrix_diag_v3.cc create mode 100644 mindspore/core/ops/matrix_diag_v3.h create mode 100644 mindspore/core/ops/matrix_set_diag_v3.cc create mode 100644 mindspore/core/ops/matrix_set_diag_v3.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py delete mode 100644 tests/st/scipy_st/matrix_diag_part_test.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.cc new file mode 100644 index 000000000000..2b7c8e93c774 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.cc @@ -0,0 +1,300 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kMatrixDiagPartV3InputsNum = 3; +constexpr size_t kMatrixDiagPartV3OutputsNum = 1; +constexpr int64_t kParallelArrayNumSameShape = 2048; // all cores running if data size is too large +constexpr size_t kIndexPaddingValue = 2; +constexpr int64_t ZERO = 0; +static std::pair ComputeTwo(int64_t diag_index, int64_t max_diag_len, int64_t num_rows, + int64_t num_cols, bool align_superdiag, bool align_subdiag) { + bool left_align = (diag_index >= ZERO && align_superdiag) || (diag_index <= ZERO && align_subdiag); + int64_t diag_len = std::min(num_rows + std::min(ZERO, diag_index), num_cols + std::min(ZERO, -diag_index)); + int64_t offset = (left_align) ? ZERO : (max_diag_len - diag_len); + return {diag_len, offset}; +} +} // namespace + +void MatrixDiagPartV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + + if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) { + align_ = common::AnfAlgo::GetNodeAttr(kernel_node, "align"); + if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" || + align_ == "LEFT_RIGHT")) { + MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagPartV3' is not in: 'LEFT_RIGHT', " + "'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'."; + } + if (align_ == "") align_ = "RIGHT_LEFT"; + } else { + align_ = "RIGHT_LEFT"; + } + + auto padding_data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndexPaddingValue); + input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); + + if (padding_data_type != input_dtype_) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, the data type of x need be same with padding_value."; + } + + if (input_dtype_ != output_data_type) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, the data type of x need be same with output."; + } + + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + size_t k_dim_size = k_shape_.size(); + const size_t k_dim_size_max = 1; + if (k_dim_size > k_dim_size_max) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k_dim_size must not be greater than 1, received " << k_dim_size << "."; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "MatrixDiagPartV3 does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; +} + +template +bool MatrixDiagPartV3CpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixDiagPartV3InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixDiagPartV3OutputsNum, kernel_name_); + // k + int64_t lower_diag_index = 0; + upper_diag_index_ = 0; + size_t k_len = static_cast(inputs[1]->size / sizeof(int32_t)); + auto k_Data = reinterpret_cast(inputs[1]->addr); + MS_EXCEPTION_IF_NULL(k_Data); + const size_t k_len_max = 2; + if (k_len == 0 || k_len > k_len_max) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k must have one or two elements, but received " << k_len << "elements."; + } + lower_diag_index = k_Data[0]; + upper_diag_index_ = k_Data[0]; + if (k_len == k_len_max) { + upper_diag_index_ = k_Data[1]; + } + if (!(lower_diag_index <= upper_diag_index_)) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k[0] must not be larger than k[1] . ,received " << lower_diag_index + << " is larger than " << upper_diag_index_; + } + // x + size_t input_dims = x_shape_.size(); + const size_t input_dim_min = 2; + if (input_dims < input_dim_min) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, input x dims must be greater equal than 2 while got " << input_dims + << "."; + } + num_cols_ = SizeToLong(x_shape_[input_dims - 1]); + const size_t toCalRow = 2; + num_rows_ = SizeToLong(x_shape_[input_dims - toCalRow]); + size_t input_numelements = static_cast(inputs[0]->size / sizeof(T)); + num_array_ = (SizeToLong(input_numelements)) / (num_rows_ * num_cols_); + + if (align_ == "LEFT_LEFT" || align_ == "LEFT_RIGHT") { + align_superdiag_ = true; + } else { + align_superdiag_ = false; + } + if (align_ == "LEFT_LEFT" || align_ == "RIGHT_LEFT") { + align_subdiag_ = true; + } else { + align_subdiag_ = false; + } + num_diags_ = upper_diag_index_ - lower_diag_index + 1; + max_diag_len_ = std::min(num_rows_ + std::min(upper_diag_index_, ZERO), num_cols_ - std::max(lower_diag_index, ZERO)); + output_elements_in_batch_ = num_diags_ * max_diag_len_; + data_num_ = num_array_ * output_elements_in_batch_; + return DoLaunch(inputs, outputs); +} + +template +bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector &inputs, + const std::vector &outputs) { + // padding_value + size_t padding_value_num = static_cast(inputs[kIndexPaddingValue]->size / sizeof(T)); + if (!(padding_value_num == 1)) { + MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, padding_value must have only one element, received " + << padding_value_num << " elements. "; + } + auto *padding_value_data = reinterpret_cast(inputs[kIndexPaddingValue]->addr); + MS_EXCEPTION_IF_NULL(padding_value_data); + T padding_value = padding_value_data[0]; + auto output_data = reinterpret_cast(outputs[0]->addr); + MS_EXCEPTION_IF_NULL(output_data); + auto input_data = reinterpret_cast(inputs[0]->addr); + MS_EXCEPTION_IF_NULL(input_data); + size_t Num_array = LongToSize(num_array_); + + if (data_num_ >= kParallelArrayNumSameShape) { + auto task = [this, &output_data, &input_data, padding_value](size_t start, size_t end) { + int64_t out_begin_index = SizeToLong(start * output_elements_in_batch_); + for (size_t index_array = start; index_array < end; index_array++) { + for (int64_t i = 0; i < num_diags_; i++) { + int64_t offset = 0; + int64_t diag_len = 0; + int64_t diag_index = upper_diag_index_ - i; + int64_t col_offset = std::max(ZERO, -diag_index); + int64_t row_offset = std::max(ZERO, diag_index); + std::tie(diag_len, offset) = + ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_); + + for (int64_t n = 0; n < diag_len; n++) { + output_data[LongToSize(out_begin_index + offset + n)] = input_data[LongToSize( + index_array * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)]; + } + const bool left_align = (offset == 0); + const int64_t padding_start = (left_align) ? diag_len : 0; + const int64_t padding_end = (left_align) ? max_diag_len_ : offset; + int64_t n = padding_start; + while (n < padding_end) { + output_data[LongToSize(out_begin_index + n)] = padding_value; + n += 1; + } + out_begin_index += max_diag_len_; + } + } + }; + CPUKernelUtils::ParallelFor(task, Num_array); + } else { + // single core used if data size is not too large + int64_t out_begin_index = 0; + for (int64_t index_array = 0; index_array < num_array_; index_array++) { + for (int64_t i = 0; i < num_diags_; i++) { + int64_t offset = 0; + int64_t diag_len = 0; + int64_t diag_index = upper_diag_index_ - i; + int64_t col_offset = std::max(ZERO, -diag_index); + int64_t row_offset = std::max(ZERO, diag_index); + std::tie(diag_len, offset) = + ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_); + + for (int64_t n = 0; n < diag_len; n++) { + output_data[LongToSize(out_begin_index + offset + n)] = + input_data[LongToSize(index_array * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)]; + } + const bool left_align = (offset == 0); + const int64_t padding_start = (left_align) ? diag_len : 0; + const int64_t padding_end = (left_align) ? max_diag_len_ : offset; + int64_t n = padding_start; + while (n < padding_end) { + output_data[LongToSize(out_begin_index + n)] = padding_value; + n += 1; + } + out_begin_index += max_diag_len_; + } + } + } + return true; +} + +std::vector> + MatrixDiagPartV3CpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &MatrixDiagPartV3CpuKernelMod::LaunchKernel}}; + +std::vector MatrixDiagPartV3CpuKernelMod::GetOpSupport() { + std::vector support_list; + std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixDiagPartV3, MatrixDiagPartV3CpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h new file mode 100644 index 000000000000..7013e3b7ebc8 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h @@ -0,0 +1,73 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MatrixDiagPartV3CpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + MatrixDiagPartV3CpuKernelMod() = default; + ~MatrixDiagPartV3CpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MatrixDiagPartV3Func = std::function &, const std::vector &)>; + static std::vector> func_list_; + MatrixDiagPartV3Func kernel_func_; + + template + bool DoLaunch(const std::vector &inputs, const std::vector &outputs); + + std::vector x_shape_; + std::vector k_shape_; + TypeId input_dtype_; + std::string align_; + int64_t num_diags_ = 1; + int64_t max_diag_len_ = 0; + int64_t output_elements_in_batch_ = 0; + bool align_superdiag_ = true; + bool align_subdiag_ = true; + int64_t num_cols_ = 1; + int64_t num_rows_ = 1; + int64_t upper_diag_index_ = 0; + int64_t data_num_ = 0; + int64_t num_array_ = 0; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.cc new file mode 100644 index 000000000000..43dfc4f4da6a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.cc @@ -0,0 +1,314 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kMatrixDiagV3InputsNum = 5; +constexpr size_t kMatrixDiagV3OutputsNum = 1; +constexpr size_t kIndexNumRow = 2; +constexpr size_t kIndexNumCol = 3; +constexpr size_t kIndexPaddingValue = 4; +static std::pair ComputeTwo(int64_t diag_index, int64_t max_diag_len, int32_t num_rows, + int32_t num_cols, bool align_superdiag, bool align_subdiag) { + const int64_t zero = 0; + bool left_align = (diag_index >= zero && align_superdiag) || (diag_index <= zero && align_subdiag); + int64_t diag_len = std::min(num_rows + std::min(zero, diag_index), num_cols + std::min(zero, -diag_index)); + int64_t offset = (left_align) ? zero : (max_diag_len - diag_len); + return {diag_len, offset}; +} +} // namespace + +void MatrixDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + + if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) { + align_ = common::AnfAlgo::GetNodeAttr(kernel_node, "align"); + if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" || + align_ == "LEFT_RIGHT")) { + MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagV3' is not in: 'LEFT_RIGHT', " + "'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'."; + } + if (align_ == "") align_ = "RIGHT_LEFT"; + } else { + align_ = "RIGHT_LEFT"; + } + + diagonal_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + auto padding_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndexPaddingValue); + auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); + + if (diagonal_data_type_ != padding_type) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, the data type of x need be same with padding_value."; + } + + if (diagonal_data_type_ != output_data_type) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, The data type of x need be same with output."; + } + + diagonal_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + size_t k_dim_size = k_shape_.size(); + const size_t k_dim_size_max = 1; + if (k_dim_size > k_dim_size_max) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, k_dim_size must not be greater than 1, received " << k_dim_size << "."; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "MatrixDiagV3 does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; +} + +template +bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixDiagV3InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixDiagV3OutputsNum, kernel_name_); + lower_diag_index_ = 0; + upper_diag_index_ = 0; + num_rows_ = -1; + num_cols_ = -1; + const size_t diag_rank = diagonal_shape_.size(); + if (diag_rank < 1) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, input x dims must be greater equal than 1 while got " << diag_rank << "."; + } + max_diag_len_ = SizeToLong(diagonal_shape_[diag_rank - 1]); + // k + auto *k_data = reinterpret_cast(inputs[1]->addr); + MS_EXCEPTION_IF_NULL(k_data); + lower_diag_index_ = k_data[0]; + upper_diag_index_ = lower_diag_index_; + size_t k_num = static_cast(inputs[1]->size / sizeof(int32_t)); + const size_t k_num_max = 2; + if (k_num == 0 || k_num > k_num_max) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, k must have one or two elements, but received " << k_num << "elements."; + } + if (k_num == k_num_max) { + upper_diag_index_ = k_data[1]; + } + if (!(lower_diag_index_ <= upper_diag_index_)) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, lower_diag_index must be smaller than upper_diag_index,received " + << lower_diag_index_ << " is larger than " << upper_diag_index_; + } + const int64_t num_diags = upper_diag_index_ - lower_diag_index_ + 1; + // num_rows + size_t num_rows_num = static_cast(inputs[kIndexNumRow]->size / sizeof(int32_t)); + if (!(num_rows_num == 1)) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_rows must have only one element, received " << num_rows_num + << " elements. "; + } + auto *num_rows_data = reinterpret_cast(inputs[kIndexNumRow]->addr); + MS_EXCEPTION_IF_NULL(num_rows_data); + num_rows_ = num_rows_data[0]; + // num_cols + size_t num_cols_num = static_cast(inputs[kIndexNumCol]->size / sizeof(int32_t)); + if (!(num_cols_num == 1)) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_cols must have only one element, received " << num_cols_num + << " elements. "; + } + auto *num_cols_data = reinterpret_cast(inputs[kIndexNumCol]->addr); + MS_EXCEPTION_IF_NULL(num_cols_data); + num_cols_ = num_cols_data[0]; + + const int32_t min_rows = max_diag_len_ + std::max(-upper_diag_index_, 0); + const int32_t min_cols = max_diag_len_ + std::max(lower_diag_index_, 0); + if (num_rows_ != -1 && num_rows_ < min_rows) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of rows is too small."; + } + if (num_cols_ != -1 && num_cols_ < min_cols) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of columns is too small."; + } + if (num_rows_ == -1 && num_cols_ == -1) { + num_rows_ = std::max(min_rows, min_cols); + num_cols_ = num_rows_; + } + if (num_rows_ == -1) { + num_rows_ = min_rows; + } + if (num_cols_ == -1) { + num_cols_ = min_cols; + } + if (num_rows_ != min_rows && num_cols_ != min_cols) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of rows or columns is not consistent with " + "the specified d_lower, d_upper, and diagonal."; + } + diag_elements_in_batch_ = num_diags * max_diag_len_; + diag_batch_base_index_ = 0 * diag_elements_in_batch_; + size_t num_element = static_cast(outputs[0]->size / sizeof(T)); + num_batches_ = (SizeToLong(num_element)) / (num_rows_ * num_cols_); + + return DoLaunch(inputs, outputs); +} + +template +bool MatrixDiagV3CpuKernelMod::DoLaunch(const std::vector &inputs, + const std::vector &outputs) { + align_superdiag_ = align_ == "LEFT_LEFT" || align_ == "LEFT_RIGHT"; + align_subdiag_ = align_ == "LEFT_LEFT" || align_ == "RIGHT_LEFT"; + // padding_value + size_t padding_value_num = static_cast(inputs[kIndexPaddingValue]->size / sizeof(T)); + if (!(padding_value_num == 1)) { + MS_LOG(EXCEPTION) << "For MatrixDiagV3, padding_value must have only one element, received " << padding_value_num + << " elements. "; + } + auto *padding_value_data = reinterpret_cast(inputs[kIndexPaddingValue]->addr); + MS_EXCEPTION_IF_NULL(padding_value_data); + T padding_value = padding_value_data[0]; + + auto *diagonal_data = reinterpret_cast(inputs[0]->addr); + MS_EXCEPTION_IF_NULL(diagonal_data); + auto *output_data = reinterpret_cast(outputs[0]->addr); + MS_EXCEPTION_IF_NULL(output_data); + int64_t elem = 0; + for (int64_t index_array = 0; index_array < num_batches_; index_array++) { + for (int64_t i = 0; i < num_rows_; i++) { + for (int64_t j = 0; j < num_cols_; j++) { + int64_t diag_index = j - i; + int64_t diag_index_in_input = upper_diag_index_ - diag_index; + int64_t diag_len, offset; + std::tie(diag_len, offset) = + ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_); + int64_t index_in_the_diagonal = j - std::max(diag_index, 0) + offset; + if (lower_diag_index_ <= diag_index && diag_index <= upper_diag_index_) { + size_t index = + LongToSize(diag_batch_base_index_ + diag_index_in_input * max_diag_len_ + index_in_the_diagonal); + output_data[LongToSize(elem)] = diagonal_data[index]; + elem++; + } else { + output_data[LongToSize(elem)] = padding_value; + elem++; + } + } + } + diag_batch_base_index_ += diag_elements_in_batch_; + } + return true; +} + +std::vector> MatrixDiagV3CpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &MatrixDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &MatrixDiagV3CpuKernelMod::LaunchKernel}}; + +std::vector MatrixDiagV3CpuKernelMod::GetOpSupport() { + std::vector support_list; + std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixDiagV3, MatrixDiagV3CpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h new file mode 100644 index 000000000000..d3014b8ed160 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h @@ -0,0 +1,73 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MatrixDiagV3CpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + MatrixDiagV3CpuKernelMod() = default; + ~MatrixDiagV3CpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MatrixDiagV3Func = std::function &, + const std::vector &)>; + static std::vector> func_list_; + MatrixDiagV3Func kernel_func_; + + template + bool DoLaunch(const std::vector &inputs, const std::vector &outputs); + + std::vector diagonal_shape_; + std::vector k_shape_; + TypeId diagonal_data_type_; + std::string align_; + bool align_superdiag_ = true; + bool align_subdiag_ = true; + int64_t num_batches_ = 0; + int32_t lower_diag_index_ = 0; + int32_t upper_diag_index_ = 0; + int32_t num_rows_ = -1; + int32_t num_cols_ = -1; + int64_t max_diag_len_ = 1; + int64_t diag_batch_base_index_ = 0; + int64_t diag_elements_in_batch_ = 0; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.cc new file mode 100644 index 000000000000..58f437dded54 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.cc @@ -0,0 +1,307 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kMatrixSetDiagV3InputsNum = 3; +constexpr size_t kMatrixSetDiagV3OutputsNum = 1; +constexpr size_t kParallelDataNum = 64 * 1024; +constexpr size_t kKLengthMax = 2; +constexpr size_t kIndexK = 2; +constexpr int64_t ZERO = 0; +} // namespace + +void MatrixSetDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + + if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) { + align_ = common::AnfAlgo::GetNodeAttr(kernel_node, "align"); + if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" || + align_ == "LEFT_RIGHT")) { + MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixSetDiagV3' is not in: 'LEFT_RIGHT', " + "'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'."; + } + if (align_ == "") align_ = "RIGHT_LEFT"; + } else { + align_ = "RIGHT_LEFT"; + } + + auto diagonal_data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); + input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); + + if (diagonal_data_type != input_dtype_) { + MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, the data type of x need be same diagonal."; + } + + if (input_dtype_ != output_data_type) { + MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, the data type of x need be same with output."; + } + + x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + diagonal_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndexK); + size_t k_dim_size = k_shape_.size(); + const size_t k_dim_size_max = 1; + if (k_dim_size > k_dim_size_max) { + MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k_dim_size must not be greater than 1, received " << k_dim_size << "."; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "MatrixSetDiagV3 does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; +} + +template +bool MatrixSetDiagV3CpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixSetDiagV3InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixSetDiagV3OutputsNum, kernel_name_); + size_t input_dims = x_shape_.size(); + const size_t input_dim_min = 2; + const size_t toCalRow = 2; + if (input_dims < input_dim_min) { + MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, input x dims must be greater equal than 2 while got " << input_dims + << "."; + } + input_columns_ = x_shape_[input_dims - 1]; + input_rows_ = x_shape_[input_dims - toCalRow]; + input_numelements_ = static_cast(inputs[0]->size / sizeof(T)); + + size_t diagonal_dims = diagonal_shape_.size(); + diagonal_columns_ = diagonal_shape_[diagonal_dims - 1]; + diagonal_rows_ = 1; + if (diagonal_dims > 1) { + diagonal_rows_ = diagonal_shape_[diagonal_dims - toCalRow]; + } + + k_len_ = static_cast(inputs[kIndexK]->size / sizeof(int32_t)); + k_lower_ = 0; + k_upper_ = 0; + auto k_Data = reinterpret_cast(inputs[kIndexK]->addr); + MS_EXCEPTION_IF_NULL(k_Data); + if (k_len_ == 0 || k_len_ > kKLengthMax) { + MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k must have only one or two elements, received " << k_len_ + << "elements."; + } + k_lower_ = k_Data[0]; + k_upper_ = k_Data[0]; + if (k_len_ == kKLengthMax) { + k_upper_ = k_Data[1]; + } + if (!(k_lower_ <= k_upper_)) { + MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k[0] must not be larger than k[1] ,received " << k_lower_ + << " is larger than " << k_upper_; + } + max_diag_len_ = std::min(input_rows_ + std::min(k_upper_, ZERO), input_columns_ + std::min(-k_lower_, ZERO)); + + return DoLaunch(inputs, outputs); +} + +template +void MatrixSetDiagV3CpuKernelMod::singleCal(const std::vector &inputs, + const std::vector &outputs) { + auto output_data = reinterpret_cast(outputs[0]->addr); + MS_EXCEPTION_IF_NULL(output_data); + auto diagonal_data = reinterpret_cast(inputs[1]->addr); + MS_EXCEPTION_IF_NULL(diagonal_data); + auto input_data = reinterpret_cast(inputs[0]->addr); + MS_EXCEPTION_IF_NULL(input_data); + if (k_len_ == 1 || (k_len_ == kKLengthMax && k_lower_ == k_upper_)) { + for (size_t elem = 0; elem < input_numelements_; ++elem) { + int64_t t = SizeToLong(elem % (input_rows_ * input_columns_)); + int64_t index = SizeToLong(elem / (input_rows_ * input_columns_)); + int64_t m = t / input_columns_; + int64_t n = t % input_columns_; + int64_t x = n - std::max(k_upper_, ZERO); + if (n - m == k_upper_) + output_data[elem] = diagonal_data[LongToSize(index * diagonal_columns_ + x)]; + else + output_data[elem] = input_data[elem]; + } + } else { + for (size_t elem = 0; elem < input_numelements_; ++elem) { + int64_t t = SizeToLong(elem % (input_rows_ * input_columns_)); + int64_t index = SizeToLong(elem / (input_rows_ * input_columns_)); + int64_t m = t / input_columns_; + int64_t n = t % input_columns_; + int64_t d = n - m; + if (d >= k_lower_ && d <= k_upper_) { + int64_t x = k_upper_ - d; + int64_t offset = 0; + if (((align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT") && d >= 0) || + ((align_ == "LEFT_RIGHT" || align_ == "RIGHT_RIGHT") && d <= 0)) { + offset = max_diag_len_ - std::min(input_columns_ - std::max(d, ZERO), input_rows_ + std::min(d, ZERO)); + } + int64_t y = n - std::max(d, ZERO) + offset; + size_t position = LongToSize(index * diagonal_rows_ * diagonal_columns_ + x * diagonal_columns_ + y); + output_data[elem] = diagonal_data[position]; + } else { + output_data[elem] = input_data[elem]; + } + } + } +} + +template +bool MatrixSetDiagV3CpuKernelMod::DoLaunch(const std::vector &inputs, + const std::vector &outputs) { + auto output_data = reinterpret_cast(outputs[0]->addr); + MS_EXCEPTION_IF_NULL(output_data); + auto diagonal_data = reinterpret_cast(inputs[1]->addr); + MS_EXCEPTION_IF_NULL(diagonal_data); + auto input_data = reinterpret_cast(inputs[0]->addr); + MS_EXCEPTION_IF_NULL(input_data); + + // 64K boundary value to determine whether to use all cores + size_t input_size = inputs[0]->size; + if (input_size < kParallelDataNum) { + singleCal(inputs, outputs); + } else { + auto task = [this, &diagonal_data, &output_data, &input_data](size_t start, size_t end) { + if (k_len_ == 1 || (k_len_ == kKLengthMax && k_lower_ == k_upper_)) { + for (size_t elem = start; elem < end; ++elem) { + int64_t t = SizeToLong(elem % (input_rows_ * input_columns_)); + int64_t index = SizeToLong(elem / (input_rows_ * input_columns_)); + int64_t m = t / input_columns_; + int64_t n = t % input_columns_; + int64_t x = n - std::max(k_upper_, ZERO); + if (n - m == k_upper_) + output_data[elem] = diagonal_data[LongToSize(index * diagonal_columns_ + x)]; + else + output_data[elem] = input_data[elem]; + } + } else { + for (size_t elem = start; elem < end; ++elem) { + int64_t t = SizeToLong(elem % (input_rows_ * input_columns_)); + int64_t index = SizeToLong(elem / (input_rows_ * input_columns_)); + int64_t m = t / input_columns_; + int64_t n = t % input_columns_; + int64_t d = n - m; + if (d >= k_lower_ && d <= k_upper_) { + int64_t x = k_upper_ - d; + int64_t offset = 0; + if (((align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT") && d >= 0) || + ((align_ == "LEFT_RIGHT" || align_ == "RIGHT_RIGHT") && d <= 0)) { + offset = max_diag_len_ - std::min(input_columns_ - std::max(d, ZERO), input_rows_ + std::min(d, ZERO)); + } + int64_t y = n - std::max(d, ZERO) + offset; + size_t position = LongToSize(index * diagonal_rows_ * diagonal_columns_ + x * diagonal_columns_ + y); + output_data[elem] = diagonal_data[position]; + } else { + output_data[elem] = input_data[elem]; + } + } + } + }; + CPUKernelUtils::ParallelFor(task, input_numelements_); + } + return true; +} + +std::vector> + MatrixSetDiagV3CpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt32), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &MatrixSetDiagV3CpuKernelMod::LaunchKernel}}; + +std::vector MatrixSetDiagV3CpuKernelMod::GetOpSupport() { + std::vector support_list; + std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixSetDiagV3, MatrixSetDiagV3CpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h new file mode 100644 index 000000000000..e472c73aa8a0 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class MatrixSetDiagV3CpuKernelMod : public DeprecatedNativeCpuKernelMod { + public: + MatrixSetDiagV3CpuKernelMod() = default; + ~MatrixSetDiagV3CpuKernelMod() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + } + + protected: + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MatrixSetDiagV3Func = std::function &, + const std::vector &)>; + static std::vector> func_list_; + MatrixSetDiagV3Func kernel_func_; + + template + bool DoLaunch(const std::vector &inputs, const std::vector &outputs); + template + void singleCal(const std::vector &inputs, const std::vector &outputs); + + std::vector diagonal_shape_; + std::vector k_shape_; + std::vector x_shape_; + TypeId input_dtype_; + std::string align_; + size_t input_columns_ = 1; + size_t input_rows_ = 1; + size_t diagonal_columns_ = 1; + size_t diagonal_rows_ = 1; + size_t k_len_ = 0; + int64_t k_lower_ = 0; + int64_t k_upper_ = 0; + int64_t max_diag_len_ = 0; + size_t input_numelements_ = 0; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_ diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index ab2a42431290..5323f32139c8 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -40,6 +40,9 @@ #include "utils/ms_context.h" #include "ops/tile.h" #include "ops/slice.h" +#include "ops/matrix_diag_part_v3.h" +#include "ops/matrix_diag_v3.h" +#include "ops/matrix_set_diag_v3.h" #include "ops/grad/slice_grad.h" #include "ops/lstm.h" @@ -54,6 +57,9 @@ std::set GetDependsFormMap(const std::string &prim_name, size_t input_n static const auto &kStridedSlice = prim::kPrimStridedSlice->name(); static const auto &kStridedSliceGrad = prim::kPrimStridedSliceGrad->name(); static const auto &kReduceSum = prim::kPrimReduceSum->name(); + static const auto &kMatrixDiagV3 = prim::kPrimMatrixDiagV3->name(); + static const auto &kMatrixDiagPartV3 = prim::kPrimMatrixDiagPartV3->name(); + static const auto &kMatrixSetDiagV3 = prim::kPrimMatrixSetDiagV3->name(); static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name(); static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name(); static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name(); @@ -74,6 +80,9 @@ std::set GetDependsFormMap(const std::string &prim_name, size_t input_n static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}}, {kUnsortedSegmentMin, ShapeSet{2}}, {kUnsortedSegmentMax, ShapeSet{2}}, + {kMatrixDiagV3, ShapeSet{1, 2, 3, 4}}, + {kMatrixDiagPartV3, ShapeSet{1, 2}}, + {kMatrixSetDiagV3, ShapeSet{2}}, {kGather, ShapeSet{2}}, {kGatherV2, ShapeSet{2}}, {kSparseGatherV2, ShapeSet{2}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 6f721648ec77..fd599df98912 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -125,6 +125,9 @@ constexpr auto kConcat = "Concat"; constexpr auto kRightShift = "RightShift"; constexpr auto kDiag = "Diag"; constexpr auto kDiagPart = "DiagPart"; +constexpr auto kMatrixDiagV3 = "MatrixDiagV3"; +constexpr auto kMatrixDiagPartV3 = "MatrixDiagPartV3"; +constexpr auto kMatrixSetDiagV3 = "MatrixSetDiagV3"; constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; constexpr auto kTranspose = "Transpose"; constexpr auto kSplitV = "SplitV"; @@ -369,6 +372,9 @@ GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared("MaskedFill" GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared("MaskedSelect")); GVAR_DEF(PrimitivePtr, kPrimDiag, std::make_shared(kDiag)); GVAR_DEF(PrimitivePtr, kPrimDiagPart, std::make_shared(kDiagPart)); +GVAR_DEF(PrimitivePtr, kPrimMatrixDiagV3, std::make_shared(kMatrixDiagV3)); +GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPartV3, std::make_shared(kMatrixDiagPartV3)); +GVAR_DEF(PrimitivePtr, kPrimMatrixSetDiagV3, std::make_shared(kMatrixSetDiagV3)); GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared("NonZero")); GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared(kRealInner)); GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared(kReal)); @@ -661,7 +667,6 @@ GVAR_DEF(PrimitivePtr, kPrimAddcmul, std::make_shared(kAddcmul)); GVAR_DEF(PrimitivePtr, kPrimMatMul, std::make_shared("MatMul")); GVAR_DEF(PrimitivePtr, kPrimMatMulV2, std::make_shared("MatMulV2")); GVAR_DEF(PrimitivePtr, kPrimMatrixDiag, std::make_shared("MatrixDiag")); -GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPart, std::make_shared("MatrixDiagPartV3")); GVAR_DEF(PrimitivePtr, kPrimBatchMatMul, std::make_shared("BatchMatMul")); GVAR_DEF(PrimitivePtr, kPrimBatchMatMulV2, std::make_shared("BatchMatMulV2")); GVAR_DEF(PrimitivePtr, kPrimMaximumGrad, std::make_shared("MaximumGrad")); diff --git a/mindspore/core/ops/matrix_diag_part.cc b/mindspore/core/ops/matrix_diag_part.cc deleted file mode 100644 index bea264a1ab9b..000000000000 --- a/mindspore/core/ops/matrix_diag_part.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ops/matrix_diag_part.h" -#include -#include "abstract/primitive_infer_map.h" -#include "utils/check_convert_utils.h" -#include "abstract/utils.h" -#include "ops/op_utils.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -namespace { -abstract::ShapePtr MatrixDiagPartInferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto input_shape = input_args[0]->BuildShape(); - auto shape_element = input_shape->cast(); - ShapeVector shape = shape_element->shape(); - ShapeVector min_shape = shape_element->shape(); - ShapeVector max_shape = shape_element->shape(); - const constexpr int64_t kShape2 = 2; - max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1; - min_shape[shape.size() - 1] = 1; - shape[shape.size() - 1] = abstract::Shape::SHP_ANY; - return std::make_shared(shape, min_shape, max_shape); -} - -TypePtr MatrixDiagPartInferType(const PrimitivePtr &prim, const std::vector &input_args) { - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto infer_type = input_args[0]->BuildType(); - MS_EXCEPTION_IF_NULL(infer_type); - const std::set valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, kInt64}; - CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name()); - return infer_type; -} -} // namespace - -MIND_API_OPERATOR_IMPL(MatrixDiagPartV3, BaseOperator); -AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - return abstract::MakeAbstract(MatrixDiagPartInferShape(primitive, input_args), - MatrixDiagPartInferType(primitive, input_args)); -} -REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true); -} // namespace ops -} // namespace mindspore diff --git a/mindspore/core/ops/matrix_diag_part_v3.cc b/mindspore/core/ops/matrix_diag_part_v3.cc new file mode 100644 index 000000000000..c1eeb345d631 --- /dev/null +++ b/mindspore/core/ops/matrix_diag_part_v3.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/matrix_diag_part_v3.h" +#include +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" +#include "abstract/param_validator.h" +#include "abstract/utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +int64_t TrueValueCal(const std::vector &input_args) { + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto rank = SizeToLong(x_shape.size()); + int64_t true_value = 1; + const int64_t number_two = 2; + for (int64_t i = 0; i < rank - number_two; i++) { + true_value *= x_shape[i]; + } + return true_value; +} +abstract::ShapePtr MatrixDiagPartV3InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t kNumber1 = 1; + const int64_t kNumber2 = 2; + auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto k_rank = SizeToLong(k_shape.size()); + CheckAndConvertUtils::CheckInRange("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name); + auto padding_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto padding_value_rank = SizeToLong(padding_shape.size()); + CheckAndConvertUtils::CheckInteger("padding_value rank", padding_value_rank, kEqual, 0, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto rank = SizeToLong(x_shape.size()); + CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name); + int64_t row = x_shape[rank - kNumber2]; + int64_t col = x_shape[rank - 1]; + if (input_args[kInputIndex1]->isa() && + input_args[kInputIndex1]->BuildValue()->isa()) { + auto k = input_args[kInputIndex1]->cast(); + MS_EXCEPTION_IF_NULL(k); + auto k_value_ptr = k->BuildValue(); + MS_EXCEPTION_IF_NULL(k_value_ptr); + auto k_tensor = k_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(k_tensor); + auto k_val = reinterpret_cast(k_tensor->data_c()); + size_t k_val_size = LongToSize(k_tensor->DataSize()); + CheckAndConvertUtils::CheckInRange("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2}, + prim_name); + if (input_args[kInputIndex2]->isa() && + input_args[kInputIndex2]->BuildValue()->isa()) { + auto padding_value = input_args[kInputIndex2]->cast(); + MS_EXCEPTION_IF_NULL(padding_value); + auto padding_value_ptr = padding_value->BuildValue(); + MS_EXCEPTION_IF_NULL(padding_value_ptr); + auto padding_value_tensor = padding_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(padding_value_tensor); + size_t padding_value_size = LongToSize(padding_value_tensor->DataSize()); + CheckAndConvertUtils::CheckInteger("padding_value size", SizeToLong(padding_value_size), kEqual, kNumber1, + prim_name); + } else { + MS_EXCEPTION(TypeError) << "For " << prim_name << ", input k and padding_value must be const Tensor."; + } + std::vector out_shape; + (void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - kNumber2); + int64_t max_diag_len = 0; + int64_t true_value = TrueValueCal(input_args); + if (!(k_val[0] > -row && k_val[0] < col)) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1])," + << " meaning the value of k must be in (" << -row << ", " << col << ") in this case" + << ", but got " << k_val[0] << "."; + } + if (k_val_size == 1 || k_val[0] == k_val[1]) { + max_diag_len = std::min(row + std::min(k_val[0], 0), col + std::min(-k_val[0], 0)); + out_shape.push_back(max_diag_len); + true_value *= max_diag_len; + } else { + if (!(k_val[1] > -row && k_val[1] < col)) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1])," + << " meaning the value of k must be in (" << -row << ", " << col << ") in this case" + << ", but got " << k_val[1] << "."; + } + if (!(k_val[0] <= k_val[1])) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1]."; + } + max_diag_len = std::min(row + std::min(k_val[1], 0), col + std::min(-k_val[0], 0)); + out_shape.push_back(k_val[1] - k_val[0] + 1); + out_shape.push_back(max_diag_len); + true_value *= max_diag_len; + true_value *= (k_val[1] - k_val[0] + 1); + } + auto max_length_ptr = primitive->GetAttr("max_length"); + MS_EXCEPTION_IF_NULL(max_length_ptr); + int64_t max_value = GetValue(max_length_ptr); + if (true_value > max_value) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", the number of elements of output must be less than max length: " << max_value + << ", but got " << true_value + << "! The shape of output should be reduced or max_length should be increased."; + } + return std::make_shared(out_shape); + } else { + ShapeVector out_shape = {-2}; + ShapeVector infer_shape_min = {0}; + int64_t max_value = (row + col) * std::max(row, col); + for (int64_t i = 0; i < rank - kNumber2; i++) { + max_value *= x_shape[i]; + } + ShapeVector infer_shape_max = {max_value}; + return std::make_shared(out_shape, infer_shape_min, infer_shape_max); + } +} + +TypePtr MatrixDiagPartV3InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + + auto x = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); + auto padding_value = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex2); + + (void)abstract::CheckDtypeSame(prim_name, x, padding_value); + + auto x_type = input_args[kInputIndex0]->BuildType(); + MS_EXCEPTION_IF_NULL(x_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name); + + const std::set valid_type = {kInt32}; + auto k_type = input_args[kInputIndex1]->BuildType(); + MS_EXCEPTION_IF_NULL(k_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name); + + return x_type; +} +} // namespace + +MIND_API_OPERATOR_IMPL(MatrixDiagPartV3, BaseOperator); +AbstractBasePtr MatrixDiagPartV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto infer_type = MatrixDiagPartV3InferType(primitive, input_args); + auto infer_shape = MatrixDiagPartV3InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPartV3, MatrixDiagPartV3Infer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/matrix_diag_part.h b/mindspore/core/ops/matrix_diag_part_v3.h similarity index 58% rename from mindspore/core/ops/matrix_diag_part.h rename to mindspore/core/ops/matrix_diag_part_v3.h index 33f09eb38798..3ba40a301561 100644 --- a/mindspore/core/ops/matrix_diag_part.h +++ b/mindspore/core/ops/matrix_diag_part_v3.h @@ -1,5 +1,5 @@ /** - * Copyright 2021-2022 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,28 +14,34 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_ -#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_ -#include -#include +#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_ +#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_ +#include +#include +#include +#include #include "ops/base_operator.h" #include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMatrixDiagPartV3 = "MatrixDiagPartV3"; -/// \brief get the specified part of the inner most diag matrix of a matrix, fill with padding value . -/// Refer to Python API @ref mindspore.ops.MatrixDiagPart for more details. + +/// \brief Returns the batched diagonal part of a batched tensor. +/// Refer to Python API @ref mindspore.ops.MatrixDiagPartV3 for more details. class MIND_API MatrixDiagPartV3 : public BaseOperator { public: MIND_API_BASE_MEMBER(MatrixDiagPartV3); /// \brief Constructor. - MatrixDiagPartV3() : BaseOperator(kNameMatrixDiagPartV3) { InitIOName({"input", "k", "padding_value"}, {"output"}); } + MatrixDiagPartV3() : BaseOperator(kNameMatrixDiagPartV3) { InitIOName({"x", "k", "padding_value"}, {"y"}); } }; -abstract::AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MatrixDiagPartV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); + +using PrimMatrixDiagPartV3Ptr = std::shared_ptr; } // namespace ops } // namespace mindspore -#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_ + +#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_ diff --git a/mindspore/core/ops/matrix_diag_v3.cc b/mindspore/core/ops/matrix_diag_v3.cc new file mode 100644 index 000000000000..817f41c606bc --- /dev/null +++ b/mindspore/core/ops/matrix_diag_v3.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/matrix_diag_v3.h" +#include +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" +#include "abstract/param_validator.h" +#include "abstract/utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +const int64_t kNumber1 = 1; +const int64_t kNumber2 = 2; +void CheckTrueValueValidAndKValue(const std::vector &input_args, int64_t row_val, int64_t col_val, + int64_t additional_value, int64_t max_value, int *k_val, size_t k_val_size) { + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto rank = SizeToLong(x_shape.size()); + int64_t true_value = 1; + for (int64_t i = 0; i < rank - kNumber2; i++) { + true_value *= x_shape[i]; + } + true_value *= additional_value; + true_value *= (row_val * col_val); + if (true_value > max_value) { + MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the number of elements of output must be less than max length: " + << max_value << ", but got " << true_value + << "! The shape of output should be reduced or max_length should be increased."; + } + if (!(k_val[0] > -row_val && k_val[0] < col_val)) { + MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the value of k must be in (-num_rows, num_cols), " + << "meaning the value of k must be in (" << -row_val << ", " << col_val + << ") in this case, but got " << k_val[0] << "."; + } + if (k_val_size == kNumber2 && k_val[0] != k_val[1]) { + if (!(k_val[1] > -row_val && k_val[1] < col_val)) { + MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the value of k must be in (-num_rows, num_cols), " + << "meaning the value of k must be in (" << -row_val << ", " << col_val + << ") in this case, but got " << k_val[1] << "."; + } + } +} +int64_t GetValAndCheckSize(const PrimitivePtr &primitive, const std::vector &input_args, + size_t index) { + // get value of specified input and check its size + auto prim_name = primitive->name(); + if (input_args[index]->isa() && input_args[index]->BuildValue()->isa()) { + auto abstract_tensor = input_args[index]->cast(); + MS_EXCEPTION_IF_NULL(abstract_tensor); + auto tensor_value_ptr = abstract_tensor->BuildValue(); + MS_EXCEPTION_IF_NULL(tensor_value_ptr); + auto specified_tensor = tensor_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(specified_tensor); + size_t tensor_val_size = LongToSize(specified_tensor->DataSize()); + if (index == kInputIndex2) { + CheckAndConvertUtils::CheckInteger("num_rows size", SizeToLong(tensor_val_size), kEqual, kNumber1, prim_name); + } else if (index == kInputIndex3) { + CheckAndConvertUtils::CheckInteger("num_cols size", SizeToLong(tensor_val_size), kEqual, kNumber1, prim_name); + } else if (index == kInputIndex4) { + CheckAndConvertUtils::CheckInteger("padding_value size", SizeToLong(tensor_val_size), kEqual, kNumber1, + prim_name); + return 0; + } + auto tensor_ptr = reinterpret_cast(specified_tensor->data_c()); + int64_t tensor_val = static_cast(*tensor_ptr); + return tensor_val; + } else { + MS_EXCEPTION(TypeError) << "For " << prim_name + << ", input k, num_rows, num_cols and padding_value must be const Tensor."; + } +} +abstract::ShapePtr MatrixDiagV3InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + auto prim_name = primitive->name(); // then get shape and check rank + auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto row_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto col_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + auto padding_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; + auto k_rank = SizeToLong(k_shape.size()); + auto row_rank = SizeToLong(row_shape.size()); + auto col_rank = SizeToLong(col_shape.size()); + auto padding_value_rank = SizeToLong(padding_shape.size()); + CheckAndConvertUtils::CheckInteger("num_rows rank", row_rank, kEqual, 0, prim_name); + CheckAndConvertUtils::CheckInteger("num_cols rank", col_rank, kEqual, 0, prim_name); + CheckAndConvertUtils::CheckInteger("padding_value rank", padding_value_rank, kEqual, 0, prim_name); + CheckAndConvertUtils::CheckInRange("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto rank = SizeToLong(x_shape.size()); + CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber1, prim_name); + int64_t max_diag_len = x_shape[rank - 1]; + auto max_length_ptr = primitive->GetAttr("max_length"); + MS_EXCEPTION_IF_NULL(max_length_ptr); + int64_t max_value = GetValue(max_length_ptr); + if (input_args[kInputIndex1]->isa() && + input_args[kInputIndex1]->BuildValue()->isa()) { + auto k = input_args[kInputIndex1]->cast(); // get k value and check its size + MS_EXCEPTION_IF_NULL(k); + auto k_value_ptr = k->BuildValue(); + MS_EXCEPTION_IF_NULL(k_value_ptr); + auto k_tensor = k_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(k_tensor); + auto k_val = reinterpret_cast(k_tensor->data_c()); + size_t k_val_size = LongToSize(k_tensor->DataSize()); + CheckAndConvertUtils::CheckInRange("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2}, + prim_name); + int64_t row_val = GetValAndCheckSize(primitive, input_args, kInputIndex2); // get row value and check its size + int64_t col_val = GetValAndCheckSize(primitive, input_args, kInputIndex3); // get col value and check its size + (void)GetValAndCheckSize(primitive, input_args, kInputIndex4); // check size of padding_value + std::vector out_shape; // calculate out_shape + int64_t min_num_rows, min_num_cols; + int64_t additional_value = 1; + if (k_val_size == 1 || k_val[0] == k_val[1]) { + min_num_rows = max_diag_len - std::min(k_val[0], 0); + min_num_cols = max_diag_len + std::max(k_val[0], 0); + (void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - 1); + additional_value = x_shape[rank - kNumber2]; + } else { + if (!(k_val[0] <= k_val[1])) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1]."; + int64_t num_diags = k_val[1] - k_val[0] + 1; + CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name); + if (x_shape[rank - kNumber2] != num_diags) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the input x_shape[-2] doesn't match with k value."; + min_num_rows = max_diag_len - std::min(k_val[1], 0); + min_num_cols = max_diag_len + std::max(k_val[0], 0); + (void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - kNumber2); + } + if (row_val != -1 && row_val < min_num_rows) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows is too small."; + if (col_val != -1 && col_val < min_num_cols) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of columns is too small."; + if (row_val == -1 && col_val == -1) { + row_val = std::max(min_num_rows, min_num_cols); + col_val = row_val; + } else if (row_val == -1) { + row_val = min_num_rows; + } else if (col_val == -1) { + col_val = min_num_cols; + } + if (!(row_val == min_num_rows || col_val == min_num_cols)) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows or columns is not consistent with " + << "the specified k and x."; + CheckTrueValueValidAndKValue(input_args, row_val, col_val, additional_value, max_value, k_val, k_val_size); + out_shape.push_back(row_val); + out_shape.push_back(col_val); + return std::make_shared(out_shape); + } else { + ShapeVector out_shape = {-2}; + ShapeVector infer_shape_min = {0}; + ShapeVector infer_shape_max = {max_value}; + return std::make_shared(out_shape, infer_shape_min, infer_shape_max); + } +} + +TypePtr MatrixDiagV3InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + + auto x = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex2); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex3); + auto padding_value = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex4); + + (void)abstract::CheckDtypeSame(prim_name, x, padding_value); + + auto x_type = input_args[kInputIndex0]->BuildType(); + MS_EXCEPTION_IF_NULL(x_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name); + + const std::set valid_type = {kInt32}; + + auto k_type = input_args[kInputIndex1]->BuildType(); + MS_EXCEPTION_IF_NULL(k_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name); + + auto row_type = input_args[kInputIndex2]->BuildType(); + MS_EXCEPTION_IF_NULL(row_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("num_rows", row_type, valid_type, prim_name); + + auto col_type = input_args[kInputIndex3]->BuildType(); + MS_EXCEPTION_IF_NULL(col_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("num_cols", col_type, valid_type, prim_name); + + return x_type; +} +} // namespace + +MIND_API_OPERATOR_IMPL(MatrixDiagV3, BaseOperator); +AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 5; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto infer_type = MatrixDiagV3InferType(primitive, input_args); + auto infer_shape = MatrixDiagV3InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagV3, prim::kPrimMatrixDiagV3, MatrixDiagV3Infer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/matrix_diag_v3.h b/mindspore/core/ops/matrix_diag_v3.h new file mode 100644 index 000000000000..b3e4c95c13e0 --- /dev/null +++ b/mindspore/core/ops/matrix_diag_v3.h @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_ +#define MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_ + +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMatrixDiagV3 = "MatrixDiagV3"; + +/// \brief Returns a batched diagonal tensor with given batched diagonal values. +/// Refer to Python API @ref mindspore.ops.MatrixDiagV3 for more details. +class MIND_API MatrixDiagV3 : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MatrixDiagV3); + /// \brief Constructor. + MatrixDiagV3() : BaseOperator(kNameMatrixDiagV3) { + InitIOName({"x", "k", "num_rows", "num_cols", "padding_value"}, {"y"}); + } +}; + +abstract::AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); + +using PrimMatrixDiagV3Ptr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_ diff --git a/mindspore/core/ops/matrix_set_diag_v3.cc b/mindspore/core/ops/matrix_set_diag_v3.cc new file mode 100644 index 000000000000..aadadb7fca71 --- /dev/null +++ b/mindspore/core/ops/matrix_set_diag_v3.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/matrix_set_diag_v3.h" +#include +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" +#include "abstract/param_validator.h" +#include "abstract/utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +void TrueValueCalAndCheck(const std::vector &input_args, int64_t max_value) { + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto rank = SizeToLong(x_shape.size()); + int64_t true_value = 1; + for (int64_t i = 0; i < rank; i++) { + true_value *= x_shape[i]; + } + if (true_value > max_value) { + MS_EXCEPTION(ValueError) << "For MatrixSetDiagV3" + << ", the number of elements of output must be less than max length: " << max_value + << ", but got " << true_value + << "! The shape of output should be reduced or max_length should be increased."; + } +} +abstract::ShapePtr MatrixSetDiagV3InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t kNumber2 = 2; + const int64_t kNumber1 = 1; + auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto k_rank = SizeToLong(k_shape.size()); + CheckAndConvertUtils::CheckInRange("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto rank = SizeToLong(x_shape.size()); + CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name); + auto diagonal_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto diagonal_rank = SizeToLong(diagonal_shape.size()); + auto max_length_ptr = primitive->GetAttr("max_length"); + MS_EXCEPTION_IF_NULL(max_length_ptr); + int64_t max_value = GetValue(max_length_ptr); + TrueValueCalAndCheck(input_args, max_value); + if (input_args[kInputIndex2]->isa() && + input_args[kInputIndex2]->BuildValue()->isa()) { + int64_t row = x_shape[rank - kNumber2]; + int64_t col = x_shape[rank - 1]; + auto k = input_args[kInputIndex2]->cast(); + MS_EXCEPTION_IF_NULL(k); + auto k_value_ptr = k->BuildValue(); + MS_EXCEPTION_IF_NULL(k_value_ptr); + auto k_tensor = k_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(k_tensor); + auto k_val = reinterpret_cast(k_tensor->data_c()); + size_t k_val_size = LongToSize(k_tensor->DataSize()); + CheckAndConvertUtils::CheckInRange("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2}, + prim_name); + int64_t max_diag_len = 0; + CheckAndConvertUtils::CheckInteger("diagonal rank", diagonal_rank, kGreaterEqual, kNumber1, prim_name); + int64_t last_shape_diagonal = diagonal_shape[diagonal_rank - 1]; + if (!(k_val[0] > -row && k_val[0] < col)) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1])," + << " meaning the value of k must be in (" << -row << ", " << col << ") in this case" + << ", but got " << k_val[0] << "."; + } + if (k_val_size == 1 || k_val[0] == k_val[1]) { + if (SizeToLong(diagonal_rank) != rank - 1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal rank size don't match with x rank size."; + } + for (int64_t i = 0; i < rank - kNumber2; i++) { + if (diagonal_shape[i] != x_shape[i]) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal shape value don't match with x shape value."; + } + max_diag_len = std::min(row + std::min(k_val[0], 0), col + std::min(-k_val[0], 0)); + } else { + if (!(k_val[1] > -row && k_val[1] < col)) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1])," + << " meaning the value of k must be in (" << -row << ", " << col << ") in this case" + << ", but got " << k_val[1] << "."; + } + if (!(k_val[0] <= k_val[1])) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1]."; + } + if (SizeToLong(diagonal_rank) != rank) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal rank size don't match with x rank size."; + } + for (int64_t i = 0; i < rank - kNumber2; i++) { + if (diagonal_shape[i] != x_shape[i]) + MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal shape value don't match with x shape value."; + } + max_diag_len = std::min(row + std::min(k_val[1], 0), col + std::min(-k_val[0], 0)); + int64_t in_row_diagonal = diagonal_shape[diagonal_rank - kNumber2]; + int64_t num_diags = k_val[1] - k_val[0] + 1; + if (num_diags != in_row_diagonal) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", diagonal.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1, " + << "which value is " << num_diags + << " in this case, but got diagonal.shape[-2]: " << in_row_diagonal + << " in this case."; + } + } + if (max_diag_len != last_shape_diagonal) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal.shape[-1] is not equal to " + << "max_diag_len calculated by min(x.shape[-2] + min(k[1], 0), x.shape[-1] + " + << "min(-k[0], 0)), which value is " << max_diag_len + << " in this case, but got diagonal.shape[-1]: " << last_shape_diagonal + << " in this case."; + } + return std::make_shared(x_shape); + } else { + ShapeVector out_shape; + ShapeVector infer_shape_min; + ShapeVector infer_shape_max; + (void)infer_shape_max.insert(infer_shape_max.end(), x_shape.begin(), x_shape.end()); + for (int64_t i = 0; i < rank; i++) { + out_shape.push_back(-1); + infer_shape_min.push_back(0); + } + return std::make_shared(out_shape, infer_shape_min, infer_shape_max); + } +} + +TypePtr MatrixSetDiagV3InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + + auto x = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); + auto diagonal = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex2); + + (void)abstract::CheckDtypeSame(prim_name, x, diagonal); + + auto x_type = input_args[kInputIndex0]->BuildType(); + MS_EXCEPTION_IF_NULL(x_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name); + + const std::set valid_type = {kInt32}; + auto k_type = input_args[kInputIndex2]->BuildType(); + MS_EXCEPTION_IF_NULL(k_type); + (void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name); + + return x_type; +} +} // namespace + +MIND_API_OPERATOR_IMPL(MatrixSetDiagV3, BaseOperator); +AbstractBasePtr MatrixSetDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto infer_type = MatrixSetDiagV3InferType(primitive, input_args); + auto infer_shape = MatrixSetDiagV3InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(MatrixSetDiagV3, prim::kPrimMatrixSetDiagV3, MatrixSetDiagV3Infer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/matrix_set_diag_v3.h b/mindspore/core/ops/matrix_set_diag_v3.h new file mode 100644 index 000000000000..5c745ca827b8 --- /dev/null +++ b/mindspore/core/ops/matrix_set_diag_v3.h @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_ +#define MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_ + +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMatrixSetDiagV3 = "MatrixSetDiagV3"; + +/// \brief Returns a batched matrix tensor with new batched diagonal values. +/// Refer to Python API @ref mindspore.ops.MatrixSetDiagV3 for more details. +class MIND_API MatrixSetDiagV3 : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MatrixSetDiagV3); + /// \brief Constructor. + MatrixSetDiagV3() : BaseOperator(kNameMatrixSetDiagV3) { InitIOName({"x", "diagonal", "k"}, {"y"}); } +}; + +abstract::AbstractBasePtr MatrixSetDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); + +using PrimMatrixSetDiagV3Ptr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_ diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index e84b6d0aa409..cb13c9ffd683 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -15,14 +15,19 @@ """array_ops""" +from mindspore import Tensor from ...common import dtype as mstype from .._grad.grad_math_ops import binop_grad_common from .._grad.grad_base import bprop_getters from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.array_ops import Tril +from ..operations.array_ops import MatrixDiagV3 +from ..operations.array_ops import MatrixDiagPartV3 +from ..operations.array_ops import MatrixSetDiagV3 from ..operations.array_ops import Triu from .. import functional as F from .. import operations as P +from .._utils.utils import is_shape_unknown @bprop_getters.register(P.MaskedFill) @@ -61,6 +66,85 @@ def get_bprop_tensor_scatter_sub(self): return bprop +@bprop_getters.register(MatrixDiagV3) +def get_bprop_matrix_diag_v3(self): + """Generate bprop for MatrixDiagV3""" + align = self.align + matrix_diag_part_v3 = MatrixDiagPartV3(align=align) + zeros = P.Zeros() + + def bprop(x, k, num_rows, num_cols, padding_value, out, dout): + result = (matrix_diag_part_v3(dout, k, zeros((), dout.dtype)), zeros_like(k), zeros_like(num_rows), + zeros_like(num_cols), zeros_like(padding_value)) + return result + + return bprop + + +@bprop_getters.register(MatrixDiagPartV3) +def get_bprop_matrix_diag_part_v3(self): + """Generate bprop for MatrixDiagPartV3""" + align = self.align + matrix_diag_v3 = MatrixDiagV3(align=align) + matrix_set_diag_v3 = MatrixSetDiagV3(align=align) + zeros = P.Zeros() + + def bprop(x, k, padding_value, out, dout): + shape_this = P.Shape()(x)[-2:] + if not is_shape_unknown(shape_this): + row = shape_this[0] + col = shape_this[1] + result = (matrix_diag_v3(dout, k, Tensor(row, dtype=mstype.int32), Tensor(col, dtype=mstype.int32), + zeros((), dout.dtype)), zeros_like(k), zeros_like(padding_value)) + else: + result = (matrix_set_diag_v3(zeros_like(x), dout, k), zeros_like(k), zeros_like(padding_value)) + return result + + return bprop + + +@bprop_getters.register(MatrixSetDiagV3) +def get_bprop_matrix_set_diag_v3(self): + """Generate bprop for MatrixSetDiagV3""" + align = self.align + matrix_diag_part_v3 = MatrixDiagPartV3(align=align) + matrix_set_diag_v3 = MatrixSetDiagV3(align=align) + resha = P.Reshape() + zeros = P.Zeros() + minimum = P.Minimum() + concat = P.Concat() + + def bprop(x, diagonal, k, out, dout): + diagonal_cal = matrix_diag_part_v3(dout, k, zeros((), dout.dtype)) + + diagonal_shape = P.Shape()(diagonal) + if is_shape_unknown(diagonal_shape): + shape_dout = P.Shape()(dout) + pre_shape = shape_dout[:-2] + back_shape = shape_dout[-2:] + + site_dia = resha(k, (-1)) + index_min = -1 * site_dia[0] + index_max = site_dia[-1] + col = 0 + if index_max < 0: + col = index_max + row = 0 + if index_min < 0: + row = index_min + max_diag_len = minimum(back_shape[0] + col, back_shape[1] + row) + + back = [max_diag_len] + if index_max != index_min: + back = [index_max-index_min+1, max_diag_len] + diagonal_shape = concat([pre_shape, back]) + x_cal = matrix_set_diag_v3(dout, zeros(diagonal_shape, dout.dtype), k) + + return x_cal, diagonal_cal, zeros_like(k) + + return bprop + + def tensor_scatter_possible_replacement(x, indices, updates, out, dout): """bpropr for any TensorScatter* op that possibly replaces values in the input tensor""" gather_nd = P.GatherNd() diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 6ab85c3f874b..bb44ad40d65e 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -109,6 +109,9 @@ from .stack_push_pop import _stack_pop_aicpu from .asinh import _asinh_aicpu from .asinh_grad import _asinh_grad_aicpu from .stack_push_pop import _stack_destroy_aicpu +from .matrix_diag_v3 import _matrix_diag_v3_aicpu +from .matrix_diag_part_v3 import _matrix_diag_part_v3_aicpu +from .matrix_set_diag_v3 import _matrix_set_diag_v3_aicpu from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu from .resize_bilinear import _resize_bilinear_aicpu from .resize_bilinear_grad import _resize_bilinear_grad_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py new file mode 100644 index 000000000000..d087fbe307ef --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py @@ -0,0 +1,54 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MatrixDiagPartV3 op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +matrix_diag_part_v3_op_info = AiCPURegOp("MatrixDiagPartV3") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "k", "required") \ + .input(2, "padding_value", "required") \ + .output(0, "y", "required") \ + .attr("align", "str") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, + DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, + DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, + DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, + DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, + DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, + DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, + DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, + DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(matrix_diag_part_v3_op_info) +def _matrix_diag_part_v3_aicpu(): + """MatrixDiagPartV3 AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py new file mode 100644 index 000000000000..5f51682272aa --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py @@ -0,0 +1,56 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MatrixDiagV3 op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +matrix_diag_v3_op_info = AiCPURegOp("MatrixDiagV3") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "k", "required") \ + .input(2, "num_rows", "required") \ + .input(3, "num_cols", "required") \ + .input(4, "padding_value", "required") \ + .output(0, "y", "required") \ + .attr("align", "str") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(matrix_diag_v3_op_info) +def _matrix_diag_v3_aicpu(): + """MatrixDiagV3 AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py new file mode 100644 index 000000000000..cef120d5f1d1 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py @@ -0,0 +1,54 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MatrixSetDiagV3 op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +matrix_set_diag_v3_op_info = AiCPURegOp("MatrixSetDiagV3") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "diagonal", "required") \ + .input(2, "k", "required") \ + .output(0, "y", "required") \ + .attr("align", "str") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, + DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, + DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default, + DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default, + DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, + DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default, + DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, + DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default, + DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, + DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, + DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, + DataType.I32_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(matrix_set_diag_v3_op_info) +def _matrix_set_diag_v3_aicpu(): + """MatrixSetDiagV3 AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index e36be17c5964..63c960e1f863 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -1327,6 +1327,239 @@ class Size(PrimitiveWithInfer): return out +class MatrixDiagV3(Primitive): + r""" + Returns a batched diagonal tensor with given batched diagonal values. + Returns a tensor with the contents in x as k[0]-th to k[1]-th diagonals of a matrix, with everything else padded + with padding_value. num_rows and num_cols specify the dimension of the innermost matrix of the output. Some + diagonals are shorter than max_diag_len and need to be padded. At least one of the num_rows and num_cols is equal to + the calculated value as below. Input k, num_rows, num_cols and padding_value must be const Tensor when taking Graph + mode. + + Args: + align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align + is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT" + aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row). + + Inputs: + - **x** (Tensor) - The diagonal tensor. Rank r, where r >= 1. And its rank must be greater equal than 2 if k + have two values. Moreover, x.shape[-2] must be equal to num_diags calculated by k[1] - k[0] + 1 when its rank + is greater than 1. + - **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to + the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or + a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The + value must be in the range of given or calculated num_rows and num_cols, meaning value of k must be in + (-num_rows, num_cols). + - **num_rows** (Tensor) - A Tensor of type int32. The number of rows of the output matrix. It can be -1 to + indicate that num_rows should be calculated by other inputs. There must be only one value. And it can be + calculated by x.shape[-1] - min(k[1], 0) when specifying num_rows as -1. Moreover, the value must be greater + equal than x.shape[-1] - min(k[1], 0) when its value is not -1. + - **num_cols** (Tensor) - A Tensor of type int32. The number of columns of the output matrix. It can be -1 to + indicate that num_cols should be calculated by other inputs. There must be only one value. And it can be + calculated by x.shape[-1] + max(k[0], 0) when specifying num_cols as -1. Moreover, the value must be greater + equal than x.shape[-1] + max(k[0], 0) when its value is not -1. + - **padding_value** (Tensor) - A Tensor. Have the same dtype as x. The number to fill the area outside the + specified diagonal band with. There must be only one value. + + Outputs: + A Tensor. Has the same type as x. + Let x have r dimensions [I, J, ..., L, M, N]. The output tensor has rank r + 1 with shape + [I, J, ..., L, M, num_rows, num_cols] when only one diagonal is given (k is an integer or k[0] == k[1]). + Otherwise, it has rank r with shape [I, J, ..., L, num_rows, num_cols]. + + Raises: + TypeError: If any input is not Tensor. + TypeError: If input `x` and `padding_value` are not the same dtype. + TypeError: If `k`, `num_rows` or `num_cols` is not int32 dtype. + ValueError: If `align` is not a string or not in the valid range. + ValueError: If rank of `num_rows`, `num_cols` or `padding_value` is not equal to 0. + ValueError: If rank of `k` is not equal to 0 or 1. + ValueError: If rank of `x` is not greater equal to 1. Or the rank of `x` is not greater equal to 2 in case the + size of `k` is 2. + ValueError: If size of `k` is not equal to 1 or 2. + ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2. + ValueError: If the number of rows or columns is too small. + ValueError: If the number of rows or columns is not consistent with the specified `k` and `x`. + ValueError: If the value of `k` is not in (-num_rows, num_cols). + ValueError: If the x.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[8, 9, 0], + ... [1, 2, 3], + ... [0, 4, 5]]), mindspore.float32) + >>> k =Tensor(np.array([-1, 1]), mindspore.int32) + >>> num_rows = Tensor(np.array(3), mindspore.int32) + >>> num_cols = Tensor(np.array(3), mindspore.int32) + >>> padding_value = Tensor(np.array(11), mindspore.float32) + >>> matrix_diag_v3 = ops.MatrixDiagV3(align='LEFT_RIGHT') + >>> output = matrix_diag_v3(x, k, num_rows, num_cols, padding_value) + >>> print(output) + [[ 1. 8. 11.] + [ 4. 2. 9.] + [11. 5. 3.]] + >>> print(output.shape) + (3, 3) + """ + + @prim_attr_register + def __init__(self, align="RIGHT_LEFT"): + """"Initialize MatrixDiagV3""" + self.add_prim_attr("max_length", 200000000) + validator.check_value_type("align", align, [str], self.name) + validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name) + self.init_prim_io_names(inputs=['x', 'k', 'num_rows', 'num_cols', 'padding_value'], outputs=['y']) + + +class MatrixDiagPartV3(Primitive): + r""" + Returns the batched diagonal part of a batched tensor. + Returns a tensor with the k[0]-th to k[1]-th diagonals of the batched x. Some diagonals are shorter than + max_diag_len and need to be padded. Input k and padding_value must be const Tensor when taking Graph mode. + + Args: + align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align + is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT" + aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row). + + Inputs: + - **x** (Tensor) - Rank r, where r >= 2. + - **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to + the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or + a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The + value of k has restructions, meaning value of k must be in (-x.shape[-2], x.shape[-1]). + - **padding_value** (Tensor) - A Tensor. Have the same dtype as x. The number to fill the area outside the + specified diagonal band with. There must be only one value. + + Outputs: + A Tensor. Has the same type as x. + Assume x has r dimensions [I, J, ..., L, M, N]. Let max_diag_len be the maximum length among all + diagonals to be extracted, max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0)) + Let num_diags be the number of diagonals to extract, num_diags = k[1] - k[0] + 1. + If num_diags == 1, the output tensor is of rank r - 1 with shape [I, J, ..., L, max_diag_len] + Otherwise, the output tensor has rank r with dimensions [I, J, ..., L, num_diags, max_diag_len] + + Raises: + TypeError: If any input is not Tensor. + TypeError: If input `x` and `padding_value` are not the same dtype. + TypeError: If `k` is not int32 dtype. + ValueError: If `align` is not a string or not in the valid range. + ValueError: If rank of `k` is not equal to 0 or 1. + ValueError: If rank of `padding_value` is not equal to 0. + ValueError: If rank of `x` is not greater equal to 2. + ValueError: If size of `k` is not equal to 1 or 2. + ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2. + ValueError: If the value of `k` is not in (-x.shape[-2], x.shape[-1]). + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[1, 2, 3, 4], + ... [5, 6, 7, 8], + ... [9, 8, 7, 6]]), mindspore.float32) + >>> k =Tensor(np.array([1, 3]), mindspore.int32) + >>> padding_value = Tensor(np.array(9), mindspore.float32) + >>> matrix_diag_part_v3 = ops.MatrixDiagPartV3(align='RIGHT_LEFT') + >>> output = matrix_diag_part_v3(x, k, padding_value) + >>> print(output) + [[9. 9. 4.] + [9. 3. 8.] + [2. 7. 6.]] + >>> print(output.shape) + (3, 3) + """ + + @prim_attr_register + def __init__(self, align="RIGHT_LEFT"): + """"Initialize MatrixDiagPartV3""" + self.add_prim_attr("max_length", 200000000) + validator.check_value_type("align", align, [str], self.name) + validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name) + self.init_prim_io_names(inputs=['x', 'k', 'padding_value'], outputs=['y']) + + +class MatrixSetDiagV3(Primitive): + r""" + Returns a batched matrix tensor with new batched diagonal values. + Given x and diagonal, this operation returns a tensor with the same shape and values as x, except for the specified + diagonals of the innermost matrices. These will be overwritten by the values in diagonal. Some diagonals are shorter + than max_diag_len and need to be padded. + The diagonal.shape[-2] must be equal to num_diags calculated by k[1] - k[0] + 1. The diagonal.shape[-1] must be + equal to the longest diagonal value max_diag_len calculated by min(x.shape[-2] + min(k[1], 0), x.shape[-1] + + min(-k[0], 0)). Let x have r + 1 dimensions [I, J, ..., L, M, N]. The diagonal tensor has rank r with shape [I, J, + ..., L, max_diag_len] when k is an integer or k[0] == k[1]. Otherwise, it has rank r + 1 with shape [I, J, ..., L, + num_diags, max_diag_len]. + + Args: + align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align + is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT" + aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row). + + Inputs: + - **x** (Tensor) - Rank r + 1, where r >= 1. + - **diagonal** (Tensor) - A Tensor. Have the same dtype as x. Rank r when k is an integer or k[0] == k[1]. + Otherwise, it has rank r + 1. + - **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to + the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or + a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The + value of k has restructions, meaning value of k must be in (-x.shape[-2], x.shape[-1]). Input k must be const + Tensor when taking Graph mode. + + Outputs: + A Tensor. Has the same type as x. + Let x has r+1 dimensions [I, J, ..., L, M, N]. + The output is a tensor of rank k+1 with dimensions [I, J, ..., L, M, N], the same as input x. + + Raises: + TypeError: If any input is not Tensor. + TypeError: If input `x` and `diagonal` are not the same dtype. + TypeError: If `k` is not int32 dtype. + ValueError: If `align` is not a string or not in the valid range. + ValueError: If rank of `k` is not equal to 0 or 1. + ValueError: If rank of `x` is not greater equal to 2. + ValueError: If size of `k` is not equal to 1 or 2. + ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2. + ValueError: If the `diagonal` rank size don't match with input `x` rank size. + ValueError: If the `diagonal` shape value don't match with input `x` shape value. + ValueError: If the diagonal.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1. + ValueError: If the value of `k` is not in (-x.shape[-2], x.shape[-1]). + ValueError: If the diagonal.shape[-1] is not equal to the max_diag_len calculated by min(x.shape[-2] + min(k[1], + 0), x.shape[-1] + min(-k[0], 0)). + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x = Tensor(np.array([[7, 7, 7, 7], + ... [7, 7, 7, 7], + ... [7, 7, 7, 7]]), mindspore.float32) + >>> diagonal = Tensor(np.array([[0, 9, 1], + ... [6, 5, 8], + ... [1, 2, 3], + ... [4, 5, 0]]), mindspore.float32) + >>> k =Tensor(np.array([-1, 2]), mindspore.int32) + >>> matrix_set_diag_v3 = ops.MatrixSetDiagV3(align='RIGHT_LEFT') + >>> output = matrix_set_diag_v3(x, diagonal, k) + >>> print(output) + [[1. 6. 9. 7.] + [4. 2. 5. 1.] + [7. 5. 3. 8.]] + >>> print(output.shape) + (3, 4) + """ + + @prim_attr_register + def __init__(self, align="RIGHT_LEFT"): + """"Initialize MatrixSetDiagV3""" + self.add_prim_attr("max_length", 200000000) + validator.check_value_type("align", align, [str], self.name) + validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name) + self.init_prim_io_names(inputs=['x', 'diagonal', 'k'], outputs=['y']) + + class Fill(PrimitiveWithInfer): """ Create a Tensor of the specified shape and fill it with the specified value. diff --git a/tests/st/scipy_st/matrix_diag_part_test.py b/tests/st/scipy_st/matrix_diag_part_test.py deleted file mode 100644 index e152ab148d3e..000000000000 --- a/tests/st/scipy_st/matrix_diag_part_test.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""st for scipy.ops_wrapper.""" -import numpy -import pytest -import mindspore.scipy as msp -from mindspore import context, Tensor -from mindspore import dtype -from tests.st.scipy_st.utils import match_array - -aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"} -PAD_VALUE = -1 - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -@pytest.mark.parametrize('array_dict', [([[[5]]], {}), - ([[[3, 1, 1], [6, 4, 4], [1, 6, 4]]], - {(-2, -2, 0): [[1]], (-2, -1, 3): [[[6, 6], [-1, 1]]], - (-2, 0, 2): [[[3, 4, 4], [6, 6, -1], [1, -1, -1]]], - (-2, 1, 3): [[[-1, 1, 4], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]], - (-2, 2, 0): [[[1, -1, -1], [1, 4, -1], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]], - (-1, -1, 2): [[6, 6]], (-1, 0, 1): [[[3, 4, 4], [6, 6, -1]]], - (-1, 1, 2): [[[-1, 1, 4], [3, 4, 4], [6, 6, -1]]], - (-1, 2, 3): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4], [-1, 6, 6]]], - (0, 0, 0): [[3, 4, 4]], (0, 1, 1): [[[1, 4, -1], [3, 4, 4]]], - (0, 2, 2): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4]]], (1, 1, 2): [[1, 4]], - (1, 2, 3): [[[-1, 1], [1, 4]]]}), - ([[[6, 1]]], {}), - ([[[2, 2, 4, 3, 0], [8, 5, 3, 0, 3], [6, 3, 2, 6, 7]]], - {(-2, -2, 0): [[6]], (-2, -1, 3): [[[8, 3], [-1, 6]]], - (-2, 0, 2): [[[2, 5, 2], [8, 3, -1], [6, -1, -1]]], - (-2, 1, 3): [[[2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]], - (-2, 2, 0): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]], - (-2, 3, 1): [ - [[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], [6, -1, -1]]], - (-2, 4, 2): [ - [[-1, -1, 0], [-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], - [6, -1, -1]]], (-1, -1, 2): [[8, 3]], - (-1, 0, 1): [[[2, 5, 2], [8, 3, -1]]], - (-1, 1, 2): [[[2, 3, 6], [2, 5, 2], [8, 3, -1]]], - (-1, 2, 3): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]], - (-1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]], - (-1, 4, 1): [ - [[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1]]], - (0, 0, 0): [[2, 5, 2]], (0, 1, 1): [[[2, 3, 6], [2, 5, 2]]], - (0, 2, 2): [[[4, 0, 7], [2, 3, 6], [2, 5, 2]]], - (0, 3, 3): [[[-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2]]], - (0, 4, 0): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2]]], - (1, 1, 2): [[2, 3, 6]], (1, 2, 3): [[[4, 0, 7], [2, 3, 6]]], - (1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6]]], - (1, 4, 1): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6]]]}), - ([[[5], [5]]], {(-1, -1, 2): [[5]], (-1, 0, 1): [[[5], [5]]], - (0, 0, 0): [[5]]}), - ([[[2, 4, 1], [6, 4, 1], [0, 5, 2], [1, 6, 0], [1, 0, 7]]], - {(-4, -4, 0): [[1]], (-4, -3, 3): [[[1, 0], [-1, 1]]], - (-4, -2, 2): [[[0, 6, 7], [1, 0, -1], [1, -1, -1]]], - (-4, -1, 1): [[[6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]], - (-4, 0, 0): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0], [-1, -1, 1]]], - (-4, 1, 1): [ - [[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]], - (-4, 2, 2): [ - [[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], - [1, -1, -1]]], (-3, -3, 2): [[1, 0]], - (-3, -2, 1): [[[0, 6, 7], [1, 0, -1]]], - (-3, -1, 0): [[[6, 5, 0], [0, 6, 7], [-1, 1, 0]]], - (-3, 0, 3): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]], - (-3, 1, 0): [[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]], - (-3, 2, 1): [ - [[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1]]], - (-2, -2, 0): [[0, 6, 7]], (-2, -1, 3): [[[6, 5, 0], [0, 6, 7]]], - (-2, 0, 2): [[[2, 4, 2], [6, 5, 0], [0, 6, 7]]], - (-2, 1, 3): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]], - (-2, 2, 0): [[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]], - (-1, -1, 2): [[6, 5, 0]], (-1, 0, 1): [[[2, 4, 2], [6, 5, 0]]], - (-1, 1, 2): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0]]], - (-1, 2, 3): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0]]], - (0, 0, 0): [[2, 4, 2]], (0, 1, 1): [[[4, 1, -1], [2, 4, 2]]], - (0, 2, 2): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2]]], (1, 1, 2): [[4, 1]], - (1, 2, 3): [[[-1, 1], [4, 1]]], (2, 2, 0): [[1]]}), - ([[[6]], [[4]]], {}), - ([[[2, 4, 8], [3, 4, 2], [1, 6, 3]], [[6, 7, 2], [8, 2, 1], [4, 5, 5]]], - {(-2, -2, 0): [[1], [4]], (-2, -1, 3): [[[3, 6], [-1, 1]], [[8, 5], [-1, 4]]], - (-2, 0, 2): [[[2, 4, 3], [3, 6, -1], [1, -1, -1]], - [[6, 2, 5], [8, 5, -1], [4, -1, -1]]], - (-2, 1, 3): [[[-1, 4, 2], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]], - [[-1, 7, 1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]], - (-2, 2, 0): [[[8, -1, -1], [4, 2, -1], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]], - [[2, -1, -1], [7, 1, -1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]], - (-1, -1, 2): [[3, 6], [8, 5]], - (-1, 0, 1): [[[2, 4, 3], [3, 6, -1]], [[6, 2, 5], [8, 5, -1]]], - (-1, 1, 2): [[[-1, 4, 2], [2, 4, 3], [3, 6, -1]], - [[-1, 7, 1], [6, 2, 5], [8, 5, -1]]], - (-1, 2, 3): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3], [-1, 3, 6]], - [[-1, -1, 2], [-1, 7, 1], [6, 2, 5], [-1, 8, 5]]], - (0, 0, 0): [[2, 4, 3], [6, 2, 5]], - (0, 1, 1): [[[4, 2, -1], [2, 4, 3]], [[7, 1, -1], [6, 2, 5]]], - (0, 2, 2): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3]], - [[-1, -1, 2], [-1, 7, 1], [6, 2, 5]]], - (1, 1, 2): [[4, 2], [7, 1]], - (1, 2, 3): [[[-1, 8], [4, 2]], [[-1, 2], [7, 1]]]}), - ([[[4, 0]], [[7, 4]]], {}), - ([[[3, 5, 8, 3, 5], [7, 8, 1, 0, 6], [5, 4, 0, 3, 6]], - [[7, 4, 8, 7, 3], [4, 6, 5, 7, 1], [5, 3, 1, 1, 0]]], - {(-2, -2, 0): [[5], [5]], (-2, -1, 3): [[[7, 4], [-1, 5]], [[4, 3], [-1, 5]]], - (-2, 0, 2): [[[3, 8, 0], [7, 4, -1], [5, -1, -1]], - [[7, 6, 1], [4, 3, -1], [5, -1, -1]]], - (-2, 1, 3): [[[5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]], - [[4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]], - (-2, 2, 0): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]], - [[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]], - (-2, 3, 1): [ - [[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], [5, -1, -1]], - [[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], [5, -1, -1]]], - (-2, 4, 2): [ - [[-1, -1, 5], [-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], - [5, -1, -1]], - [[-1, -1, 3], [-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], - [5, -1, -1]]], - (-1, -1, 2): [[7, 4], [4, 3]], - (-1, 0, 1): [[[3, 8, 0], [7, 4, -1]], [[7, 6, 1], [4, 3, -1]]], - (-1, 1, 2): [[[5, 1, 3], [3, 8, 0], [7, 4, -1]], - [[4, 5, 1], [7, 6, 1], [4, 3, -1]]], - (-1, 2, 3): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]], - [[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]], - (-1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]], - [[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]], - (-1, 4, 1): [ - [[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1]], - [[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1]]], - (0, 0, 0): [[3, 8, 0], [7, 6, 1]], - (0, 1, 1): [[[5, 1, 3], [3, 8, 0]], [[4, 5, 1], [7, 6, 1]]], - (0, 2, 2): [[[8, 0, 6], [5, 1, 3], [3, 8, 0]], - [[8, 7, 0], [4, 5, 1], [7, 6, 1]]], - (0, 3, 3): [[[-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0]], - [[-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]], - (0, 4, 0): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0]], - [[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]], - (1, 1, 2): [[5, 1, 3], [4, 5, 1]], - (1, 2, 3): [[[8, 0, 6], [5, 1, 3]], [[8, 7, 0], [4, 5, 1]]], - (1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3]], - [[7, 1, -1], [8, 7, 0], [4, 5, 1]]], - (1, 4, 1): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3]], - [[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1]]]}), - ([[[4], [7]], [[3], [5]]], - {(-1, -1, 2): [[7], [5]], (-1, 0, 1): [[[4], [7]], [[3], [5]]], - (0, 0, 0): [[4], [3]]}), - ([[[0, 2, 2], [0, 0, 5], [6, 5, 5], [5, 8, 5], [3, 8, 0]], - [[2, 8, 3], [4, 4, 1], [0, 4, 2], [0, 7, 0], [0, 7, 4]]], - {(-4, -4, 0): [[3], [0]], (-4, -3, 3): [[[5, 8], [-1, 3]], [[0, 7], [-1, 0]]], - (-4, -2, 2): [[[6, 8, 0], [5, 8, -1], [3, -1, -1]], - [[0, 7, 4], [0, 7, -1], [0, -1, -1]]], - (-4, -1, 1): [[[0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]], - [[4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]], - (-4, 0, 0): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8], [-1, -1, 3]], - [[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7], [-1, -1, 0]]], - (-4, 1, 1): [ - [[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]], - [[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]], - (-4, 2, 2): [ - [[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], - [3, -1, -1]], - [[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], - [0, -1, -1]]], (-3, -3, 2): [[5, 8], [0, 7]], - (-3, -2, 1): [[[6, 8, 0], [5, 8, -1]], [[0, 7, 4], [0, 7, -1]]], - (-3, -1, 0): [[[0, 5, 5], [6, 8, 0], [-1, 5, 8]], - [[4, 4, 0], [0, 7, 4], [-1, 0, 7]]], - (-3, 0, 3): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]], - [[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]], - (-3, 1, 0): [[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]], - [[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]], - (-3, 2, 1): [ - [[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1]], - [[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1]]], - (-2, -2, 0): [[6, 8, 0], [0, 7, 4]], - (-2, -1, 3): [[[0, 5, 5], [6, 8, 0]], [[4, 4, 0], [0, 7, 4]]], - (-2, 0, 2): [[[0, 0, 5], [0, 5, 5], [6, 8, 0]], - [[2, 4, 2], [4, 4, 0], [0, 7, 4]]], - (-2, 1, 3): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0]], - [[-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]], - (-2, 2, 0): [[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0]], - [[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]], - (-1, -1, 2): [[0, 5, 5], [4, 4, 0]], - (-1, 0, 1): [[[0, 0, 5], [0, 5, 5]], [[2, 4, 2], [4, 4, 0]]], - (-1, 1, 2): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5]], - [[-1, 8, 1], [2, 4, 2], [4, 4, 0]]], - (-1, 2, 3): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5]], - [[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0]]], - (0, 0, 0): [[0, 0, 5], [2, 4, 2]], - (0, 1, 1): [[[2, 5, -1], [0, 0, 5]], [[8, 1, -1], [2, 4, 2]]], - (0, 2, 2): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5]], - [[-1, -1, 3], [-1, 8, 1], [2, 4, 2]]], - (1, 1, 2): [[2, 5], [8, 1]], - (1, 2, 3): [[[-1, 2], [2, 5]], [[-1, 3], [8, 1]]], (2, 2, 0): [[2], [3]]})]) -def test_matrix_diag_part(array_dict): - """ - testcase generate from below - from tf.python.ops import array_ops - import numpy as np - f = open (r'dict.tst','w') - aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"} - Adict=[] - for i in [1, 2]: - for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]: - A = np.array(np.random.randint(20, size=(i, m, n))) - kadict={} - for k0 in range(-m + 1, m - 1): - for k1 in range(k0, n): - k = (k0, k1) - align_= (abs(k0)+ abs(k1)) % 4 - ka = (k,align_) - B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1) - kadict[ka] = B.numpy() - Adict.append(A, kadict) - print(Adict, file= f) - f.close() - Feature: ALL To ALL - Description: - Expectation: the result match to numpy - """ - context.set_context(mode=context.PYNATIVE_MODE) - a, kadict = array_dict - for key1, b in kadict.items(): - k0, k1, align_ = key1 - if k0 == k1: - r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict.get(align_)) - else: - r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict.get(align_)) - match_array(b, r_b.asnumpy()) - - -def test_matrix_diag_part_valid(): - """ - test case for pad different type - Description: test cases for default/none default padding value, if padding value type not eq to a, - will raise exception - Expectation: the result match to numpy - """ - context.set_context(mode=context.PYNATIVE_MODE) - a = [[1, 2, 3], [3, 4, 5], [4, 5, 6]] - padding_value = 0 - k = [-1, 0] - b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.float32), k, padding_value, align="LEFT_RIGHT") - match_array(b.asnumpy(), numpy.array([[1.0, 4.0, 6.0], [0.0, 3.0, 5.0]]).astype(numpy.float32)) - b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.int32), k, padding_value, align="LEFT_RIGHT") - match_array(b.asnumpy(), numpy.array([[1, 4, 6], [0, 3, 5]]).astype(numpy.int32)) - b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.float32), k, padding_value=1.1, align="LEFT_RIGHT") - match_array(b.asnumpy(), numpy.array([[1.0, 4.0, 6.0], [1.1, 3.0, 5.0]]).astype(numpy.float32)) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 5410cd953dc0..764db6cfbc50 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -35,6 +35,9 @@ from mindspore.ops.operations import nn_ops as nps from mindspore.ops.operations.array_ops import Tril from mindspore.ops.operations.random_ops import NonDeterministicInts from mindspore.ops.operations.array_ops import Triu +from mindspore.ops.operations.array_ops import MatrixDiagV3 +from mindspore.ops.operations.array_ops import MatrixDiagPartV3 +from mindspore.ops.operations.array_ops import MatrixSetDiagV3 from mindspore.ops.operations.nn_ops import FractionalMaxPool from mindspore.ops.operations._grad_ops import FractionalMaxPoolGrad from mindspore.nn.layer import normalization @@ -1066,6 +1069,40 @@ class ApplyAdagradDANet(nn.Cell): return out +class MatrixDiagV3Net(nn.Cell): + def __init__(self, k, num_rows, num_cols, padding_value, align='LEFT_RIGHT'): + super(MatrixDiagV3Net, self).__init__() + self.k = k + self.num_rows = num_rows + self.num_cols = num_cols + self.padding_value = padding_value + self.matrix_diag_v3 = MatrixDiagV3(align=align) + + def construct(self, x, k, num_rows, num_cols, padding_value): + return self.matrix_diag_v3(x, self.k, self.num_rows, self.num_cols, self.padding_value) + + +class MatrixDiagPartV3Net(nn.Cell): + def __init__(self, k, padding_value, align='LEFT_RIGHT'): + super(MatrixDiagPartV3Net, self).__init__() + self.k = k + self.padding_value = padding_value + self.matrix_diag_dart_v3 = MatrixDiagPartV3(align=align) + + def construct(self, x, k, padding_value): + return self.matrix_diag_dart_v3(x, self.k, self.padding_value) + + +class MatrixSetDiagV3Net(nn.Cell): + def __init__(self, k, align='LEFT_RIGHT'): + super(MatrixSetDiagV3Net, self).__init__() + self.k = k + self.matrix_set_diag_v3 = MatrixSetDiagV3(align=align) + + def construct(self, x, diagonal, k): + return self.matrix_set_diag_v3(x, diagonal, self.k) + + class SparseApplyRMSPropNet(nn.Cell): def __init__(self, rho, momentum, epsilon, use_locking=False): super(SparseApplyRMSPropNet, self).__init__() @@ -2807,6 +2844,69 @@ test_case_array_ops = [ Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], 'skip': ['backward'], }), + ('MatrixDiagV3', { + 'block': MatrixDiagV3Net(k=Tensor(np.array([-1, 1]), mstype.int32), num_rows=Tensor(np.array(3), mstype.int32) + , num_cols=Tensor(np.array(3), mstype.int32), + padding_value=Tensor(np.array(11), mstype.float32), align='LEFT_RIGHT'), + 'desc_inputs': [Tensor(np.array([[[8, 9, 0], + [1, 2, 3], + [0, 4, 5]], + [[2, 3, 0], + [6, 7, 9], + [0, 9, 1]]]), mstype.float32), + Tensor(np.array([-1, 1]), mstype.int32), + Tensor(np.array(3), mstype.int32), + Tensor(np.array(3), mstype.int32), + Tensor(np.array(11), mstype.float32)], + 'desc_bprop': [(Tensor(np.array([[[1, 8, 11], + [4, 2, 9], + [11, 5, 3]], + [[6, 2, 11], + [9, 7, 3], + [11, 1, 9]]]), mstype.float32))], + }), + ('MatrixDiagPartV3', { + 'block': MatrixDiagPartV3Net(k=Tensor(np.array([1, 3]), mstype.int32), + padding_value=Tensor(np.array(9), mstype.float32), align='RIGHT_LEFT'), + 'desc_inputs': [Tensor(np.array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 8, 7, 6]], + [[5, 4, 3, 2], + [1, 2, 3, 4], + [5, 6, 7, 8]]]), mstype.float32), + Tensor(np.array([1, 3]), mstype.int32), + Tensor(np.array(9), mstype.float32)], + 'desc_bprop': [(Tensor(np.array([[[9, 9, 4], + [9, 3, 8], + [2, 7, 6]], + [[9, 9, 2], + [9, 3, 4], + [4, 3, 8]]]), mstype.float32))], + }), + ('MatrixSetDiagV3', { + 'block': MatrixSetDiagV3Net(k=Tensor(np.array([-1, 2]), mstype.int32), align='RIGHT_LEFT'), + 'desc_inputs': [Tensor(np.array([[[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]], + [[7, 7, 7, 7], + [7, 7, 7, 7], + [7, 7, 7, 7]]]), mstype.float32), + Tensor(np.array([[[0, 9, 1], + [6, 5, 8], + [1, 2, 3], + [4, 5, 0]], + [[0, 1, 2], + [5, 6, 4], + [6, 1, 2], + [3, 4, 0]]]), mstype.float32), + Tensor(np.array([-1, 2]), mstype.int32)], + 'desc_bprop': [(Tensor(np.array([[[1, 6, 9, 7], + [4, 2, 5, 1], + [7, 5, 3, 8]], + [[6, 5, 1, 7], + [3, 1, 6, 2], + [7, 4, 2, 4]]]), mstype.float32))], + }), ('TransShape', { 'block': P.TransShape(), 'desc_const': [(1, 12, 24, 24)], -- Gitee