From: @TFbunny Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -19,37 +19,37 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, bool) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, int8_t) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, int16_t) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, int64_t) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, uint8_t) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| PrintGpuKernel, bool) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||
| PrintGpuKernel, uint16_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||
| PrintGpuKernel, uint32_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||
| PrintGpuKernel, uint64_t) | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, uint16_t) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, uint32_t) | |||
| MS_REG_GPU_KERNEL_ONE(Print, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, uint64_t) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| PrintGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -17,11 +17,17 @@ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ir/tensor.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| using mindspore::tensor::Tensor; | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| @@ -37,19 +43,42 @@ class PrintGpuKernel : public GpuKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| VARIABLE_NOT_USED(outputs); | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| input_device_data_[i] = GetDeviceAddress<T>(inputs, i); | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudaMemcpy(&input_host_data_[0], &input_device_data_[0], input_size_ * sizeof(T), cudaMemcpyDeviceToHost), | |||
| "cudaMemcpy output failed"); | |||
| for (size_t i = 0; i < input_num_.size(); i++) { | |||
| for (size_t j = 0; j < input_num_[i]; j++) { | |||
| std::cout << input_host_data_[i][j]; | |||
| } | |||
| int *output_address = GetDeviceAddress<int>(outputs, 0); | |||
| // host initialization | |||
| std::vector<std::unique_ptr<T[]> > input_host_data; | |||
| for (size_t i = 0; i < input_size_.size(); i++) { | |||
| std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]); | |||
| input_host_data.push_back(std::move(value)); | |||
| } | |||
| // check type | |||
| T type_value = static_cast<T>(0.0f); | |||
| auto type_id = CheckType(type_value); | |||
| if (type_id == kTypeUnknown) { | |||
| MS_LOG(EXCEPTION) << "GPU print does not support the input type."; | |||
| } | |||
| // print core function | |||
| for (size_t i = 0; i < input_host_data.size(); i++) { | |||
| std::string error_msg = "cudaMemcpy print loop failed at input_device_data["; | |||
| error_msg.append(std::to_string(i)); | |||
| error_msg.append("]."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost), | |||
| error_msg); | |||
| ShapeVector shape; | |||
| (void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape), | |||
| [](const size_t &value) { return static_cast<int64_t>(value); }); | |||
| Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T)); | |||
| std::cout << current_tensor.ToString() << std::endl; | |||
| } | |||
| int output = 1; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(output_address, &output, sizeof(int), cudaMemcpyHostToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync output failed"); | |||
| return true; | |||
| } | |||
| @@ -57,38 +86,70 @@ class PrintGpuKernel : public GpuKernel { | |||
| kernel_node_ = kernel_node; | |||
| size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| input_device_data_ = std::make_unique<T *[]>(input_tensor_num); | |||
| input_host_data_ = std::make_unique<T *[]>(input_tensor_num); | |||
| std::vector<size_t> value_shape; | |||
| for (size_t i = 0; i < input_tensor_num; i++) { | |||
| size_t counter = 0; | |||
| size_t value = 1; | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i); | |||
| for (size_t j = 0; j < input_shape.size(); j++) { | |||
| input_size_ *= input_shape[j]; | |||
| counter++; | |||
| value *= input_shape[j]; | |||
| value_shape.push_back(input_shape[j]); | |||
| } | |||
| input_num_.push_back(counter); | |||
| input_size_.push_back(value); | |||
| input_shape_.push_back(value_shape); | |||
| value_shape.clear(); | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| input_size_ = 1; | |||
| input_device_data_ = nullptr; | |||
| input_host_data_ = nullptr; | |||
| input_num_.clear(); | |||
| input_size_.clear(); | |||
| input_shape_.clear(); | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); } | |||
| void InitSizeLists() override { | |||
| for (size_t i = 0; i < input_size_.size(); i++) { | |||
| input_size_list_.push_back(input_size_[i] * sizeof(T)); | |||
| } | |||
| output_size_list_.push_back(sizeof(int)); | |||
| } | |||
| TypeId CheckType(T value) { | |||
| if (std::is_same<T, bool>::value) { | |||
| return kNumberTypeBool; | |||
| } else if (std::is_same<T, int8_t>::value) { | |||
| return kNumberTypeInt8; | |||
| } else if (std::is_same<T, int16_t>::value) { | |||
| return kNumberTypeInt16; | |||
| } else if (std::is_same<T, int>::value) { | |||
| return kNumberTypeInt32; | |||
| } else if (std::is_same<T, int64_t>::value) { | |||
| return kNumberTypeInt64; | |||
| } else if (std::is_same<T, uint8_t>::value) { | |||
| return kNumberTypeUInt8; | |||
| } else if (std::is_same<T, uint16_t>::value) { | |||
| return kNumberTypeUInt16; | |||
| } else if (std::is_same<T, uint32_t>::value) { | |||
| return kNumberTypeUInt32; | |||
| } else if (std::is_same<T, uint64_t>::value) { | |||
| return kNumberTypeUInt64; | |||
| } else if (std::is_same<T, half>::value) { | |||
| return kNumberTypeFloat16; | |||
| } else if (std::is_same<T, float>::value) { | |||
| return kNumberTypeFloat32; | |||
| } | |||
| return kTypeUnknown; | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| std::unique_ptr<T *[]> input_device_data_; | |||
| std::unique_ptr<T *[]> input_host_data_; | |||
| std::vector<size_t> input_num_; | |||
| std::vector<size_t> input_size_; | |||
| std::vector<std::vector<size_t> > input_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -341,10 +341,11 @@ class Print(PrimitiveWithInfer): | |||
| In pynative mode, please use python print function. | |||
| In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print, | |||
| str remains unchanged. | |||
| In GPU, all input elements should be the same type and string is not supported. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to. | |||
| Supports multiple inputs which are separated by ','. | |||
| Supports multiple inputs which are separated by ','. GPU does not support string as an input. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| @@ -0,0 +1,135 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| import mindspore.context as context | |||
| class PrintNetOneInput(nn.Cell): | |||
| def __init__(self): | |||
| super(PrintNetOneInput, self).__init__() | |||
| self.op = P.Print() | |||
| def construct(self, x): | |||
| self.op(x) | |||
| return x | |||
| class PrintNetTwoInputs(nn.Cell): | |||
| def __init__(self): | |||
| super(PrintNetTwoInputs, self).__init__() | |||
| self.op = P.Print() | |||
| def construct(self, x, y): | |||
| self.op(x, y) | |||
| return x | |||
| def print_testcase(nptype): | |||
| # large shape | |||
| x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype) | |||
| # small shape | |||
| y = np.arange(9).reshape(3, 3).astype(nptype) | |||
| x = Tensor(x) | |||
| y = Tensor(y) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net_1 = PrintNetOneInput() | |||
| net_2 = PrintNetTwoInputs() | |||
| net_1(x) | |||
| net_2(x, y) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_bool(): | |||
| print_testcase(np.bool) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_int8(): | |||
| print_testcase(np.int8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_int16(): | |||
| print_testcase(np.int16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_int32(): | |||
| print_testcase(np.int32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_int64(): | |||
| print_testcase(np.int64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_uint8(): | |||
| print_testcase(np.uint8) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_uint16(): | |||
| print_testcase(np.uint16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_uint32(): | |||
| print_testcase(np.uint32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_uint64(): | |||
| print_testcase(np.uint64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_float16(): | |||
| print_testcase(np.float16) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_float32(): | |||
| print_testcase(np.float32) | |||