From 35ff7c58f2f7c68cd2142aef69ab7e89f3f7924e Mon Sep 17 00:00:00 2001 From: zhuyuxiao Date: Wed, 23 Dec 2020 10:37:58 +0800 Subject: [PATCH] add pack for cpu kernel; bugfix adagrad gpu --- .../kernel_compiler/cpu/pack_cpu_kernel.cc | 111 ++++++++++++++++++ .../kernel_compiler/cpu/pack_cpu_kernel.h | 82 +++++++++++++ .../gpu/nn/adagrad_gpu_kernel.h | 1 + mindspore/ops/operations/array_ops.py | 4 +- tests/st/ops/cpu/test_pack_op.py | 100 ++++++++++++++++ 5 files changed, 296 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_pack_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc new file mode 100644 index 0000000000..9431f88578 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/pack_cpu_kernel.h" +#include +#include + +namespace mindspore { +namespace kernel { +template +PackCpuFwdKernel::PackCpuFwdKernel() + : axis_(0), input_num_(1), output_size_(0), dims_behind_axis_(1), inputs_host_(nullptr) {} + +template +void PackCpuFwdKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); + + if (axis_ < 0) { + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + axis_ += (SizeToInt(input_shape.size()) + 1); + } + + // calculate elements while dim >= axis + auto first_input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + for (size_t i = IntToSize(axis_); i < first_input_shape.size(); i++) { + dims_behind_axis_ *= first_input_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + output_size_ = 1; + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } +} + +template +bool PackCpuFwdKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + if (!CheckParam(outputs)) { + return false; + } + auto output = reinterpret_cast(outputs[0]->addr); + + inputs_host_ = std::make_unique(input_num_); + for (size_t i = 0; i < inputs.size(); i++) { + inputs_host_[i] = reinterpret_cast(inputs[i]->addr); + } + + // multi-threading + size_t input_size = output_size_; + size_t max_thread_num = std::max(std::thread::hardware_concurrency(), static_cast(1)); + size_t use_thread_num = + input_size < 128 * max_thread_num ? std::ceil(static_cast(input_size / 128.0)) : max_thread_num; + std::vector threads; + + if (use_thread_num < 1) { + use_thread_num = 1; + } + + threads.reserve(use_thread_num); + size_t start = 0; + size_t batch_size = (input_size + use_thread_num - 1) / use_thread_num; + + while (start < input_size) { + size_t end = (start + batch_size) > input_size ? input_size : (start + batch_size); + threads.emplace_back(std::thread(&PackCpuFwdKernel::PackTensor, this, output, start, end)); + start += batch_size; + } + + for (auto &it : threads) { + it.join(); + } + return true; +} + +template +bool PackCpuFwdKernel::CheckParam(const std::vector &outputs) { + if (outputs.size() != 1) { + MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but PackGpuFwdKernel needs 1 output."; + return false; + } + return true; +} + +template +void PackCpuFwdKernel::PackTensor(T *output, size_t start, size_t end) { + for (size_t pos = start; pos < end; ++pos) { + size_t cur_input_index = pos / dims_behind_axis_ % input_num_; + size_t cycle_len = input_num_ * dims_behind_axis_; + size_t local_index = pos / cycle_len * dims_behind_axis_ + pos % cycle_len % dims_behind_axis_; + output[pos] = inputs_host_[cur_input_index][local_index]; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.h new file mode 100644 index 0000000000..2409fd2f3a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/pack_cpu_kernel.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_PACK_CPU_KERNEL_H +#define MINDSPORE_PACK_CPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class PackCpuFwdKernel : public CPUKernel { + public: + PackCpuFwdKernel(); + ~PackCpuFwdKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + bool CheckParam(const std::vector &outputs); + void PackTensor(T *output, size_t start, size_t end); + + int axis_; + size_t input_num_; + size_t output_size_; + size_t dims_behind_axis_; + std::unique_ptr inputs_host_; +}; + +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + PackCpuFwdKernel, int8_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + PackCpuFwdKernel, int16_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + PackCpuFwdKernel, int32_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + PackCpuFwdKernel, int64_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + PackCpuFwdKernel, uint8_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + PackCpuFwdKernel, bool) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + PackCpuFwdKernel, uint16_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + PackCpuFwdKernel, uint32_t) +MS_REG_CPU_KERNEL_T(Pack, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + PackCpuFwdKernel, uint64_t) +MS_REG_CPU_KERNEL_T( + Pack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + PackCpuFwdKernel, float16) +MS_REG_CPU_KERNEL_T( + Pack, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PackCpuFwdKernel, float) +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_PACK_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h index a6cf718b64..7f9a87b5bb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h @@ -38,6 +38,7 @@ class AdagradGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + update_slots = AnfAlgo::GetNodeAttr(kernel_node, "update_slots"); if (input_num != 4) { MS_LOG(ERROR) << "Input number is " << input_num << ", but adagrad needs 4 inputs."; return false; diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2fe71579c1..405e61e007 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2227,7 +2227,7 @@ class Pack(PrimitiveWithInfer): or if the shapes of elements in input_x are not the same. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> data1 = Tensor(np.array([0, 1]).astype(np.float32)) @@ -2282,7 +2282,7 @@ class Unpack(PrimitiveWithInfer): ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)). Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> unpack = ops.Unpack() diff --git a/tests/st/ops/cpu/test_pack_op.py b/tests/st/ops/cpu/test_pack_op.py new file mode 100644 index 0000000000..7de93b5eab --- /dev/null +++ b/tests/st/ops/cpu/test_pack_op.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + + +class PackNet(nn.Cell): + def __init__(self, nptype): + super(PackNet, self).__init__() + self.pack = P.Pack(axis=2) + self.data_np = np.array([0] * 16).astype(nptype) + self.data_np = np.reshape(self.data_np, (2, 2, 2, 2)) + self.x1 = Parameter(initializer( + Tensor(self.data_np), [2, 2, 2, 2]), name='x1') + self.x2 = Parameter(initializer( + Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(nptype)), [2, 2, 2, 2]), name='x2') + + @ms_function + def construct(self): + return self.pack((self.x1, self.x2)) + + +def pack(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + pack_ = PackNet(nptype) + output = pack_() + expect = np.array([[[[[0, 0], + [0, 0]], + [[0, 1], + [2, 3]]], + [[[0, 0], + [0, 0]], + [[4, 5], + [6, 7]]]], + [[[[0, 0], + [0, 0]], + [[8, 9], + [10, 11]]], + [[[0, 0], + [0, 0]], + [[12, 13], + [14, 15]]]]]).astype(nptype) + assert (output.asnumpy() == expect).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pack_graph_float32(): + pack(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pack_graph_float16(): + pack(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pack_graph_int32(): + pack(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pack_graph_int16(): + pack(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pack_graph_uint8(): + pack(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_pack_graph_bool(): + pack(np.bool)