From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosmantags/v1.1.0
| @@ -19,13 +19,28 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, int32_t) | |||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, bool) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, int32_t, int32_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, half, int32_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, float, int32_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), | |||
| DynamicShapeGpuKernel, bool, int32_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), | |||
| DynamicShapeGpuKernel, int32_t, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), | |||
| DynamicShapeGpuKernel, half, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), | |||
| DynamicShapeGpuKernel, float, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), | |||
| DynamicShapeGpuKernel, bool, int64_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| template <typename T, typename S> | |||
| class DynamicShapeGpuKernel : public GpuKernel { | |||
| public: | |||
| DynamicShapeGpuKernel() { ResetResource(); } | |||
| @@ -38,8 +38,8 @@ class DynamicShapeGpuKernel : public GpuKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| int *output_device_address = GetDeviceAddress<int>(outputs, 0); | |||
| size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(int); | |||
| S *output_device_address = GetDeviceAddress<S>(outputs, 0); | |||
| size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(S); | |||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||
| cudaMemcpyAsync(output_device_address, prev_node_output_shape_.data(), prev_node_output_shape_size, | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| @@ -58,9 +58,10 @@ class DynamicShapeGpuKernel : public GpuKernel { | |||
| input_size_ = 1; | |||
| for (const size_t &e : prev_node_output_shape_tmp) { | |||
| input_size_ *= e; | |||
| // shapes are Tensors with elements of type int32, but GetPrevNodeOutputInferShape returns vector of size_t, | |||
| // so we use an int* for allocated output memory and cast to an int here, otherwise the memcpy will fail with a | |||
| // silently. | |||
| // shapes are Tensors with elements of type S (int32, or int64) but | |||
| // GetPrevNodeOutputInferShape returns vector of size_t, so we use | |||
| // an S* for allocated output memory and cast to an integral type here, | |||
| // otherwise the memcpy will fail silently. | |||
| prev_node_output_shape_.push_back(e); | |||
| } | |||
| @@ -83,13 +84,13 @@ class DynamicShapeGpuKernel : public GpuKernel { | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(output_size_ * sizeof(int)); | |||
| output_size_list_.push_back(output_size_ * sizeof(S)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| std::vector<int> prev_node_output_shape_; | |||
| std::vector<S> prev_node_output_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| class DynamicShapeNet(nn.Cell): | |||
| def __init__(self): | |||
| super(DynamicShapeNet, self).__init__() | |||
| self.convert_to_dynamic_shape_op = inner.GpuConvertToDynamicShape() | |||
| self.dynamic_shape_op = P.DynamicShape() | |||
| def construct(self, x): | |||
| x_dynamic_shape = self.convert_to_dynamic_shape_op(x) | |||
| return self.dynamic_shape_op(x_dynamic_shape) | |||
| def dynamic_shape(np_type): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dynamic_shape_net = DynamicShapeNet() | |||
| shape = (1,) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (7,) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (1, 1) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (1, 7) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (3, 1) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (2, 4) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (1, 1, 1) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (1, 5, 3) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| shape = (2, 3, 1, 3, 1) | |||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||
| ms_out = dynamic_shape_net(x).asnumpy() | |||
| expected = np.array(shape) | |||
| np.testing.assert_array_equal(ms_out, expected) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dynamic_shape_int32(): | |||
| dynamic_shape(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dynamic_shape_float16(): | |||
| dynamic_shape(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dynamic_shape_float32(): | |||
| dynamic_shape(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dynamic_shape_bool(): | |||
| dynamic_shape(np.bool) | |||