From: @caojian05 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -99,6 +99,24 @@ void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] == input2[idx[1]]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] != input2[idx[1]]; | |||
| } | |||
| } | |||
| void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -114,6 +132,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = POW; | |||
| } else if (kernel_name == prim::kPrimLess->name()) { | |||
| operate_type_ = LESS; | |||
| } else if (kernel_name == prim::kPrimEqual->name()) { | |||
| operate_type_ = EQUAL; | |||
| } else if (kernel_name == prim::kPrimNotEqual->name()) { | |||
| operate_type_ = NOTEQUAL; | |||
| } else if (kernel_name == prim::kPrimAssignAdd->name()) { | |||
| operate_type_ = ASSIGNADD; | |||
| } else { | |||
| @@ -141,19 +163,22 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { | |||
| MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; | |||
| } | |||
| target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | |||
| } | |||
| bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { | |||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt8) { | |||
| LaunchKernel<int>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| LaunchKernel<int64_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeBool) { | |||
| LaunchKernelLogic<bool>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||
| MS_LOG(EXCEPTION) << "Data type " << TypeIdLabel(dtype_) << "is not support."; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -190,7 +215,8 @@ void ArithmeticCPUKernel::GenIndex(size_t num, std::vector<size_t> *idx) { | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::LaunchLess(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| T *input1 = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *input2 = reinterpret_cast<T *>(inputs[1]->addr); | |||
| bool *output = reinterpret_cast<bool *>(outputs[0]->addr); | |||
| @@ -213,7 +239,15 @@ void ArithmeticCPUKernel::LaunchLess(const std::vector<AddressPtr> &inputs, cons | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Less<T>, this, input1, input2, output, start, end)); | |||
| if (operate_type_ == LESS) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Less<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == EQUAL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == NOTEQUAL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| @@ -223,8 +257,8 @@ void ArithmeticCPUKernel::LaunchLess(const std::vector<AddressPtr> &inputs, cons | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| if (operate_type_ == LESS) { | |||
| LaunchLess<T>(inputs, outputs); | |||
| if (target_dtype_ == kNumberTypeBool) { | |||
| LaunchKernelLogic<T>(inputs, outputs); | |||
| return; | |||
| } | |||
| T *input1 = reinterpret_cast<T *>(inputs[0]->addr); | |||
| @@ -17,6 +17,7 @@ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <limits> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| @@ -32,7 +33,7 @@ class ArithmeticCPUKernel : public CPUKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchLess(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| void LaunchKernelLogic(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| @@ -52,6 +53,10 @@ class ArithmeticCPUKernel : public CPUKernel { | |||
| void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| std::vector<size_t> input_shape0_; | |||
| std::vector<size_t> input_shape1_; | |||
| std::vector<size_t> input_element_num0_; | |||
| @@ -60,6 +65,7 @@ class ArithmeticCPUKernel : public CPUKernel { | |||
| std::vector<size_t> output_element_num_; | |||
| OperateType operate_type_{ADD}; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| TypeId target_dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| @@ -108,6 +114,70 @@ MS_REG_CPU_KERNEL( | |||
| MS_REG_CPU_KERNEL( | |||
| Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| NotEqual, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -29,6 +29,19 @@ void Square(const T *in, T *out, size_t start, size_t end) { | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Sign(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| if (in[i] < 0) { | |||
| out[i] = -1; | |||
| } else if (in[i] > 0) { | |||
| out[i] = 1; | |||
| } else { | |||
| out[i] = 0; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Neg(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| @@ -62,6 +75,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = ZEROSLIKE; | |||
| } else if (kernel_name == prim::kPrimNeg->name()) { | |||
| operate_type_ = NEG; | |||
| } else if (kernel_name == prim::kPrimSign->name()) { | |||
| operate_type_ = SIGN; | |||
| } | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| } | |||
| @@ -111,6 +126,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end)); | |||
| } else if (operate_type_ == ZEROSLIKE) { | |||
| threads.emplace_back(std::thread(ZerosLike<T>, input, output, start, end)); | |||
| } else if (operate_type_ == SIGN) { | |||
| threads.emplace_back(std::thread(Sign<T>, input, output, start, end)); | |||
| } | |||
| start += once_compute_size; | |||
| } | |||
| @@ -54,6 +54,10 @@ MS_REG_CPU_KERNEL(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -71,7 +71,10 @@ enum OperateType { | |||
| SQRTGRAD, | |||
| SIGMOIDGRAD, | |||
| ONESLIKE, | |||
| ZEROSLIKE | |||
| ZEROSLIKE, | |||
| SIGN, | |||
| EQUAL, | |||
| NOTEQUAL, | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| @@ -1,99 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/equal_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void EqualCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { | |||
| MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; | |||
| } | |||
| } | |||
| bool EqualCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeBool) { | |||
| LaunchKernel<bool>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt8) { | |||
| LaunchKernel<int8_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt16) { | |||
| LaunchKernel<int16_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt) { | |||
| LaunchKernel<int32_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| LaunchKernel<int64_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeUInt8) { | |||
| LaunchKernel<uint8_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeUInt16) { | |||
| LaunchKernel<uint16_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeUInt32 || dtype_ == kNumberTypeUInt) { | |||
| LaunchKernel<uint32_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeUInt64) { | |||
| LaunchKernel<uint64_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat64) { | |||
| LaunchKernel<double>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support bool, int, uint, float, but actual data type is " << TypeIdLabel(dtype_); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void EqualCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| T *left = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *right = reinterpret_cast<T *>(inputs[1]->addr); | |||
| bool *output = reinterpret_cast<bool *>(outputs[0]->addr); | |||
| size_t elem_num = inputs[0]->size / sizeof(T); | |||
| for (size_t i = 0; i < elem_num; i++) { | |||
| if (left[i] == right[i]) { | |||
| output[i] = true; | |||
| } else { | |||
| output[i] = false; | |||
| } | |||
| } | |||
| } | |||
| void EqualCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but EqualCPUKernel needs 2 inputs."; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but EqualCPUKernel needs 1 output."; | |||
| } | |||
| auto input_shape0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto input_shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| if (input_shape0.size() != input_shape1.size()) { | |||
| MS_LOG(EXCEPTION) << "Input0 and Input1 must have the same shape"; | |||
| } | |||
| for (size_t i = 0; i < input_shape0.size(); ++i) { | |||
| if (input_shape0[i] != input_shape1[i]) { | |||
| MS_LOG(EXCEPTION) << "Input0 and Input1 must have the same shape"; | |||
| } | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,75 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EQUAL_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EQUAL_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class EqualCPUKernel : public CPUKernel { | |||
| public: | |||
| EqualCPUKernel() = default; | |||
| ~EqualCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| TypeId dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), | |||
| EqualCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EQUAL_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/mkldnn/pooling_avg_grad_cpu_kernel.h" | |||
| #include <string> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include "utils/ms_utils.h" | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void AvgPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| std::vector<size_t> dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); | |||
| dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); | |||
| std::vector<int> origin_kernel_sizes; | |||
| std::vector<int> strides; | |||
| std::vector<int64_t> kernel_sizes_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, KSIZE); | |||
| std::vector<int64_t> strides_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES); | |||
| (void)std::transform(kernel_sizes_me.begin(), kernel_sizes_me.end(), std::back_inserter(origin_kernel_sizes), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| (void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| if (origin_kernel_sizes.size() != 4 || strides.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "Invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size(); | |||
| } | |||
| dnnl::memory::dims strides_dims{strides[2], strides[3]}; | |||
| dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]}; | |||
| const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PADDING); | |||
| std::vector<int> int_padding_l; | |||
| std::vector<int> int_padding_r; | |||
| std::vector<size_t> kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])}); | |||
| GetPadding(kernel_node, pad_mode, src_shape, kernel_size, strides[3], &int_padding_l, &int_padding_r); | |||
| if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Pooling avg get padding failed"; | |||
| } | |||
| dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; | |||
| dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; | |||
| // pooling_avg forward description | |||
| dnnl::pooling_forward::desc desc = | |||
| dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, dst_desc, | |||
| strides_dims, kernels_dims, padding_l, padding_r); | |||
| auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); | |||
| // pooling_avg backward description | |||
| dnnl::pooling_backward::desc backward_desc = dnnl::pooling_backward::desc( | |||
| dnnl::algorithm::pooling_avg, src_desc, dst_desc, strides_dims, kernels_dims, padding_l, padding_r); | |||
| auto backward_prim_desc = | |||
| dnnl::pooling_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), prim_desc); | |||
| primitive_ = std::make_shared<dnnl::pooling_backward>(backward_prim_desc); | |||
| AddArgument(DNNL_ARG_SRC, src_desc); | |||
| AddArgument(DNNL_ARG_DST, dst_desc); | |||
| AddArgument(DNNL_ARG_DIFF_SRC, src_desc); | |||
| AddArgument(DNNL_ARG_DIFF_DST, dst_desc); | |||
| } | |||
| bool AvgPoolingGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() < 3 || outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Pooling avg grad error input output size!"; | |||
| } | |||
| SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DST, inputs[1]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[2]->addr); | |||
| SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); | |||
| ExecutePrimitive(); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_AVG_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_AVG_GRAD_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class AvgPoolingGradCPUKernel : public MKLCPUKernel { | |||
| public: | |||
| AvgPoolingGradCPUKernel() = default; | |||
| ~AvgPoolingGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| int stride_{0}; | |||
| std::vector<size_t> kernel_size_; | |||
| }; | |||
| MS_REG_CPU_KERNEL(AvgPoolGradCpu, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| AvgPoolingGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_AVG_GRAD_CPU_KERNEL_H_ | |||
| @@ -54,6 +54,11 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| dnnl::pooling_forward::desc desc = | |||
| dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc, | |||
| strides_dims, kernels_dims, padding_l, padding_r); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == prim::kPrimAvgPool->name()) { | |||
| desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc, | |||
| dst_desc, strides_dims, kernels_dims, padding_l, padding_r); | |||
| } | |||
| auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); | |||
| primitive_ = std::make_shared<dnnl::pooling_forward>(prim_desc); | |||
| AddArgument(DNNL_ARG_SRC, src_desc); | |||
| @@ -35,6 +35,8 @@ class PoolingCPUKernel : public MKLCPUKernel { | |||
| MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| PoolingCPUKernel); | |||
| MS_REG_CPU_KERNEL(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| PoolingCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/mkldnn/pooling_max_grad_cpu_kernel.h" | |||
| #include <string> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| @@ -23,7 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| void MaxPoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| src_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| dst_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| @@ -45,9 +45,9 @@ void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); | |||
| } | |||
| void PoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, float diff, | |||
| const std::vector<std::pair<size_t, size_t>> &box, | |||
| std::vector<std::pair<size_t, float>> *row_max_pair) { | |||
| void MaxPoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, float diff, | |||
| const std::vector<std::pair<size_t, size_t>> &box, | |||
| std::vector<std::pair<size_t, float>> *row_max_pair) { | |||
| float max_value = 0; | |||
| size_t max_index = box[1].second; | |||
| size_t src_width = src_shape_[3]; | |||
| @@ -74,7 +74,7 @@ void PoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, flo | |||
| output[(*row_max_pair)[max_index].first] += diff; | |||
| } | |||
| void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *diff, float *output) { | |||
| void MaxPoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *diff, float *output) { | |||
| int src_width = SizeToInt(src_shape_[3]); | |||
| int src_height = SizeToInt(src_shape_[2]); | |||
| std::vector<std::pair<size_t, float>> row_max_pair(src_shape_[3]); | |||
| @@ -100,9 +100,9 @@ void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *d | |||
| } | |||
| } | |||
| bool PoolingGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| bool MaxPoolingGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (inputs.size() < 3 || outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "pooling grad error input output size!"; | |||
| } | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_GRAD_CPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_MAX_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_MAX_GRAD_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| @@ -23,10 +23,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class PoolingGradCPUKernel : public MKLCPUKernel { | |||
| class MaxPoolingGradCPUKernel : public MKLCPUKernel { | |||
| public: | |||
| PoolingGradCPUKernel() = default; | |||
| ~PoolingGradCPUKernel() override = default; | |||
| MaxPoolingGradCPUKernel() = default; | |||
| ~MaxPoolingGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| @@ -50,8 +50,8 @@ MS_REG_CPU_KERNEL(MaxPoolGrad, | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| PoolingGradCPUKernel); | |||
| MaxPoolingGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_GRAD_CPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_MAX_GRAD_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,124 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/split_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void SplitCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| axis_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS); | |||
| auto output_1_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (axis_ < 0) { | |||
| axis_ = axis_ + SizeToLong(output_1_shape.size()); | |||
| } | |||
| axis_ += 4 - SizeToLong(output_1_shape.size()); | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i); | |||
| CPUKernelUtils::ExpandDimsTo4(&output_shape); | |||
| output_shape_list_.push_back(output_shape); | |||
| } | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| CPUKernelUtils::ExpandDimsTo4(&input_shape_); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| } | |||
| bool SplitCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt) { | |||
| return LaunchKernel<int32_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt64) { | |||
| return LaunchKernel<int64_t>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { | |||
| return LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat64) { | |||
| return LaunchKernel<double>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_); | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool SplitCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto buff_size = inputs[0]->size; | |||
| size_t dim0 = input_shape_[0]; | |||
| size_t dim1 = input_shape_[1]; | |||
| size_t dim2 = input_shape_[2]; | |||
| if (axis_ == 3) { | |||
| for (size_t i = 0; i < dim0; ++i) { | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| for (size_t k = 0; k < dim2; ++k) { | |||
| CopyDataToOutput(outputs, i, j, k, &input_addr, &buff_size); | |||
| } | |||
| } | |||
| } | |||
| } else if (axis_ == 2) { | |||
| for (size_t i = 0; i < dim0; ++i) { | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| CopyDataToOutput(outputs, i, j, 0, &input_addr, &buff_size); | |||
| } | |||
| } | |||
| } else if (axis_ == 1) { | |||
| for (size_t i = 0; i < dim0; ++i) { | |||
| CopyDataToOutput(outputs, i, 0, 0, &input_addr, &buff_size); | |||
| } | |||
| } else if (axis_ == 0) { | |||
| CopyDataToOutput(outputs, 0, 0, 0, &input_addr, &buff_size); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void SplitCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &outputs, size_t dim0, size_t dim1, | |||
| size_t dim2, T **input_addr, size_t *buff_size) { | |||
| for (size_t i = 0; i < output_shape_list_.size(); ++i) { | |||
| auto output_i_shape = output_shape_list_[i]; | |||
| auto output_i_addr = reinterpret_cast<float *>(outputs[i]->addr); | |||
| size_t num = CPUKernelUtils::GetElementNumOnAxis(output_i_shape, axis_); | |||
| num *= output_i_shape[axis_]; | |||
| auto pos = CPUKernelUtils::CalcOffset(output_i_shape, dim0, dim1, dim2, 0); | |||
| auto ret = memcpy_s(output_i_addr + pos, *buff_size, *input_addr, num * sizeof(T)); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "memcpy failed."; | |||
| } | |||
| *input_addr += num; | |||
| *buff_size -= num * sizeof(T); | |||
| } | |||
| } | |||
| void SplitCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (output_shape.size() > 4) { | |||
| MS_LOG(EXCEPTION) << "Output dims is " << output_shape.size() << ", but SplitCPUKernel only support 4d or lower."; | |||
| } | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SplitCPUKernel needs 1 input."; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPLIT_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPLIT_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SplitCPUKernel : public CPUKernel { | |||
| public: | |||
| SplitCPUKernel() : axis_(0) {} | |||
| ~SplitCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| static void CheckParam(const CNodePtr &kernel_node); | |||
| template <typename T> | |||
| void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2, | |||
| T **output_addr, size_t *buff_size); | |||
| int64_t axis_; | |||
| std::vector<std::vector<size_t>> output_shape_list_; | |||
| std::vector<size_t> input_shape_; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Split, | |||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SplitCPUKernel); | |||
| MS_REG_CPU_KERNEL(Split, | |||
| KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SplitCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPLIT_CPU_KERNEL_H_ | |||
| @@ -141,6 +141,7 @@ inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive | |||
| inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool"); | |||
| inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad"); | |||
| inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm"); | |||
| inline const PrimitivePtr kPrimAvgPoolGradCpu = std::make_shared<Primitive>("AvgPoolGradCpu"); | |||
| inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam"); | |||
| inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm"); | |||
| inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx"); | |||
| @@ -263,6 +264,7 @@ inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log"); | |||
| inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt"); | |||
| inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | |||
| inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | |||
| inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | |||
| // Statements | |||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | |||
| @@ -338,6 +338,19 @@ def get_bprop_avg_pool_grad(self): | |||
| bprop_fn = bprop_gpu | |||
| elif self.target == "CPU": | |||
| avgpool_grad_cpu = G.AvgPoolGradCpu( | |||
| ksize=self.ksize, | |||
| strides=self.strides, | |||
| padding=self.padding, | |||
| data_format=self.format) | |||
| def bprop_cpu(x, out, dout): | |||
| dx = avgpool_grad_cpu(x, out, dout) | |||
| return (dx,) | |||
| bprop_fn = bprop_cpu | |||
| elif self.target == "GE": | |||
| avgpool_grad_ge = G.AvgPoolGrad( | |||
| ksize=self.ksize, | |||
| @@ -885,6 +885,20 @@ class AvgPoolGradGpu(_PoolGrad): | |||
| return x1_dtype | |||
| class AvgPoolGradCpu(_PoolGrad): | |||
| """Gradients of the avg pool operation for cpu.""" | |||
| @prim_attr_register | |||
| def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"): | |||
| super(AvgPoolGradCpu, self).__init__(ksize, strides, padding, data_format) | |||
| def infer_shape(self, x1_shape, x2_shape, grad_shape): | |||
| return x1_shape | |||
| def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): | |||
| return x1_dtype | |||
| class MaxPoolGrad(_PoolGrad): | |||
| """Performs gradients of the max pool operation.""" | |||
| @@ -951,7 +951,7 @@ class Split(PrimitiveWithCheck): | |||
| :math:`(y_1, y_2, ..., y_S)`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> split = ops.Split(1, 2) | |||
| @@ -2519,7 +2519,7 @@ class Equal(_LogicBinaryOp): | |||
| Tensor, the shape is the same as the one after broadcasting,and the data type is bool. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32) | |||
| @@ -2656,7 +2656,7 @@ class NotEqual(_LogicBinaryOp): | |||
| Tensor, the shape is the same as the one after broadcasting,and the data type is bool. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32) | |||
| @@ -3459,7 +3459,7 @@ class Sign(PrimitiveWithInfer): | |||
| Tensor, has the same shape and type as the `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[2.0, 0.0, -1.0]]), mindspore.float32) | |||
| @@ -1639,7 +1639,7 @@ class AvgPool(_Pool): | |||
| Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| @@ -1672,6 +1672,8 @@ class AvgPool(_Pool): | |||
| def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"): | |||
| if context.get_context("device_target") == "GPU": | |||
| self.target = "GPU" | |||
| elif context.get_context("device_target") == "CPU": | |||
| self.target = "CPU" | |||
| elif context.get_context("enable_ge"): | |||
| self.target = "GE" | |||
| else: | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops.composite import GradOperation | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| @ms_function | |||
| def construct(self, input_, output_grad): | |||
| return self.grad(self.network)(input_, output_grad) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_net(): | |||
| x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32) | |||
| net = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='valid') | |||
| out = net(Tensor(x)) | |||
| out_shape = out.asnumpy().shape | |||
| sens = np.arange(int(np.prod(out_shape))).reshape(out_shape).astype(np.float32) | |||
| backword_net = Grad(net) | |||
| output = backword_net(Tensor(x), Tensor(sens)) | |||
| print(len(output)) | |||
| print(output[0].asnumpy()) | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_avgpool_k2s1pv(): | |||
| x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32) | |||
| net = nn.AvgPool2d(kernel_size=2, stride=1, pad_mode='valid') | |||
| out = net(Tensor(x)) | |||
| print(out) | |||
| expect_result = np.array( | |||
| [[[[3.5, 4.5, 5.5, 6.5, 7.5], | |||
| [9.5, 10.5, 11.5, 12.5, 13.5], | |||
| [15.5, 16.5, 17.5, 18.5, 19.5], | |||
| [21.5, 22.5, 23.5, 24.5, 25.5], | |||
| [27.5, 28.5, 29.5, 30.5, 31.5]]]] | |||
| ) | |||
| assert np.allclose(out.asnumpy(), expect_result) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_avgpool_k2s2pv(): | |||
| x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32) | |||
| net = nn.AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') | |||
| out = net(Tensor(x)) | |||
| print(out) | |||
| expect_result = np.array( | |||
| [[[[3.5, 5.5, 7.5], | |||
| [15.5, 17.5, 19.5], | |||
| [27.5, 29.5, 31.5]]]] | |||
| ) | |||
| assert np.allclose(out.asnumpy(), expect_result) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_avgpool_k3s2pv(): | |||
| x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32) | |||
| net = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='valid') | |||
| out = net(Tensor(x)) | |||
| print(out) | |||
| expect_result = np.array( | |||
| [[[[7., 9.], | |||
| [19., 21.]]]] | |||
| ) | |||
| assert np.allclose(out.asnumpy(), expect_result) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_avgpool_k3s2ps(): | |||
| x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32) | |||
| net = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same') | |||
| out = net(Tensor(x)) | |||
| print(out) | |||
| expect_result = np.array( | |||
| [[[[7., 9., 10.5], | |||
| [19., 21., 22.5], | |||
| [28., 30., 31.5]]]] | |||
| ) | |||
| assert np.allclose(out.asnumpy(), expect_result) | |||
| if __name__ == '__main__': | |||
| test_avgpool_k2s1pv() | |||
| test_avgpool_k2s2pv() | |||
| test_avgpool_k3s2pv() | |||
| test_avgpool_k3s2ps() | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_notequal_int(): | |||
| op = P.NotEqual() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||
| input_y = Tensor(np.array([11, 2, 13]).astype(np.int32)) | |||
| outputs = op_wrapper(input_x, input_y) | |||
| print(outputs) | |||
| assert np.allclose(outputs.asnumpy(), (True, False, True)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_notequal_float(): | |||
| op = P.NotEqual() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([1, 2, 3]).astype(np.float32)) | |||
| input_y = Tensor(np.array([-1, 0, 3]).astype(np.float32)) | |||
| outputs = op_wrapper(input_x, input_y) | |||
| print(outputs) | |||
| assert np.allclose(outputs.asnumpy(), (True, True, False)) | |||
| if __name__ == '__main__': | |||
| test_notequal_int() | |||
| test_notequal_float() | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_sign_float32(): | |||
| op = P.Sign() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([[2.0, 0.0, -1.0]]).astype(np.float32)) | |||
| outputs = op_wrapper(input_x) | |||
| print(outputs) | |||
| assert np.allclose(outputs.asnumpy(), [[1., 0., -1.]]) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_sign_int32(): | |||
| op = P.Sign() | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.array([[20, 0, -10]]).astype(np.int32)) | |||
| outputs = op_wrapper(input_x) | |||
| print(outputs) | |||
| assert np.allclose(outputs.asnumpy(), [[1, 0, -1]]) | |||
| if __name__ == '__main__': | |||
| test_sign_float32() | |||
| test_sign_int32() | |||
| @@ -0,0 +1,86 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class OpNetWrapper(nn.Cell): | |||
| def __init__(self, op): | |||
| super(OpNetWrapper, self).__init__() | |||
| self.op = op | |||
| def construct(self, *inputs): | |||
| return self.op(*inputs) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_out1_axis0(): | |||
| op = P.Split(0, 1) | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6))) | |||
| outputs = op_wrapper(input_x) | |||
| print(outputs) | |||
| assert outputs[0].shape == (2, 2, 6) | |||
| assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2, 3, 4, 5]) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_out2_axis2(): | |||
| op = P.Split(2, 2) | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.arange(24).astype(np.int32).reshape((2, 2, 6))) | |||
| outputs = op_wrapper(input_x) | |||
| print(outputs) | |||
| assert outputs[0].shape == (2, 2, 3) | |||
| assert outputs[1].shape == (2, 2, 3) | |||
| assert np.allclose(outputs[0].asnumpy()[0, 0, :], [0, 1, 2]) | |||
| assert np.allclose(outputs[1].asnumpy()[0, 0, :], [3, 4, 5]) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_out2_axis1neg(): | |||
| op = P.Split(-1, 2) | |||
| op_wrapper = OpNetWrapper(op) | |||
| input_x = Tensor(np.arange(24).astype(np.float32).reshape((2, 2, 6))) | |||
| outputs = op_wrapper(input_x) | |||
| print(outputs) | |||
| assert np.allclose(outputs[0].asnumpy()[0, :, :], [[0., 1., 2.], [6., 7., 8.]]) | |||
| assert np.allclose(outputs[1].asnumpy()[0, :, :], [[3., 4., 5.], [9., 10., 11.]]) | |||
| if __name__ == '__main__': | |||
| test_out1_axis0() | |||
| test_out2_axis2() | |||
| test_out2_axis1neg() | |||