Merge pull request !29352 from wuwenbing/masterfeature/build-system-rewrite
| @@ -25,7 +25,7 @@ void MatrixBandPartCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| dim_size_ = shapes_.size(); | |||
| if (shapes_.size() < kDim2) { | |||
| MS_LOG(EXCEPTION) << "wrong array shape, A should be a matrix max than 2."; | |||
| MS_LOG(EXCEPTION) << "Wrong array shape, A should be a matrix max than 2."; | |||
| } | |||
| m_ = shapes_[dim_size_ - kDim2]; | |||
| n_ = shapes_[dim_size_ - kDim1]; | |||
| @@ -53,7 +53,7 @@ bool MatrixBandPartCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, c | |||
| for (size_t k = 0; k < out_range_size_; k++) { | |||
| for (size_t i = 0; i < std::min(m_, l + n_); i++) { | |||
| const size_t s = i < l ? 0 : i - l; | |||
| // when i = n - u, end is n -1, because end pos is start from 0 | |||
| // When i = n - u, end is n -1, because end pos is start from 0 | |||
| const size_t e = i >= n_ - u ? n_ - 1 : i + u; | |||
| const size_t offset = k * m_ * n_ + i * n_; | |||
| memcpy_s(out_value + offset + s, matrix_size_ * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T)); | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/matrix_diag_part_cpu_kernel.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| void MatrixDiagPartCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| out_shapes_ = shapes_; | |||
| dim_size_ = shapes_.size(); | |||
| if (shapes_.size() < kDim2) { | |||
| MS_LOG(EXCEPTION) << "Wrong array shape, A should be a matrix max than 2."; | |||
| } | |||
| m_ = shapes_[dim_size_ - kDim2]; | |||
| n_ = shapes_[dim_size_ - kDim1]; | |||
| for (size_t i = 0; i < shapes_.size() - kDim2; i++) { | |||
| out_range_size_ *= shapes_[i]; | |||
| } | |||
| // Invalid alignment will throw an exception. | |||
| auto alignment = AnfAlgo::GetNodeAttr<std::string>(kernel_node, ALIGNMENT); | |||
| alignment_ = GetAlignments(alignment); | |||
| node_wpt_ = kernel_node; | |||
| } | |||
| template <typename T> | |||
| bool MatrixDiagPartCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| T *in_value = reinterpret_cast<T *>(inputs[0]->addr); | |||
| // K is 2 elements vector, k[0] is lower part, k[0]<0, k[1] is upper part, | |||
| int64_t *k_range = reinterpret_cast<int64_t *>(inputs[1]->addr); | |||
| T *padding_value = reinterpret_cast<T *>(inputs[2]->addr); | |||
| T *out_value = reinterpret_cast<T *>(outputs[0]->addr); | |||
| int64_t l = k_range[0]; | |||
| int64_t u = k_range[1]; | |||
| // New diagonal matrix m*n matrix, m dimension ; | |||
| if (l > u) { | |||
| MS_LOG(EXCEPTION) << "The k[1] must not less than k[0]."; | |||
| } | |||
| u = std::min(u, n_ - 1); | |||
| l = std::max(-(m_ - 1), l); | |||
| int64_t num_diags = u - l + 1; | |||
| // New diagonal matrix m * n matrix, n dimension | |||
| int64_t max_diag_len = | |||
| std::min(m_ + std::min(u, static_cast<int64_t>(0)), n_ + std::min(-l, static_cast<int64_t>(0))); | |||
| MS_LOG(DEBUG) << "Num_diags:" << num_diags << ",max_diag_len:" << max_diag_len; | |||
| int64_t dest_inner_matrix_size = num_diags * max_diag_len; | |||
| out_shapes_ = shapes_; | |||
| // Set dynamic shape and dtype | |||
| if (!node_wpt_.expired()) { | |||
| auto node_ = node_wpt_.lock(); | |||
| out_shapes_[shapes_.size() - kDim1] = max_diag_len; | |||
| // If the out shape m' * n', the m' dimension is 1, then remove this dimension | |||
| out_shapes_[shapes_.size() - kDim2] = num_diags; | |||
| if (num_diags == 1) { | |||
| out_shapes_.erase(out_shapes_.begin() + shapes_.size() - kDim2); | |||
| } | |||
| auto dtype = AnfAlgo::GetOutputDeviceDataType(node_, 0); | |||
| AnfAlgo::SetOutputInferTypeAndShape({dtype}, {out_shapes_}, node_.get()); | |||
| } | |||
| for (int64_t i = 0; i < out_range_size_; i++) { | |||
| // The j_index means current dest row index | |||
| for (int64_t j = u; j >= l; j--) { | |||
| int64_t current_diag_len = j >= 0 ? std::min(n_ - j, m_) : std::min(m_ + j, n_); | |||
| int64_t current_pad_len = max_diag_len - current_diag_len; | |||
| // Pad left by default | |||
| bool pad_left = (alignment_.first == MatrixDiag::Alignment::RIGHT && j > 0) || | |||
| (alignment_.second == MatrixDiag::Alignment::RIGHT && j < 0); | |||
| // Set none-padding values, l means current diag col index | |||
| for (int64_t k = 0; k < max_diag_len; k++) { | |||
| // Source pos, k offset, only effective when pad left | |||
| int64_t k_offset = (pad_left && k >= current_pad_len) ? k - current_pad_len : k; | |||
| // Calculate source offset row/col offset | |||
| size_t row_index = j >= 0 ? j + k_offset : k_offset; | |||
| size_t col_index = j >= 0 ? k_offset : k_offset - j; | |||
| size_t source_offset = i * m_ * n_ + col_index * n_ + row_index; | |||
| // If current pos need pad, then the value is pad value | |||
| bool current_pad_flag = (pad_left && k < current_pad_len) || (!pad_left && k >= current_diag_len); | |||
| T current_pad_value = current_pad_flag ? *padding_value : *(in_value + source_offset); | |||
| int64_t j_index = u - j; | |||
| size_t dest_offset = dest_inner_matrix_size * i + j_index * max_diag_len + k; | |||
| MS_LOG(DEBUG) << "the diag j:" << j << ",k:" << k << ",k_offset:" << k_offset << ",row:" << row_index | |||
| << ",col:" << col_index << ",j_index:" << j_index << ",current_pad_value:" << current_pad_value; | |||
| *(out_value + dest_offset) = current_pad_value; | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H | |||
| #include <vector> | |||
| #include <complex> | |||
| #include <utility> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class MatrixDiagPartCPUKernel : public CPUKernel { | |||
| public: | |||
| MatrixDiagPartCPUKernel() = default; | |||
| ~MatrixDiagPartCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| // <Super_matrix_diag_align, Sub_matrix_diag_align> | |||
| std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> alignment_{MatrixDiag::RIGHT, MatrixDiag::LEFT}; | |||
| std::vector<size_t> shapes_{}; | |||
| int64_t out_range_size_{1}; | |||
| size_t dim_size_{1}; | |||
| int64_t m_{1}; | |||
| int64_t n_{1}; | |||
| std::vector<size_t> out_shapes_{}; | |||
| CNodeWeakPtr node_wpt_; | |||
| }; | |||
| MS_REG_CPU_KERNEL_T(MatrixDiagPart, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| MatrixDiagPartCPUKernel, int32_t); | |||
| MS_REG_CPU_KERNEL_T(MatrixDiagPart, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64), | |||
| MatrixDiagPartCPUKernel, int64_t); | |||
| MS_REG_CPU_KERNEL_T(MatrixDiagPart, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| MatrixDiagPartCPUKernel, float); | |||
| MS_REG_CPU_KERNEL_T(MatrixDiagPart, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| MatrixDiagPartCPUKernel, double); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_DIAG_PART_H | |||
| @@ -575,6 +575,7 @@ inline const PrimitivePtr kPrimAddcmul = std::make_shared<Primitive>(kAddcmul); | |||
| inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul"); | |||
| inline const PrimitivePtr kPrimMatMulV2 = std::make_shared<Primitive>("MatMulV2"); | |||
| inline const PrimitivePtr kPrimMatrixDiag = std::make_shared<Primitive>("MatrixDiag"); | |||
| inline const PrimitivePtr kPrimMatrixDiagPart = std::make_shared<Primitive>("MatrixDiagPart"); | |||
| inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul"); | |||
| inline const PrimitivePtr kPrimBatchMatMulV2 = std::make_shared<Primitive>("BatchMatMulV2"); | |||
| inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad"); | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/matrix_diag_part.h" | |||
| #include <set> | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| const constexpr int64_t kShape2 = 2; | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto input_shape = input_args[0]->BuildShape(); | |||
| auto shape_element = input_shape->cast<abstract::ShapePtr>(); | |||
| ShapeVector shape = shape_element->shape(); | |||
| ShapeVector min_shape = shape_element->shape(); | |||
| ShapeVector max_shape = shape_element->shape(); | |||
| max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1; | |||
| min_shape[shape.size() - 1] = 1; | |||
| shape[shape.size() - 1] = abstract::Shape::SHP_ANY; | |||
| return std::make_shared<abstract::Shape>(shape, min_shape, max_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| auto infer_type = input_args[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(infer_type); | |||
| const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, kInt64}; | |||
| CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name()); | |||
| return infer_type; | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPart, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_ | |||
| #define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "ops/op_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameMatrixDiagPart = "MatrixDiagPart"; | |||
| /// \brief get the specified part of the inner most diag matrix of a matrix, fill with padding value . | |||
| /// Refer to Python API @ref mindspore.ops.MatrixDiagPart for more details. | |||
| class MatrixDiagPart : public PrimitiveC { | |||
| public: | |||
| /// \brief Constructor. | |||
| MatrixDiagPart() : PrimitiveC(kNameMatrixDiagPart) { InitIOName({"input", "k", "padding_value"}, {"output"}); } | |||
| /// \brief Destructor. | |||
| ~MatrixDiagPart() = default; | |||
| MS_DECLARE_PARENT(MatrixDiagPart, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_ | |||
| @@ -390,4 +390,33 @@ class MatrixBandPartNet(nn.Cell): | |||
| return self.matrix_band_part(a, num_lower, num_upper) | |||
| class MatrixDiagPart(PrimitiveWithInfer): | |||
| """ | |||
| MatrixDiagPart() | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, align="RIGHT_LEFT"): | |||
| super().__init__(name="MatrixDiagPart") | |||
| self.add_prim_attr('alignment', align) | |||
| self.init_prim_io_names(inputs=['A', 'k', 'padding_value'], outputs=['output']) | |||
| class MatrixDiagPartNet(nn.Cell): | |||
| """ | |||
| Returns: | |||
| batched diagonal part of a batched tensor, the part between, k[0] to k[1], the shape is dynamic | |||
| Raises: | |||
| k[1] should not less then k[0] | |||
| """ | |||
| def __init__(self, align="RIGHT_LEFT"): | |||
| super(MatrixDiagPartNet, self).__init__() | |||
| self.matrix_diag_part = MatrixDiagPart(align) | |||
| def construct(self, a, k, padding_value): | |||
| return self.matrix_diag_part(a, k, padding_value) | |||
| from .ops_grad import get_bprpo_eigh | |||
| @@ -0,0 +1,204 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """st for scipy.ops_wrapper.""" | |||
| import pytest | |||
| from mindspore import context, Tensor | |||
| from mindspore.scipy.ops import MatrixDiagPartNet | |||
| from tests.st.scipy_st.utils import match_matrix | |||
| aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"} | |||
| PAD_VALUE = -1 | |||
| Adict = {(1, 1, 1): (([[[1]]]), {}), (1, 3, 3): (([[[8, 2, 1], [5, 3, 7], [0, 3, 4]]]), | |||
| {(-2, -2, 0): ([[0]]), (-2, -1, 3): ([[[5, 3], [-1, 0]]]), | |||
| (-2, 0, 2): ([[[8, 3, 4], [5, 3, -1], [0, -1, -1]]]), | |||
| (-2, 1, 3): ([[[-1, 2, 7], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]), | |||
| (-2, 2, 0): (\ | |||
| [[[1, -1, -1], [2, 7, -1], [8, 3, 4], [-1, 5, 3], [-1, -1, 0]]]), | |||
| (-1, -1, 2): ([[5, 3]]), (-1, 0, 1): ([[[8, 3, 4], [5, 3, -1]]]), | |||
| (-1, 1, 2): ([[[-1, 2, 7], [8, 3, 4], [5, 3, -1]]]), | |||
| (-1, 2, 3): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4], [-1, 5, 3]]]), | |||
| (0, 0, 0): ([[8, 3, 4]]), (0, 1, 1): ([[[2, 7, -1], [8, 3, 4]]]), | |||
| (0, 2, 2): ([[[-1, -1, 1], [-1, 2, 7], [8, 3, 4]]]), | |||
| (1, 1, 2): ([[2, 7]]), (1, 2, 3): ([[[-1, 1], [2, 7]]])}), | |||
| (1, 1, 2): (([[[3, 2]]]), {}), (1, 3, 5): (([[[3, 5, 5, 2, 5], [0, 6, 2, 4, 7], [7, 3, 3, 6, 8]]]), | |||
| {(-2, -2, 0): ([[7]]), (-2, -1, 3): ([[[0, 3], [-1, 7]]]), | |||
| (-2, 0, 2): ([[[3, 6, 3], [0, 3, -1], [7, -1, -1]]]), | |||
| (-2, 1, 3): ([[[5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]), | |||
| (-2, 2, 0): (\ | |||
| [[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3], [-1, -1, 7]]]), | |||
| (-2, 3, 1): ([ | |||
| [[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [0, 3, -1], | |||
| [7, -1, -1]]]), (-2, 4, 2): ([\ | |||
| [[-1, -1, 5], [-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3], | |||
| [0, 3, -1], [7, -1, -1]]]), (-1, -1, 2): ([[0, 3]]), | |||
| (-1, 0, 1): ([[[3, 6, 3], [0, 3, -1]]]), | |||
| (-1, 1, 2): ([[[5, 2, 6], [3, 6, 3], [0, 3, -1]]]), | |||
| (-1, 2, 3): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]), | |||
| (-1, 3, 0): (\ | |||
| [[[2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], [-1, 0, 3]]]), | |||
| (-1, 4, 1): ([ | |||
| [[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3], | |||
| [0, 3, -1]]]), (0, 0, 0): ([[3, 6, 3]]), | |||
| (0, 1, 1): ([[[5, 2, 6], [3, 6, 3]]]), | |||
| (0, 2, 2): ([[[5, 4, 8], [5, 2, 6], [3, 6, 3]]]), | |||
| (0, 3, 3): ([[[-1, 2, 7], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]), | |||
| (0, 4, 0): (\ | |||
| [[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6], [3, 6, 3]]]), | |||
| (1, 1, 2): ([[5, 2, 6]]), (1, 2, 3): ([[[5, 4, 8], [5, 2, 6]]]), | |||
| (1, 3, 0): ([[[2, 7, -1], [5, 4, 8], [5, 2, 6]]]), | |||
| (1, 4, 1): ([[[5, -1, -1], [2, 7, -1], [5, 4, 8], [5, 2, 6]]])}), | |||
| (1, 2, 1): (([[[4], [1]]]), {(-1, -1, 2): ([[1]]), (-1, 0, 1): ([[[4], [1]]]), (0, 0, 0): ([[4]])}), | |||
| (1, 5, 3): (([[[7, 8, 8], [3, 5, 6], [0, 4, 4], [8, 4, 5], [0, 4, 6]]]), | |||
| {(-4, -4, 0): ([[0]]), (-4, -3, 3): ([[[8, 4], [-1, 0]]]), | |||
| (-4, -2, 2): ([[[0, 4, 6], [8, 4, -1], [0, -1, -1]]]), | |||
| (-4, -1, 1): ([[[3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]), | |||
| (-4, 0, 0): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4], [-1, -1, 0]]]), | |||
| (-4, 1, 1): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]), | |||
| (-4, 2, 2): (\ | |||
| [[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1], [0, -1, -1]]]), | |||
| (-3, -3, 2): ([[8, 4]]), (-3, -2, 1): ([[[0, 4, 6], [8, 4, -1]]]), | |||
| (-3, -1, 0): ([[[3, 4, 5], [0, 4, 6], [-1, 8, 4]]]), | |||
| (-3, 0, 3): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]), | |||
| (-3, 1, 0): ([[[8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [-1, 8, 4]]]), | |||
| (-3, 2, 1): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6], [8, 4, -1]]]), | |||
| (-2, -2, 0): ([[0, 4, 6]]), (-2, -1, 3): ([[[3, 4, 5], [0, 4, 6]]]), | |||
| (-2, 0, 2): ([[[7, 5, 4], [3, 4, 5], [0, 4, 6]]]), | |||
| (-2, 1, 3): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]), | |||
| (-2, 2, 0): ([[[8, -1, -1], [8, 6, -1], [7, 5, 4], [3, 4, 5], [0, 4, 6]]]), | |||
| (-1, -1, 2): ([[3, 4, 5]]), (-1, 0, 1): ([[[7, 5, 4], [3, 4, 5]]]), | |||
| (-1, 1, 2): ([[[-1, 8, 6], [7, 5, 4], [3, 4, 5]]]), | |||
| (-1, 2, 3): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4], [3, 4, 5]]]), (0, 0, 0): ([[7, 5, 4]]), | |||
| (0, 1, 1): ([[[8, 6, -1], [7, 5, 4]]]), (0, 2, 2): ([[[-1, -1, 8], [-1, 8, 6], [7, 5, 4]]]), | |||
| (1, 1, 2): ([[8, 6]]), (1, 2, 3): ([[[-1, 8], [8, 6]]]), (2, 2, 0): ([[8]])}), | |||
| (2, 1, 1): (([[[5]], [[6]]]), {}), (2, 3, 3): (\ | |||
| ([[[7, 6, 3], [5, 8, 5], [5, 0, 2]], [[1, 8, 1], [5, 5, 8], [8, 4, 0]]]), | |||
| {(-2, -2, 0): ([[5], [8]]), (-2, -1, 3): ([[[5, 0], [-1, 5]], [[5, 4], [-1, 8]]]), | |||
| (-2, 0, 2): ([[[7, 8, 2], [5, 0, -1], [5, -1, -1]], [[1, 5, 0], [5, 4, -1], [8, -1, -1]]]), | |||
| (-2, 1, 3): ([[[-1, 6, 5], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]], [[-1, 8, 8], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]), | |||
| (-2, 2, 0): ([[[3, -1, -1], [6, 5, -1], [7, 8, 2], [-1, 5, 0], [-1, -1, 5]], | |||
| [[1, -1, -1], [8, 8, -1], [1, 5, 0], [-1, 5, 4], [-1, -1, 8]]]), (-1, -1, 2): ([[5, 0], [5, 4]]), | |||
| (-1, 0, 1): ([[[7, 8, 2], [5, 0, -1]], [[1, 5, 0], [5, 4, -1]]]), | |||
| (-1, 1, 2): ([[[-1, 6, 5], [7, 8, 2], [5, 0, -1]], [[-1, 8, 8], [1, 5, 0], [5, 4, -1]]]), | |||
| (-1, 2, 3): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2], [-1, 5, 0]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0], [-1, 5, 4]]]), | |||
| (0, 0, 0): ([[7, 8, 2], [1, 5, 0]]), (0, 1, 1): ([[[6, 5, -1], [7, 8, 2]], [[8, 8, -1], [1, 5, 0]]]), | |||
| (0, 2, 2): ([[[-1, -1, 3], [-1, 6, 5], [7, 8, 2]], [[-1, -1, 1], [-1, 8, 8], [1, 5, 0]]]), | |||
| (1, 1, 2): ([[6, 5], [8, 8]]), (1, 2, 3): ([[[-1, 3], [6, 5]], [[-1, 1], [8, 8]]])}), | |||
| (2, 1, 2): (([[[6, 3]], [[5, 5]]]), {}), (2, 3, 5): (\ | |||
| ([[[1, 2, 1, 2, 7], [0, 3, 5, 0, 2], [0, 5, 1, 7, 5]], [[3, 4, 3, 5, 7], [2, 5, 2, 7, 5], [7, 5, 1, 1, 7]]]), | |||
| {(-2, -2, 0): ([[0], [7]]), (-2, -1, 3): ([[[0, 5], [-1, 0]], [[2, 5], [-1, 7]]]), | |||
| (-2, 0, 2): ([[[1, 3, 1], [0, 5, -1], [0, -1, -1]], [[3, 5, 1], [2, 5, -1], [7, -1, -1]]]), | |||
| (-2, 1, 3): ([[[2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]], [[4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]), | |||
| (-2, 2, 0): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5], [-1, -1, 0]], | |||
| [[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5], [-1, -1, 7]]]), (-2, 3, 1): (\ | |||
| [[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]], | |||
| [[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]), (-2, 4, 2): (\ | |||
| [[[-1, -1, 7], [-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1], [0, -1, -1]], | |||
| [[-1, -1, 7], [-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1], [7, -1, -1]]]), | |||
| (-1, -1, 2): ([[0, 5], [2, 5]]), (-1, 0, 1): ([[[1, 3, 1], [0, 5, -1]], [[3, 5, 1], [2, 5, -1]]]), | |||
| (-1, 1, 2): ([[[2, 5, 7], [1, 3, 1], [0, 5, -1]], [[4, 2, 1], [3, 5, 1], [2, 5, -1]]]), | |||
| (-1, 2, 3): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]], [[3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]), | |||
| (-1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [-1, 0, 5]], | |||
| [[5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [-1, 2, 5]]]), (-1, 4, 1): (\ | |||
| [[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1], [0, 5, -1]], | |||
| [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1], [2, 5, -1]]]), (0, 0, 0): ([[1, 3, 1], [3, 5, 1]]), | |||
| (0, 1, 1): ([[[2, 5, 7], [1, 3, 1]], [[4, 2, 1], [3, 5, 1]]]), | |||
| (0, 2, 2): ([[[1, 0, 5], [2, 5, 7], [1, 3, 1]], [[3, 7, 7], [4, 2, 1], [3, 5, 1]]]), | |||
| (0, 3, 3): ([[[-1, 2, 2], [1, 0, 5], [2, 5, 7], [1, 3, 1]], [[-1, 5, 5], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]), | |||
| (0, 4, 0): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7], [1, 3, 1]], | |||
| [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1], [3, 5, 1]]]), (1, 1, 2): ([[2, 5, 7], [4, 2, 1]]), | |||
| (1, 2, 3): ([[[1, 0, 5], [2, 5, 7]], [[3, 7, 7], [4, 2, 1]]]), | |||
| (1, 3, 0): ([[[2, 2, -1], [1, 0, 5], [2, 5, 7]], [[5, 5, -1], [3, 7, 7], [4, 2, 1]]]), | |||
| (1, 4, 1): ([[[7, -1, -1], [2, 2, -1], [1, 0, 5], [2, 5, 7]], [[7, -1, -1], [5, 5, -1], [3, 7, 7], [4, 2, 1]]])}), | |||
| (2, 2, 1): (([[[4], [8]], [[3], [5]]]), | |||
| {(-1, -1, 2): ([[8], [5]]), (-1, 0, 1): ([[[4], [8]], [[3], [5]]]), (0, 0, 0): ([[4], [3]])}), | |||
| (2, 5, 3): (([[[6, 8, 5], [7, 2, 7], [2, 2, 5], [5, 6, 7], [5, 0, 2]], | |||
| [[3, 8, 7], [7, 8, 2], [8, 1, 0], [0, 6, 5], [6, 3, 1]]]), | |||
| {(-4, -4, 0): ([[5], [6]]), (-4, -3, 3): ([[[5, 0], [-1, 5]], [[0, 3], [-1, 6]]]), | |||
| (-4, -2, 2): ([[[2, 6, 2], [5, 0, -1], [5, -1, -1]], [[8, 6, 1], [0, 3, -1], [6, -1, -1]]]), | |||
| (-4, -1, 1): ([[[7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]], | |||
| [[7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 0, 0): (\ | |||
| [[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0], [-1, -1, 5]], | |||
| [[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3], [-1, -1, 6]]]), (-4, 1, 1): (\ | |||
| [[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]], | |||
| [[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), (-4, 2, 2): (\ | |||
| [[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1], [5, -1, -1]], | |||
| [[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1], [6, -1, -1]]]), | |||
| (-3, -3, 2): ([[5, 0], [0, 3]]), | |||
| (-3, -2, 1): ([[[2, 6, 2], [5, 0, -1]], [[8, 6, 1], [0, 3, -1]]]), | |||
| (-3, -1, 0): ([[[7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), | |||
| (-3, 0, 3): (\ | |||
| [[[6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]], [[3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), | |||
| (-3, 1, 0): ([[[8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [-1, 5, 0]], | |||
| [[8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [-1, 0, 3]]]), (-3, 2, 1): (\ | |||
| [[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2], [5, 0, -1]], | |||
| [[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1], [0, 3, -1]]]), | |||
| (-2, -2, 0): ([[2, 6, 2], [8, 6, 1]]), | |||
| (-2, -1, 3): ([[[7, 2, 7], [2, 6, 2]], [[7, 1, 5], [8, 6, 1]]]), | |||
| (-2, 0, 2): ([[[6, 2, 5], [7, 2, 7], [2, 6, 2]], [[3, 8, 0], [7, 1, 5], [8, 6, 1]]]), | |||
| (-2, 1, 3): (\ | |||
| [[[-1, 8, 7], [6, 2, 5], [7, 2, 7], [2, 6, 2]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]), | |||
| (-2, 2, 0): ([[[5, -1, -1], [8, 7, -1], [6, 2, 5], [7, 2, 7], [2, 6, 2]], | |||
| [[7, -1, -1], [8, 2, -1], [3, 8, 0], [7, 1, 5], [8, 6, 1]]]), | |||
| (-1, -1, 2): ([[7, 2, 7], [7, 1, 5]]), | |||
| (-1, 0, 1): ([[[6, 2, 5], [7, 2, 7]], [[3, 8, 0], [7, 1, 5]]]), | |||
| (-1, 1, 2): ([[[-1, 8, 7], [6, 2, 5], [7, 2, 7]], [[-1, 8, 2], [3, 8, 0], [7, 1, 5]]]), | |||
| (-1, 2, 3): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5], [7, 2, 7]], | |||
| [[-1, -1, 7], [-1, 8, 2], [3, 8, 0], [7, 1, 5]]]), | |||
| (0, 0, 0): ([[6, 2, 5], [3, 8, 0]]), | |||
| (0, 1, 1): ([[[8, 7, -1], [6, 2, 5]], [[8, 2, -1], [3, 8, 0]]]), | |||
| (0, 2, 2): ([[[-1, -1, 5], [-1, 8, 7], [6, 2, 5]], [[-1, -1, 7], [-1, 8, 2], [3, 8, 0]]]), | |||
| (1, 1, 2): ([[8, 7], [8, 2]]), (1, 2, 3): ([[[-1, 5], [8, 7]], [[-1, 7], [8, 2]]]), | |||
| (2, 2, 0): ([[5], [7]])})} | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_matrix_diag_part_net_cpu(): | |||
| """ | |||
| testcase generate from below | |||
| from tensorflow.python.ops import array_ops | |||
| import numpy as np | |||
| f = open (r'dict.tst','w') | |||
| aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"} | |||
| Adict={} | |||
| for i in [1, 2]: | |||
| for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]: | |||
| A = np.array(np.random.randint(20, size=(i, m, n))) | |||
| Adict[i,m,n]=A | |||
| p = -1 | |||
| kadict={} | |||
| for k0 in range(-m + 1, m - 1): | |||
| for k1 in range(k0, n): | |||
| k = (k0, k1) | |||
| align_= (abs(k0)+ abs(k1)) % 4 | |||
| ka = (k,align_) | |||
| B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1) | |||
| kadict[ka] = B.numpy() | |||
| Adict[i,m,n]=(A, kadict) | |||
| print(Adict, file= f) | |||
| f.close() | |||
| Feature: ALL To ALL | |||
| Description: test cases for eigen decomposition test cases for Ax= lambda * x /( A- lambda * E)X=0 | |||
| Expectation: the result match to numpy | |||
| """ | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| for _, value in Adict.items(): | |||
| a, kadict = value | |||
| for key1, b in kadict.items(): | |||
| k0, k1, align_ = key1 | |||
| msp_matrixdiagpart = MatrixDiagPartNet(align=aligndict[align_]) | |||
| r_b = msp_matrixdiagpart(Tensor(a), Tensor([k0, k1]), Tensor(PAD_VALUE)) | |||
| match_matrix(Tensor(b), Tensor(r_b)) | |||