diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index ed24d7fc41..8787ece01c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -185,6 +185,15 @@ void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *o } } +template +void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = input1[idx[0]] <= input2[idx[1]]; + } +} + void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); @@ -212,6 +221,8 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = GREATER; } else if (kernel_name == prim::kPrimGreaterEqual->name()) { operate_type_ = GREATEREQUAL; + } else if (kernel_name == prim::kPrimLessEqual->name()) { + operate_type_ = LESSEQUAL; } else if (kernel_name == prim::kPrimAssignAdd->name()) { operate_type_ = ASSIGNADD; } else if (kernel_name == prim::kPrimSquaredDifference->name()) { @@ -328,6 +339,8 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector &input } else if (operate_type_ == GREATEREQUAL) { threads.emplace_back( std::thread(&ArithmeticCPUKernel::GreaterEqual, this, input1, input2, output, start, end)); + } else if (operate_type_ == LESSEQUAL) { + threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual, this, input1, input2, output, start, end)); } else { MS_LOG(EXCEPTION) << "Not support " << operate_type_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index ae63e43857..aeee5eda18 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -67,6 +67,8 @@ class ArithmeticCPUKernel : public CPUKernel { void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end); template void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); + template + void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); std::vector input_shape0_; std::vector input_shape1_; std::vector input_element_num0_; @@ -239,6 +241,41 @@ MS_REG_CPU_KERNEL( GreaterEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + LessEqual, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), + ArithmeticCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index 049c7d823b..883cc7ff22 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -69,6 +69,13 @@ void Floor(const T *in, T *out, size_t start, size_t end) { out[i] = static_cast(floor(in[i])); } } + +template +void Reciprocal(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(1.0 / in[i]); + } +} } // namespace void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { @@ -86,6 +93,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = SIGN; } else if (kernel_name == prim::kPrimFloor->name()) { operate_type_ = FLOOR; + } else if (kernel_name == prim::kPrimReciprocal->name()) { + operate_type_ = RECIPROCAL; } dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } @@ -139,6 +148,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs threads.emplace_back(std::thread(Sign, input, output, start, end)); } else if (operate_type_ == FLOOR) { threads.emplace_back(std::thread(Floor, input, output, start, end)); + } else if (operate_type_ == RECIPROCAL) { + threads.emplace_back(std::thread(Reciprocal, input, output, start, end)); } start += once_compute_size; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 9b4b3f36c1..51f88f4103 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -60,6 +60,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAtt ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index 359840be4a..b097e6f40c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -48,6 +48,7 @@ const char IS_GRAD[] = "is_grad"; const char TRANSPOSE_NO = 'N'; const char TRANSPOSE_YES = 'T'; const char AXIS[] = "axis"; +const char DIM[] = "dim"; const char BEGIN[] = "begin"; const char END[] = "end"; const char SIZE[] = "size"; @@ -81,10 +82,12 @@ enum OperateType { SIGN, EQUAL, NOTEQUAL, + LESSEQUAL, FLOOR, SQUAREDDIFFERENCE, GREATER, GREATEREQUAL, + RECIPROCAL, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_cpu_kernel.cc new file mode 100644 index 0000000000..1258b24941 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_cpu_kernel.cc @@ -0,0 +1,135 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/gather_d_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} + +template +void CopyTask(size_t cur, std::vector *pos, T *input, I *index, const int &dim, T *output, + const std::vector &output_shape, const std::vector &out_cargo_size, + const std::vector &input_cargo_size, bool reverse) { + for (size_t i = 0; i < output_shape[cur]; ++i) { + (*pos)[cur] = i; + if (cur == output_shape.size() - 1) { + size_t input_offset = 0; + size_t out_offset = 0; + // out offset + for (size_t j = 0; j < output_shape.size(); ++j) { + out_offset += (*pos)[j] * out_cargo_size[j]; + } + // input offset + size_t cur_index = (*pos)[dim]; + (*pos)[dim] = index[out_offset]; + for (size_t j = 0; j < output_shape.size(); ++j) { + input_offset += (*pos)[j] * input_cargo_size[j]; + } + // do copy + if (reverse) { + input[input_offset] = output[out_offset]; + } else { + output[out_offset] = input[input_offset]; + } + (*pos)[dim] = cur_index; + } else { + // CopyTask + CopyTask(cur + 1, pos, input, index, dim, output, output_shape, out_cargo_size, input_cargo_size, reverse); + } + } +} +} // namespace + +template +void GatherDCPUKernel::InitKernel(const CNodePtr &kernel_node) { + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + + if (input_shape_.size() != index_shape_.size()) { + MS_LOG(EXCEPTION) << "Invalid shape size, shape size of input: " << input_shape_.size() + << ", and index: " << index_shape_.size() << " should be equal"; + } + output_shape_ = index_shape_; +} + +template +bool GatherDCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + size_t input_size = get_element_num(input_shape_) * sizeof(T); + size_t index_size = get_element_num(index_shape_) * sizeof(I); + size_t dim_size = sizeof(int); + size_t output_size = get_element_num(output_shape_) * sizeof(T); + + if (inputs[0]->size != input_size || inputs[1]->size != dim_size || inputs[2]->size != index_size || + outputs[0]->size != output_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + return false; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto dim = reinterpret_cast(inputs[1]->addr); + auto index = reinterpret_cast(inputs[2]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + int32_t input_rank = SizeToInt(input_shape_.size()); + + if (dim[0] >= input_rank || dim[0] < -input_rank) { + MS_LOG(EXCEPTION) << "The value of 'dim' should be in [" << -input_rank << ", " << input_rank + << "], but got: " << dim[0]; + return false; + } + if (dim[0] < 0) { + dim[0] = static_cast(dim[0] + input_rank); + } + // check index + int max_index = SizeToInt(input_shape_[dim[0]]); + index_size = get_element_num(index_shape_); + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + MS_LOG(EXCEPTION) << "The value of index should be in [" << -max_index << ", " << max_index + << "], but got: " << index[i]; + return false; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + + // out_cargo_size + std::vector out_cargo_size = std::vector(output_shape_.size(), 1); + for (int i = out_cargo_size.size() - 2; i >= 0; --i) { + out_cargo_size[i] = output_shape_[i + 1] * out_cargo_size[i + 1]; + } + // input_cargo_size + std::vector input_cargo_size = std::vector(input_shape_.size(), 1); + for (int i = input_cargo_size.size() - 2; i >= 0; --i) { + input_cargo_size[i] = input_shape_[i + 1] * input_cargo_size[i + 1]; + } + // copy task + std::vector pos(index_shape_.size(), 0); + int copy_dim = *dim; + CopyTask(0, &pos, input, index, copy_dim, output, output_shape_, out_cargo_size, input_cargo_size, false); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_cpu_kernel.h new file mode 100644 index 0000000000..521903b26a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_cpu_kernel.h @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class GatherDCPUKernel : public CPUKernel { + public: + GatherDCPUKernel() = default; + ~GatherDCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector input_shape_; + std::vector index_shape_; + std::vector output_shape_; + int32_t axis_; +}; + +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + GatherDCPUKernel, float, int32_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + GatherDCPUKernel, float, int64_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + GatherDCPUKernel, float16, int32_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + GatherDCPUKernel, float16, int64_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + GatherDCPUKernel, int32_t, int32_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + GatherDCPUKernel, int32_t, int64_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + GatherDCPUKernel, int64_t, int32_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + GatherDCPUKernel, int64_t, int64_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeBool), + GatherDCPUKernel, bool, int32_t); +MS_REG_CPU_KERNEL_T_S(GatherD, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeBool), + GatherDCPUKernel, bool, int64_t); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.cc new file mode 100644 index 0000000000..8a31a5a22e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} + +template +void GatherDGradCopyTask(size_t cur, std::vector *pos, T *input, I *index, const int &dim, T *output, + const std::vector &output_shape, const std::vector &out_cargo_size, + const std::vector &input_cargo_size) { + for (size_t i = 0; i < output_shape[cur]; ++i) { + (*pos)[cur] = i; + if (cur == output_shape.size() - 1) { + size_t input_offset = 0; + size_t out_offset = 0; + // out offset + for (size_t j = 0; j < output_shape.size(); ++j) { + out_offset += (*pos)[j] * out_cargo_size[j]; + } + // input offset + size_t cur_index = (*pos)[dim]; + (*pos)[dim] = index[out_offset]; + for (size_t j = 0; j < output_shape.size(); ++j) { + input_offset += (*pos)[j] * input_cargo_size[j]; + } + // do copy + input[input_offset] += output[out_offset]; + (*pos)[dim] = cur_index; + } else { + // CopyTask + GatherDGradCopyTask(cur + 1, pos, input, index, dim, output, output_shape, out_cargo_size, input_cargo_size); + } + } +} +} // namespace + +template +void GatherDGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + if (input_shape_ != index_shape_) { + MS_LOG(EXCEPTION) << "Invalid shape size, input and index shape should be equal"; + } + axis_ = AnfAlgo::GetNodeAttr(kernel_node, DIM); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); +} + +template +bool GatherDGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + size_t input_size = get_element_num(input_shape_) * sizeof(T); + size_t index_size = get_element_num(index_shape_) * sizeof(I); + size_t output_size = get_element_num(output_shape_) * sizeof(T); + if (inputs[0]->size != index_size || inputs[1]->size != input_size || outputs[0]->size != output_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + return false; + } + + auto index = reinterpret_cast(inputs[0]->addr); + auto input = reinterpret_cast(inputs[1]->addr); + auto out = reinterpret_cast(outputs[0]->addr); + + int output_rank = SizeToInt(output_shape_.size()); + if (axis_ >= output_rank || axis_ < -output_rank) { + MS_LOG(EXCEPTION) << "The value of 'axis_' should be in [" << -output_rank << ", " << output_rank + << "], but got: " << axis_; + return false; + } + + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(output_shape_.size()); + } + + // check index + index_size = get_element_num(index_shape_); + int max_index = SizeToInt(output_shape_[axis_]); + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + MS_LOG(EXCEPTION) << "The value of index should be in [" << -max_index << ", " << max_index + << "], but got: " << index[i]; + return false; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + auto out_size = get_element_num(output_shape_); + memset_s(out, out_size * sizeof(T), 0x00, out_size * sizeof(T)); + + // out_cargo_size + std::vector out_cargo_size = std::vector(output_shape_.size(), 1); + for (int i = out_cargo_size.size() - 2; i >= 0; --i) { + out_cargo_size[i] = output_shape_[i + 1] * out_cargo_size[i + 1]; + } + // input_cargo_size + std::vector input_cargo_size = std::vector(input_shape_.size(), 1); + for (int i = input_cargo_size.size() - 2; i >= 0; --i) { + input_cargo_size[i] = input_shape_[i + 1] * input_cargo_size[i + 1]; + } + + // copy task + std::vector pos(index_shape_.size(), 0); + GatherDGradCopyTask(0, &pos, out, index, axis_, input, index_shape_, input_cargo_size, out_cargo_size); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.h new file mode 100644 index 0000000000..b86fa93a67 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_d_grad_cpu_kernel.h @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERDGRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERDGRAD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class GatherDGradCPUKernel : public CPUKernel { + public: + GatherDGradCPUKernel() = default; + ~GatherDGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector input_shape_; + std::vector index_shape_; + std::vector output_shape_; + int32_t axis_; +}; + +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherDGradCPUKernel, int32_t, int32_t); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + GatherDGradCPUKernel, int32_t, int64_t); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GatherDGradCPUKernel, int32_t, float); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GatherDGradCPUKernel, int32_t, float16); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + GatherDGradCPUKernel, int32_t, bool); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherDGradCPUKernel, int64_t, int32_t); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + GatherDGradCPUKernel, int64_t, int64_t); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GatherDGradCPUKernel, int64_t, float); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GatherDGradCPUKernel, int64_t, float16); +MS_REG_CPU_KERNEL_T_S( + GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + GatherDGradCPUKernel, int64_t, bool); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_GATHERDGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.cc new file mode 100644 index 0000000000..ede0ad4e19 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.h" +#include +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +void GetCargo(std::vector *cargo, const std::vector &shape, const std::vector &dout_shape) { + int i = dout_shape.size() - 1; + int j = shape.size() - 1; + (*cargo)[i] = 1; + for (--i; j >= 1; --i, --j) { + (*cargo)[i] = shape[j] * (*cargo)[i + 1]; + } + for (; i >= 0; i--) { + (*cargo)[i] = 1; + } +} + +size_t GetTensorLen(const std::vector &shape) { + size_t len = 1; + for (size_t i = 0; i < shape.size(); i++) { + len *= shape[i]; + } + return len; +} + +void GetShape(std::vector *shape, const std::vector &shape_, const std::vector &dout_shape) { + int k = dout_shape.size() - 1; + int i = shape_.size() - 1; + for (; i >= 0; i--, k--) { + (*shape)[k] = shape_[i]; + } +} +} // namespace + +void MinimumGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + y_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + dy_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + if (!x_shape_.size() || !y_shape_.size() || !dout_shape.size()) { + MS_LOG(EXCEPTION) << "Input NULL"; + } +} + +bool MinimumGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeUInt32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt64) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeUInt64) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void MinimumGradRecTask(T *x, T *y, T *dout, T *dx, T *dy, size_t dim, size_t x_index, size_t y_index, + size_t dout_index, const std::vector &x_cargo, const std::vector &y_cargo, + const std::vector &dout_cargo, const std::vector &x_shape, + const std::vector &y_shape, const std::vector &dout_shape) { + for (size_t i = 0; i < dout_shape[dim]; i++) { + size_t x_i = x_shape[dim] == dout_shape[dim] ? i * x_cargo[dim] : 0; + size_t y_i = y_shape[dim] == dout_shape[dim] ? i * y_cargo[dim] : 0; + size_t dout_i = i * dout_cargo[dim]; + + if (dim == dout_shape.size() - 1) { + if (*(x + x_index + x_i) <= *(y + y_index + y_i)) { + *(dx + x_index + x_i) += *(dout + dout_index + i); + } else { + *(dy + y_index + y_i) += *(dout + dout_index + i); + } + } else { + MinimumGradRecTask(x, y, dout, dx, dy, dim + 1, x_index + x_i, y_index + y_i, dout_index + dout_i, x_cargo, + y_cargo, dout_cargo, x_shape, y_shape, dout_shape); + } + } +} + +template +void MinimumGradCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto x_addr = reinterpret_cast(inputs[0]->addr); + auto y_addr = reinterpret_cast(inputs[1]->addr); + auto dout_addr = reinterpret_cast(inputs[2]->addr); + auto dx_addr = reinterpret_cast(outputs[0]->addr); + auto dy_addr = reinterpret_cast(outputs[1]->addr); + + size_t x_tensor_len = GetTensorLen(x_shape_); + size_t y_tensor_len = GetTensorLen(y_shape_); + memset(dx_addr, 0, x_tensor_len * sizeof(T)); + memset(dy_addr, 0, y_tensor_len * sizeof(T)); + + std::vector x_shape(dout_shape.size(), 1); + std::vector y_shape(dout_shape.size(), 1); + std::vector x_cargo(dout_shape.size(), 0); + std::vector y_cargo(dout_shape.size(), 0); + std::vector dout_cargo(dout_shape.size(), 0); + + GetShape(&x_shape, x_shape_, dout_shape); + GetShape(&y_shape, y_shape_, dout_shape); + + GetCargo(&x_cargo, x_shape, dout_shape); + GetCargo(&y_cargo, y_shape, dout_shape); + GetCargo(&dout_cargo, dout_shape, dout_shape); + + MinimumGradRecTask(x_addr, y_addr, dout_addr, dx_addr, dy_addr, 0, 0, 0, 0, x_cargo, y_cargo, dout_cargo, x_shape, + y_shape, dout_shape); +} + +void MinimumGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but MinimumGradCPUKernel needs 3 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but MinimumGradCPUKernel needs 2 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.h new file mode 100644 index 0000000000..fab616497f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_grad_cpu_kernel.h @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MINIMUMGRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MINIMUMGRAD_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinimumGradCPUKernel : public CPUKernel { + public: + MinimumGradCPUKernel() = default; + ~MinimumGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + std::vector x_shape_; + std::vector y_shape_; + std::vector dout_shape; + std::vector dx_shape; + std::vector dy_shape; + TypeId dtype_{kTypeUnknown}; +}; + +MS_REG_CPU_KERNEL(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + MinimumGradCPUKernel); + +MS_REG_CPU_KERNEL(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + MinimumGradCPUKernel); + +MS_REG_CPU_KERNEL(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MinimumGradCPUKernel); + +MS_REG_CPU_KERNEL(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + MinimumGradCPUKernel); + +MS_REG_CPU_KERNEL(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + MinimumGradCPUKernel); + +MS_REG_CPU_KERNEL(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + MinimumGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MinimumGrad_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 6edf59067c..2e259fbc38 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4648,7 +4648,7 @@ class GatherD(PrimitiveWithInfer): Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32) @@ -4680,7 +4680,7 @@ class GatherD(PrimitiveWithInfer): if dim_v < 0: dim['value'] = dim_v + x_rank for i in range(x_rank): - if i == dim_v: + if i == dim['value']: continue validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 88aee88cc5..60f09ed963 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1459,7 +1459,7 @@ class Reciprocal(PrimitiveWithInfer): Tensor, has the same shape as the `input_x`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1.0, 2.0, 4.0]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py index 508f707859..02e520839e 100644 --- a/tests/st/ops/cpu/test_arithmetic_self_op.py +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -41,6 +41,15 @@ class FloorNet(nn.Cell): return self.floor(x) +class ReciprocalNet(nn.Cell): + def __init__(self): + super(ReciprocalNet, self).__init__() + self.reciprocal = P.Reciprocal() + + def construct(self, x): + return self.reciprocal(x) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -108,5 +117,27 @@ def test_floor(): print(output.asnumpy()) assert np.all(output.asnumpy() == expect_output) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_reciprocal(): + net = ReciprocalNet() + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop + output = net(Tensor(x)) + expect_output = (1. / x).astype(np.float16) + diff = output.asnumpy() - expect_output + error = np.ones(shape=expect_output.shape) * 1.0e-5 + assert np.all(np.abs(diff) < error) + + x = np.random.randn(3, 4, 5, 6).astype(np.float32) * prop + output = net(Tensor(x)) + expect_output = (1. / x).astype(np.float32) + diff = output.asnumpy() - expect_output + error = np.ones(shape=expect_output.shape) * 1.0e-5 + assert np.all(np.abs(diff) < error) + test_square() test_floor() +test_reciprocal() diff --git a/tests/st/ops/cpu/test_gather_d_grad_op.py b/tests/st/ops/cpu/test_gather_d_grad_op.py new file mode 100644 index 0000000000..3260ad5da1 --- /dev/null +++ b/tests/st/ops/cpu/test_gather_d_grad_op.py @@ -0,0 +1,121 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common.api import ms_function +from mindspore.ops.composite import GradOperation + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetGatherD(nn.Cell): + def __init__(self, dim=1): + super(NetGatherD, self).__init__() + self.gatherd = P.GatherD() + self.dim = int(dim) + + def construct(self, x, index): + return self.gatherd(x, self.dim, index) + +class NetGatherDGrad(nn.Cell): + def __init__(self, network): + super(NetGatherDGrad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + @ms_function + def construct(self, inputx, index, output_grad): + return self.grad(self.network)(inputx, index, output_grad) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_grad_fp32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float32) * prop + index = np.random.randint(0, 5, (5, 3, 5)).astype(np.int32) + dim = 1 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + dout = np.random.randint(0, 5, index.shape).astype(np.float32) * prop + output_grad = grad(Tensor(x), Tensor(index), Tensor(dout)) + if isinstance(output_grad, (tuple, list)): + output_grad = output_grad[0] + print(output_grad.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_grad_fp16(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float16) * prop + index = np.random.randint(0, 5, (3, 5, 5)).astype(np.int32) + dim = 0 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + dout = np.random.randint(0, 5, index.shape).astype(np.float16) * prop + output_grad = grad(Tensor(x), Tensor(index), Tensor(dout)) + if isinstance(output_grad, (tuple, list)): + output_grad = output_grad[0] + print(output_grad.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_grad_int32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.int32) * prop + index = np.random.randint(0, 5, (5, 5, 7)).astype(np.int64) + dim = -1 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + dout = np.random.randint(0, 5, index.shape).astype(np.int32) * prop + output_grad = grad(Tensor(x), Tensor(index), Tensor(dout)) + if isinstance(output_grad, (tuple, list)): + output_grad = output_grad[0] + print(output_grad.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_grad_checkresult(): + x = np.array([[[-146.76097, 119.84371], [91.22607, -166.12923]], + [[37.67479, -8.696029], [43.804962, -23.369316]]], np.float32) + index = np.array([[[0, 1], [0, 0]], [[0, 0], [0, 1]]], np.int32) + dim = 1 + + gatherd = NetGatherD(dim) + grad = NetGatherDGrad(gatherd) + dout = np.array([[[-1.23, 119.84], [91.22607, -145.67]], [[37.67479, -8.696029], [100.89, -23.369316]]], np.float32) + output = grad(Tensor(x), Tensor(index), Tensor(dout)) + + if isinstance(output, (tuple, list)): + output = output[0] + expect = np.array([[[89.99606, -145.67], [0., 119.84]], [[138.56479, -8.696029], [0., -23.369316]]], np.float32) + error = np.ones(shape=expect.shape) * 1.0e-6 + assert np.all(np.abs(output.asnumpy() - expect) < error) diff --git a/tests/st/ops/cpu/test_gather_d_op.py b/tests/st/ops/cpu/test_gather_d_op.py new file mode 100644 index 0000000000..a889f9d7dc --- /dev/null +++ b/tests/st/ops/cpu/test_gather_d_op.py @@ -0,0 +1,117 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + +class NetGatherD(nn.Cell): + def __init__(self, dim=1): + super(NetGatherD, self).__init__() + self.gatherd = P.GatherD() + self.dim = int(dim) + + def construct(self, x, index): + return self.gatherd(x, self.dim, index) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_fp32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float32) * prop + index = np.random.randint(0, 5, (5, 3, 5)).astype(np.int32) + dim = 1 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.float32) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[i, index[i, j, k], k] + error = np.ones(shape=expect.shape) * 1.0e-6 + assert np.all(np.abs(output.asnumpy() - expect) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_fp16(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.float16) * prop + index = np.random.randint(0, 5, (3, 5, 5)).astype(np.int64) + dim = 0 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.float16) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[index[i, j, k], j, k] + error = np.ones(shape=expect.shape) * 1.0e-6 + assert np.all(np.abs(output.asnumpy() - expect) < error) + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_int32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.int32) * prop + index = np.random.randint(0, 5, (5, 5, 8)).astype(np.int32) + dim = -1 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.int32) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[i, j, index[i, j, k]] + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_bool(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(5, 5, 5).astype(np.int32) * prop + x = (x >= 0).astype(np.bool) + index = np.random.randint(0, 5, (5, 5, 8)).astype(np.int32) + dim = -1 + + gatherd = NetGatherD(dim) + output = gatherd(Tensor(x), Tensor(index)) + + expect = np.zeros(index.shape).astype(np.bool) + for i in range(index.shape[0]): + for j in range(index.shape[1]): + for k in range(index.shape[2]): + expect[i, j, k] = x[i, j, index[i, j, k]] + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/cpu/test_less_equal_op.py b/tests/st/ops/cpu/test_less_equal_op.py new file mode 100644 index 0000000000..17217f8b7a --- /dev/null +++ b/tests/st/ops/cpu/test_less_equal_op.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.ops = P.LessEqual() + + def construct(self, x, y): + return self.ops(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_net(): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x3_np = np.random.randint(1, 5, 1).astype(np.float32) + y3_np = np.random.randint(1, 5, 1).astype(np.float32) + x4_np = np.array(768).astype(np.float32) + y4_np = np.array(3072.5).astype(np.float32) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + net = Net() + out = net(x0, y0).asnumpy() + expect = x0_np <= y0_np + assert np.all(out == expect) + assert out.shape == expect.shape + + out = net(x1, y1).asnumpy() + expect = x1_np <= y1_np + assert np.all(out == expect) + assert out.shape == expect.shape + + out = net(x2, y2).asnumpy() + expect = x2_np <= y2_np + assert np.all(out == expect) + assert out.shape == expect.shape + + out = net(x3, y3).asnumpy() + expect = x3_np <= y3_np + assert np.all(out == expect) + assert out.shape == expect.shape + + out = net(x4, y4).asnumpy() + expect = x4_np <= y4_np + assert np.all(out == expect) + assert out.shape == expect.shape diff --git a/tests/st/ops/cpu/test_minimum_grad_op.py b/tests/st/ops/cpu/test_minimum_grad_op.py new file mode 100644 index 0000000000..0f0d7eb096 --- /dev/null +++ b/tests/st/ops/cpu/test_minimum_grad_op.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore.ops import composite as C +from mindspore.ops.operations import Minimum + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") +grad = C.GradOperation(get_all=True, sens_param=True) + + +class MinNetMe(Cell): + def __init__(self): + super(MinNetMe, self).__init__() + self.min = Minimum() + + def construct(self, inputA, inputB): + x = self.min(inputA, inputB) + return x + + +class GradWrap(Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, inputA, inputB, sens): + gout = grad(self.network)(inputA, inputB, sens) + return gout + + +def gen_data(inputA_np, inputB_np, grad_=None): + inputA_me = inputA_np + if isinstance(inputA_np, np.ndarray): + inputA_me = Tensor(inputA_me) + + inputB_me = inputB_np + if isinstance(inputB_np, np.ndarray): + inputB_me = Tensor(inputB_np) + + if grad_ is None: + grad_ = Tensor(grad_) + + net_me = GradWrap(MinNetMe()) + net_me.set_train() + output = net_me(inputA_me, inputB_me, Tensor(grad_)) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_min_tensor_grad_4d(): + inputA_np = np.random.randn(1, 3, 2, 2).astype(np.float32) + inputB_np = np.random.randn(1, 3, 2, 2).astype(np.float32) + grad_ = np.random.randn(1, 3, 2, 2).astype(np.float32) + output = gen_data(inputA_np, inputB_np, grad_) + print(output[0].asnumpy()) + print(output[1].asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_min_tensor_grad_result(): + inputA = np.array([[[[0.659578], [0.49113268], [0.75909054], [0.71681815], [0.30421826]]], + [[[0.30322495], [0.02858258], [0.06398096], [0.09519596], [0.12498625]]], + [[[0.7347768], [0.166469], [0.328553], [0.54908437], [0.23673844]]]]).astype(np.float32) + inputB = np.array([[[[0.9154968, 0.29014662, 0.6492294, 0.39918253, 0.1648203, 0.00861965]], + [[0.996885, 0.24152198, 0.3601213, 0.51664376, 0.7933056, 0.84706444]], + [[0.75606346, 0.974512, 0.3939527, 0.69697475, 0.83400667, 0.6348955]], + [[0.68492866, 0.24609096, 0.4924665, 0.22500521, 0.38474053, 0.5586104]]]]).astype(np.float32) + grad_ = np.array([[[[0.42891738, 0.03434946, 0.06192983, 0.21216309, 0.37450036, 0.6619524], + [0.8583447, 0.5765161, 0.1468952, 0.9975385, 0.6908136, 0.4903796], + [0.68952006, 0.39336833, 0.9049695, 0.66886294, 0.2338471, 0.913618], + [0.0428149, 0.6243054, 0.8519898, 0.12088962, 0.9735885, 0.45661286], + [0.41563734, 0.41607043, 0.4754915, 0.32207987, 0.33823156, 0.47422352]], + + [[0.64478457, 0.22430937, 0.7682554, 0.46082005, 0.8938723, 0.20490853], + [0.44393885, 0.08278944, 0.4734108, 0.5543551, 0.39428464, 0.44424313], + [0.12612297, 0.76566416, 0.71133816, 0.81280327, 0.20583127, 0.54058075], + [0.41341263, 0.48118508, 0.00401995, 0.37259838, 0.05435474, 0.5240658], + [0.4081956, 0.48718935, 0.9132831, 0.67969185, 0.0119757, 0.8328054]], + + [[0.91695577, 0.95370644, 0.263782, 0.7477626, 0.6448147, 0.8080634], + [0.15576603, 0.9104615, 0.3778708, 0.6912833, 0.2092224, 0.67462957], + [0.7087075, 0.7888326, 0.4672294, 0.98221505, 0.25210258, 0.98920417], + [0.7466197, 0.22702982, 0.01991269, 0.6846591, 0.7515228, 0.5890395], + [0.04531088, 0.21740614, 0.8406235, 0.36480767, 0.37733936, 0.02914464]], + + [[0.33069974, 0.5497569, 0.9896345, 0.4167176, 0.78057563, 0.04659131], + [0.7747768, 0.21427679, 0.29893255, 0.7706969, 0.9755185, 0.42388415], + [0.3910244, 0.39381978, 0.37065396, 0.15558061, 0.05012341, 0.15870963], + [0.17791101, 0.47219893, 0.13899496, 0.32323205, 0.3628809, 0.02580585], + [0.30274773, 0.62890774, 0.11024303, 0.6980051, 0.35346958, 0.062852]]], + + [[[0.6925081, 0.74668753, 0.80145043, 0.06598313, 0.665123, 0.15073007], + [0.11784806, 0.6385372, 0.5228278, 0.5349848, 0.84671104, 0.8096436], + [0.09516156, 0.63298017, 0.52382874, 0.36734378, 0.66497755, 0.6019127], + [0.46438488, 0.0194377, 0.9388292, 0.7286089, 0.29178405, 0.11872514], + [0.22101837, 0.6164887, 0.6139798, 0.11711904, 0.6227745, 0.09701069]], + + [[0.80480653, 0.90034056, 0.8633447, 0.97415197, 0.08309154, 0.8446033], + [0.9473769, 0.791024, 0.26339203, 0.01155075, 0.2673186, 0.7116369], + [0.9687511, 0.24281934, 0.37777108, 0.09802654, 0.2421312, 0.87095344], + [0.6311381, 0.23368953, 0.0998995, 0.4364419, 0.9187446, 0.5043872], + [0.35226053, 0.09357589, 0.41317305, 0.85930043, 0.16249318, 0.5478765]], + + [[0.14338651, 0.24859418, 0.4246941, 0.73034066, 0.47172204, 0.8717199], + [0.05415315, 0.78556925, 0.99214983, 0.7415298, 0.673708, 0.87817156], + [0.616975, 0.42843062, 0.05179814, 0.1566958, 0.04536059, 0.70166487], + [0.15493333, 0.776598, 0.4361967, 0.40253627, 0.89210516, 0.8144414], + [0.04816005, 0.29696834, 0.4586605, 0.3419852, 0.5595613, 0.74093205]], + + [[0.1388035, 0.9168704, 0.64287645, 0.83864623, 0.48026922, 0.78323376], + [0.12724937, 0.83034366, 0.42557436, 0.50578654, 0.25630295, 0.15349793], + [0.27256685, 0.04547984, 0.5385756, 0.39270344, 0.7661698, 0.23722854], + [0.24620503, 0.25431684, 0.71564585, 0.01161419, 0.846467, 0.7043044], + [0.63272387, 0.11857849, 0.3772076, 0.16758402, 0.46743023, 0.05919575]]], + + [[[0.18827082, 0.8912264, 0.6841404, 0.74436826, 0.9582085, 0.1083683], + [0.60695344, 0.09742349, 0.25074378, 0.87940735, 0.21116392, 0.39418384], + [0.744686, 0.35679692, 0.01308284, 0.45166633, 0.68166, 0.8634658], + [0.7331758, 0.21113694, 0.3935488, 0.87934476, 0.70728546, 0.09309767], + [0.12128611, 0.93696386, 0.81177396, 0.85402405, 0.5827289, 0.9776509]], + + [[0.54069614, 0.66651285, 0.10646132, 0.17342485, 0.88795924, 0.03551182], + [0.25531697, 0.87946486, 0.74267226, 0.89230734, 0.95171434, 0.94697934], + [0.3708397, 0.507355, 0.97099817, 0.4918163, 0.17212386, 0.5008048], + [0.62530744, 0.25210327, 0.73966664, 0.71555346, 0.82484317, 0.6094874], + [0.4589691, 0.1386695, 0.27448782, 0.20373994, 0.27805242, 0.23292768]], + + [[0.7414099, 0.2270226, 0.90431255, 0.47035843, 0.9581062, 0.5359226], + [0.79603523, 0.45549425, 0.80858237, 0.7705133, 0.017761, 0.98001194], + [0.06013146, 0.99240226, 0.33515573, 0.04110833, 0.41470334, 0.7130743], + [0.5687417, 0.5788611, 0.00722461, 0.6603336, 0.3420471, 0.75181854], + [0.4699261, 0.51390815, 0.343182, 0.81498754, 0.8942413, 0.46532857]], + + [[0.4589523, 0.5534698, 0.2825786, 0.8205943, 0.78258514, 0.43154418], + [0.27020997, 0.01667354, 0.60871965, 0.90670526, 0.3208025, 0.96995634], + [0.85337156, 0.9711295, 0.1381724, 0.53670496, 0.7347996, 0.73380876], + [0.6137464, 0.54751194, 0.9037335, 0.23134394, 0.61411524, 0.26583543], + [0.70770144, 0.01813207, 0.24718016, 0.70329237, 0.7062925, 0.14399007]]]]).astype(np.float32) + output = gen_data(inputA, inputB, grad_) + expect0 = np.array([[[[5.7664223], [6.9810176], [2.6029902], [2.7598205], [6.763105]]], + [[[10.065580], [12.077245], [9.3383940], [11.522709], [8.889048]]], + [[[3.5789766], [13.424448], [8.7327460], [6.9677467], [9.635764]]]], np.float32) + expect1 = np.array([[[[0., 4.2504573, 2.5030296, 3.623167, 6.417151, 7.2115746]], + [[0., 4.3674493, 2.8031523, 2.5352, 0., 0.]], + [[0.7087075, 0., 2.040332, 2.1372325, 0., 2.9222295]], + [[1.0278877, 5.247942, 2.6855955, 5.494814, 3.565799, 0.66265094]]]], np.float32) + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(np.abs(output[0].asnumpy() - expect0) < error0) + assert np.all(np.abs(output[1].asnumpy() - expect1) < error1)