From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosmantags/v1.1.0
| @@ -19,13 +19,28 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, int32_t) | |||||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, bool) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, int32_t, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, half, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, float, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), | |||||
| DynamicShapeGpuKernel, bool, int32_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), | |||||
| DynamicShapeGpuKernel, int32_t, int64_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), | |||||
| DynamicShapeGpuKernel, half, int64_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), | |||||
| DynamicShapeGpuKernel, float, int64_t) | |||||
| MS_REG_GPU_KERNEL_TWO(DynamicShape, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), | |||||
| DynamicShapeGpuKernel, bool, int64_t) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | |||||
| template <typename T, typename S> | |||||
| class DynamicShapeGpuKernel : public GpuKernel { | class DynamicShapeGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| DynamicShapeGpuKernel() { ResetResource(); } | DynamicShapeGpuKernel() { ResetResource(); } | ||||
| @@ -38,8 +38,8 @@ class DynamicShapeGpuKernel : public GpuKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| int *output_device_address = GetDeviceAddress<int>(outputs, 0); | |||||
| size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(int); | |||||
| S *output_device_address = GetDeviceAddress<S>(outputs, 0); | |||||
| size_t prev_node_output_shape_size = prev_node_output_shape_.size() * sizeof(S); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | CHECK_CUDA_RET_WITH_EXCEPT( | ||||
| cudaMemcpyAsync(output_device_address, prev_node_output_shape_.data(), prev_node_output_shape_size, | cudaMemcpyAsync(output_device_address, prev_node_output_shape_.data(), prev_node_output_shape_size, | ||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | ||||
| @@ -58,9 +58,10 @@ class DynamicShapeGpuKernel : public GpuKernel { | |||||
| input_size_ = 1; | input_size_ = 1; | ||||
| for (const size_t &e : prev_node_output_shape_tmp) { | for (const size_t &e : prev_node_output_shape_tmp) { | ||||
| input_size_ *= e; | input_size_ *= e; | ||||
| // shapes are Tensors with elements of type int32, but GetPrevNodeOutputInferShape returns vector of size_t, | |||||
| // so we use an int* for allocated output memory and cast to an int here, otherwise the memcpy will fail with a | |||||
| // silently. | |||||
| // shapes are Tensors with elements of type S (int32, or int64) but | |||||
| // GetPrevNodeOutputInferShape returns vector of size_t, so we use | |||||
| // an S* for allocated output memory and cast to an integral type here, | |||||
| // otherwise the memcpy will fail silently. | |||||
| prev_node_output_shape_.push_back(e); | prev_node_output_shape_.push_back(e); | ||||
| } | } | ||||
| @@ -83,13 +84,13 @@ class DynamicShapeGpuKernel : public GpuKernel { | |||||
| protected: | protected: | ||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| input_size_list_.push_back(input_size_ * sizeof(T)); | input_size_list_.push_back(input_size_ * sizeof(T)); | ||||
| output_size_list_.push_back(output_size_ * sizeof(int)); | |||||
| output_size_list_.push_back(output_size_ * sizeof(S)); | |||||
| } | } | ||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| size_t output_size_; | size_t output_size_; | ||||
| std::vector<int> prev_node_output_shape_; | |||||
| std::vector<S> prev_node_output_shape_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| @@ -0,0 +1,117 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| class DynamicShapeNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(DynamicShapeNet, self).__init__() | |||||
| self.convert_to_dynamic_shape_op = inner.GpuConvertToDynamicShape() | |||||
| self.dynamic_shape_op = P.DynamicShape() | |||||
| def construct(self, x): | |||||
| x_dynamic_shape = self.convert_to_dynamic_shape_op(x) | |||||
| return self.dynamic_shape_op(x_dynamic_shape) | |||||
| def dynamic_shape(np_type): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| dynamic_shape_net = DynamicShapeNet() | |||||
| shape = (1,) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (7,) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (1, 1) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (1, 7) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (3, 1) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (2, 4) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (1, 1, 1) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (1, 5, 3) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| shape = (2, 3, 1, 3, 1) | |||||
| x = Tensor(np.zeros(shape).astype(np_type)) | |||||
| ms_out = dynamic_shape_net(x).asnumpy() | |||||
| expected = np.array(shape) | |||||
| np.testing.assert_array_equal(ms_out, expected) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_shape_int32(): | |||||
| dynamic_shape(np.int32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_shape_float16(): | |||||
| dynamic_shape(np.float16) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_shape_float32(): | |||||
| dynamic_shape(np.float32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dynamic_shape_bool(): | |||||
| dynamic_shape(np.bool) | |||||