diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.cc index 4f571090fa..d45feea854 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.cc @@ -19,13 +19,28 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - DynamicShapeGpuKernel, int32_t) -MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - DynamicShapeGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - DynamicShapeGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - DynamicShapeGpuKernel, bool) +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + DynamicShapeGpuKernel, int32_t, int32_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + DynamicShapeGpuKernel, half, int32_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + DynamicShapeGpuKernel, float, int32_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), + DynamicShapeGpuKernel, bool, int32_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + DynamicShapeGpuKernel, int32_t, int64_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + DynamicShapeGpuKernel, half, int64_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + DynamicShapeGpuKernel, float, int64_t) + +MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), + DynamicShapeGpuKernel, bool, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.h index 20b4f60a9b..3721621444 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/dynamic_shape_gpu_kernel.h @@ -26,7 +26,7 @@ namespace mindspore { namespace kernel { -template +template class DynamicShapeGpuKernel : public GpuKernel { public: DynamicShapeGpuKernel() { ResetResource(); } @@ -38,8 +38,8 @@ class DynamicShapeGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - int *output_device_address = GetDeviceAddress(outputs, 0); - size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(int); + S *output_device_address = GetDeviceAddress(outputs, 0); + size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(S); CHECK_CUDA_RET_WITH_EXCEPT( cudaMemcpyAsync(output_device_address, prev_node_output_shape_.data(), prev_node_output_shape_size, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), @@ -58,9 +58,10 @@ class DynamicShapeGpuKernel : public GpuKernel { input_size_ = 1; for (const size_t &e : prev_node_output_shape_tmp) { input_size_ *= e; - // shapes are Tensors with elements of type int32, but GetPrevNodeOutputInferShape returns vector of size_t, - // so we use an int* for allocated output memory and cast to an int here, otherwise the memcpy will fail with a - // silently. + // shapes are Tensors with elements of type S (int32, or int64) but + // GetPrevNodeOutputInferShape returns vector of size_t, so we use + // an S* for allocated output memory and cast to an integral type here, + // otherwise the memcpy will fail silently. prev_node_output_shape_.push_back(e); } @@ -83,13 +84,13 @@ class DynamicShapeGpuKernel : public GpuKernel { protected: void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); - output_size_list_.push_back(output_size_ * sizeof(int)); + output_size_list_.push_back(output_size_ * sizeof(S)); } private: size_t input_size_; size_t output_size_; - std::vector prev_node_output_shape_; + std::vector prev_node_output_shape_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/tests/st/ops/gpu/test_dynamic_shape_op.py b/tests/st/ops/gpu/test_dynamic_shape_op.py new file mode 100644 index 0000000000..a69202290e --- /dev/null +++ b/tests/st/ops/gpu/test_dynamic_shape_op.py @@ -0,0 +1,117 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner +import mindspore.nn as nn +import mindspore.context as context + +class DynamicShapeNet(nn.Cell): + def __init__(self): + super(DynamicShapeNet, self).__init__() + self.convert_to_dynamic_shape_op = inner.GpuConvertToDynamicShape() + self.dynamic_shape_op = P.DynamicShape() + + def construct(self, x): + x_dynamic_shape = self.convert_to_dynamic_shape_op(x) + return self.dynamic_shape_op(x_dynamic_shape) + + +def dynamic_shape(np_type): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + dynamic_shape_net = DynamicShapeNet() + + shape = (1,) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (7,) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (1, 1) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (1, 7) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (3, 1) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (2, 4) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (1, 1, 1) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (1, 5, 3) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + + shape = (2, 3, 1, 3, 1) + x = Tensor(np.zeros(shape).astype(np_type)) + ms_out = dynamic_shape_net(x).asnumpy() + expected = np.array(shape) + np.testing.assert_array_equal(ms_out, expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_shape_int32(): + dynamic_shape(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_shape_float16(): + dynamic_shape(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_shape_float32(): + dynamic_shape(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_shape_bool(): + dynamic_shape(np.bool)