From: @caojian05 Reviewed-by: @wuxuejian,@oacjiewen Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -167,6 +167,24 @@ void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] > input2[idx[1]]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] >= input2[idx[1]]; | |||
| } | |||
| } | |||
| void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -190,6 +208,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = EQUAL; | |||
| } else if (kernel_name == prim::kPrimNotEqual->name()) { | |||
| operate_type_ = NOTEQUAL; | |||
| } else if (kernel_name == prim::kPrimGreater->name()) { | |||
| operate_type_ = GREATER; | |||
| } else if (kernel_name == prim::kPrimGreaterEqual->name()) { | |||
| operate_type_ = GREATEREQUAL; | |||
| } else if (kernel_name == prim::kPrimAssignAdd->name()) { | |||
| operate_type_ = ASSIGNADD; | |||
| } else if (kernel_name == prim::kPrimSquaredDifference->name()) { | |||
| @@ -301,6 +323,11 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == NOTEQUAL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == GREATER) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == GREATEREQUAL) { | |||
| threads.emplace_back( | |||
| std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| @@ -63,6 +63,10 @@ class ArithmeticCPUKernel : public CPUKernel { | |||
| void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| std::vector<size_t> input_shape0_; | |||
| std::vector<size_t> input_shape1_; | |||
| std::vector<size_t> input_element_num0_; | |||
| @@ -213,6 +217,28 @@ MS_REG_CPU_KERNEL( | |||
| SquaredDifference, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Greater, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| GreaterEqual, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| GreaterEqual, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| GreaterEqual, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -53,6 +53,9 @@ const char END[] = "end"; | |||
| const char SIZE[] = "size"; | |||
| const char USE_NESTEROV[] = "use_nesterov"; | |||
| const char GROUP[] = "group"; | |||
| const char START[] = "start"; | |||
| const char LIMIT[] = "limit"; | |||
| const char DELTA[] = "delta"; | |||
| enum OperateType { | |||
| ADD = 0, | |||
| @@ -79,7 +82,9 @@ enum OperateType { | |||
| EQUAL, | |||
| NOTEQUAL, | |||
| FLOOR, | |||
| SQUAREDDIFFERENCE | |||
| SQUAREDDIFFERENCE, | |||
| GREATER, | |||
| GREATEREQUAL, | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/gathernd_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void GatherNdCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| // ReShape() | |||
| size_t dim_of_indices = 1; | |||
| for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); ++i) { | |||
| dim_of_indices *= indices_shapes_[i]; | |||
| } | |||
| size_t dim_after_indices = 1; | |||
| size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)]; | |||
| for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) { | |||
| dim_after_indices *= input_shapes_[i]; | |||
| } | |||
| dims_.emplace_back(dim_of_indices); | |||
| dims_.emplace_back(dim_after_indices); | |||
| dims_.emplace_back(dim_indices_last); | |||
| batch_strides_.resize(dim_indices_last, 0); | |||
| batch_indices_.resize(dim_indices_last, 0); | |||
| if (dim_indices_last > 0) { | |||
| batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1]; | |||
| batch_indices_[dim_indices_last - 1] = dims_[1]; | |||
| } | |||
| for (size_t i = dim_indices_last - 1; i > 0; --i) { | |||
| batch_strides_[i - 1] = input_shapes_[i - 1]; | |||
| batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; | |||
| } | |||
| } | |||
| bool GatherNdCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32) { | |||
| return LaunchKernel<int32_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| return LaunchKernel<int64_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| return LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat64) { | |||
| return LaunchKernel<double>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_); | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool GatherNdCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr); | |||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| // | |||
| size_t output_dim0 = dims_[0]; | |||
| size_t output_dim1 = dims_[1]; | |||
| size_t indices_dim1 = dims_[2]; | |||
| int num = output_dim0 * output_dim1; | |||
| for (int write_index = 0; write_index < num; write_index++) { | |||
| int i = write_index / output_dim1 % output_dim0; | |||
| int j = write_index % output_dim1; | |||
| int read_index = 0; | |||
| for (size_t k = 0; k < indices_dim1; k++) { | |||
| size_t ind = indices_dim1 * i + k; | |||
| int indices_i = indices_addr[ind]; | |||
| read_index += indices_i * batch_indices_[k]; | |||
| } | |||
| read_index += j; | |||
| output_addr[write_index] = input_addr[read_index]; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class GatherNdCPUKernel : public CPUKernel { | |||
| public: | |||
| GatherNdCPUKernel() = default; | |||
| ~GatherNdCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| std::vector<size_t> input_shapes_; | |||
| std::vector<size_t> indices_shapes_; | |||
| std::vector<size_t> output_shapes_; | |||
| std::vector<size_t> dims_; | |||
| std::vector<int> batch_indices_; | |||
| std::vector<int> batch_strides_; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| GatherNdCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), | |||
| GatherNdCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| GatherNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| GatherNdCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| GatherNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherNdCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERND_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/range_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void RangeCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| start_ = AnfAlgo::GetNodeAttr<float>(kernel_node, START); | |||
| limit_ = AnfAlgo::GetNodeAttr<float>(kernel_node, LIMIT); | |||
| delta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, DELTA); | |||
| } | |||
| bool RangeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32) { | |||
| return LaunchKernel<int32_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| return LaunchKernel<int64_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| return LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat64) { | |||
| return LaunchKernel<double>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_); | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool RangeCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t elem_num = outputs[0]->size / sizeof(T); | |||
| for (size_t i = 0; i < elem_num; i++) { | |||
| output_addr[i] = start_ + i * delta_; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class RangeCPUKernel : public CPUKernel { | |||
| public: | |||
| RangeCPUKernel() = default; | |||
| ~RangeCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| TypeId dtype_{kTypeUnknown}; | |||
| int64_t start_; | |||
| int64_t limit_; | |||
| int64_t delta_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), RangeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), RangeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| RangeCPUKernel); | |||
| MS_REG_CPU_KERNEL(Range, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| RangeCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RANGE_CPU_KERNEL_H_ | |||
| @@ -116,7 +116,7 @@ class Range(Cell): | |||
| Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> net = nn.Range(1, 8, 2) | |||
| @@ -3078,7 +3078,7 @@ class GatherNd(PrimitiveWithInfer): | |||
| Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:]. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) | |||
| @@ -2698,7 +2698,7 @@ class Greater(_LogicBinaryOp): | |||
| Tensor, the shape is the same as the one after broadcasting,and the data type is bool. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) | |||
| @@ -2739,7 +2739,7 @@ class GreaterEqual(_LogicBinaryOp): | |||
| Tensor, the shape is the same as the one after broadcasting,and the data type is bool. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) | |||
| @@ -0,0 +1,188 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case1_basic_func(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | |||
| params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [0, 3] | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case2_indices_to_matrix(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[1], [0]]), mindspore.int32) | |||
| params = Tensor(np.array([[0, 1], [2, 3]]), mindspore.float32) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[2, 3], [0, 1]] | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case3_indices_to_3d_tensor(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[1]]), mindspore.int32) # (1, 1) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[[4, 5], [6, 7]]] # (1, 2, 2) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case4(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[0, 1], [1, 0]]), mindspore.int32) # (2, 2) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[2, 3], [4, 5]] # (2, 2) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case5(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[0, 0, 1], [1, 0, 1]]), mindspore.int32) # (2, 3) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [1, 5] # (2,) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case6(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[[0, 0]], [[0, 1]]]), mindspore.int32) # (2, 1, 2) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[[0, 1]], [[2, 3]]] # (2, 1, 2) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case7(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[[1]], [[0]]]), mindspore.int32) # (2, 1, 1) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[[[4, 5], [6, 7]]], [[[0, 1], [2, 3]]]] # (2, 1, 2, 2) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case8(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[[0, 1], [1, 0]], [[0, 0], [1, 1]]]), mindspore.int32) # (2, 2, 2) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.float32) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[[2, 3], [4, 5]], [[0, 1], [6, 7]]] # (2, 2, 2) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_case9(): | |||
| op = P.GatherNd() | |||
| op_wrapper = OpNetWrapper(op) | |||
| indices = Tensor(np.array([[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]), mindspore.int32) # (2, 2, 3) | |||
| params = Tensor(np.array([[[0, 1], [2, 3]], | |||
| [[4, 5], [6, 7]]]), mindspore.int64) # (2, 2, 2) | |||
| outputs = op_wrapper(params, indices) | |||
| print(outputs) | |||
| expected = [[1, 5], [3, 6]] # (2, 2, 2) | |||
| assert np.allclose(outputs.asnumpy(), np.array(expected)) | |||
| if __name__ == '__main__': | |||
| test_case1_basic_func() | |||
| test_case2_indices_to_matrix() | |||
| test_case3_indices_to_3d_tensor() | |||
| test_case4() | |||
| test_case5() | |||
| test_case6() | |||
| test_case7() | |||
| test_case8() | |||
| test_case9() | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_int32(): | |||
| op = P.GreaterEqual() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||
| input_y = Tensor(np.array([3, 2, 1]).astype(np.int32)) | |||
| outputs = op_wrapper(input_x, input_y) | |||
| print(outputs) | |||
| assert outputs.shape == (3,) | |||
| assert np.allclose(outputs.asnumpy(), [False, True, True]) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_float32(): | |||
| op = P.GreaterEqual() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([1, 2, -1]).astype(np.float32)) | |||
| input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32)) | |||
| outputs = op_wrapper(input_x, input_y) | |||
| print(outputs) | |||
| assert outputs.shape == (3,) | |||
| assert np.allclose(outputs.asnumpy(), [True, True, True]) | |||
| if __name__ == '__main__': | |||
| test_int32() | |||
| test_float32() | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_int32(): | |||
| op = P.Greater() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||
| input_y = Tensor(np.array([3, 2, 1]).astype(np.int32)) | |||
| outputs = op_wrapper(input_x, input_y) | |||
| print(outputs) | |||
| assert outputs.shape == (3,) | |||
| assert np.allclose(outputs.asnumpy(), [False, False, True]) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_float32(): | |||
| op = P.Greater() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([1, 2, -1]).astype(np.float32)) | |||
| input_y = Tensor(np.array([-3, 2, -1]).astype(np.float32)) | |||
| outputs = op_wrapper(input_x, input_y) | |||
| print(outputs) | |||
| assert outputs.shape == (3,) | |||
| assert np.allclose(outputs.asnumpy(), [True, False, False]) | |||
| if __name__ == '__main__': | |||
| test_int32() | |||
| test_float32() | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_int(): | |||
| op = nn.Range(0, 100, 10) | |||
| op_wrapper = OpNetWrapper(op) | |||
| outputs = op_wrapper() | |||
| print(outputs) | |||
| assert outputs.shape == (10,) | |||
| assert np.allclose(outputs.asnumpy(), range(0, 100, 10)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_float(): | |||
| op = nn.Range(10., 100., 20.) | |||
| op_wrapper = OpNetWrapper(op) | |||
| outputs = op_wrapper() | |||
| print(outputs) | |||
| assert outputs.shape == (5,) | |||
| assert np.allclose(outputs.asnumpy(), [10., 30., 50., 70., 90.]) | |||
| if __name__ == '__main__': | |||
| test_int() | |||
| test_float() | |||