| @@ -48,7 +48,7 @@ class PrintGpuKernel : public GpuKernel { | |||
| } | |||
| int *output_address = GetDeviceAddress<int>(outputs, 0); | |||
| // host initialization | |||
| std::vector<std::unique_ptr<T[]> > input_host_data; | |||
| std::vector<std::unique_ptr<T[]>> input_host_data; | |||
| for (size_t i = 0; i < input_size_.size(); i++) { | |||
| std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]); | |||
| input_host_data.push_back(std::move(value)); | |||
| @@ -60,19 +60,25 @@ class PrintGpuKernel : public GpuKernel { | |||
| MS_LOG(EXCEPTION) << "GPU print does not support the input type."; | |||
| } | |||
| // print core function | |||
| for (size_t i = 0; i < input_host_data.size(); i++) { | |||
| std::string error_msg = "cudaMemcpy print loop failed at input_device_data["; | |||
| error_msg.append(std::to_string(i)); | |||
| error_msg.append("]."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost), | |||
| error_msg); | |||
| ShapeVector shape; | |||
| (void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape), | |||
| [](const size_t &value) { return static_cast<int64_t>(value); }); | |||
| Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T)); | |||
| std::cout << current_tensor.ToString() << std::endl; | |||
| size_t string_idx = 0; | |||
| for (size_t i = 0; i < input_flag_.size(); i++) { | |||
| if (input_flag_[i] == -1) { | |||
| std::cout << string_value_[string_idx] << std::endl; | |||
| string_idx++; | |||
| } else { | |||
| size_t tensor_idx = LongToSize(input_flag_[i]); | |||
| std::string error_msg = "cudaMemcpyAsync print loop failed at input_device_data["; | |||
| error_msg.append(std::to_string(tensor_idx)); | |||
| error_msg.append("]."); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(input_host_data[tensor_idx].get(), input_device_data_[tensor_idx], | |||
| input_size_[tensor_idx] * sizeof(T), cudaMemcpyDeviceToHost, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| error_msg); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed - Print"); | |||
| auto current_string = GetTensorString(&input_shape_, tensor_idx, type_id, &input_host_data, &input_size_); | |||
| std::cout << current_string << std::endl; | |||
| } | |||
| } | |||
| int output = 1; | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| @@ -84,7 +90,12 @@ class PrintGpuKernel : public GpuKernel { | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| if (AnfAlgo::HasNodeAttr("string_pos", kernel_node)) { | |||
| string_value_ = GetAttr<std::vector<std::string>>(kernel_node, "string_value"); | |||
| string_pos_ = GetAttr<std::vector<int64_t>>(kernel_node, "string_pos"); | |||
| } | |||
| size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| input_flag_ = SetInputFlag(&string_pos_, input_tensor_num); | |||
| input_device_data_ = std::make_unique<T *[]>(input_tensor_num); | |||
| std::vector<size_t> value_shape; | |||
| for (size_t i = 0; i < input_tensor_num; i++) { | |||
| @@ -103,6 +114,9 @@ class PrintGpuKernel : public GpuKernel { | |||
| } | |||
| void ResetResource() noexcept override { | |||
| string_value_.clear(); | |||
| string_pos_.clear(); | |||
| input_flag_.clear(); | |||
| input_device_data_ = nullptr; | |||
| input_size_.clear(); | |||
| input_shape_.clear(); | |||
| @@ -146,14 +160,52 @@ class PrintGpuKernel : public GpuKernel { | |||
| return kTypeUnknown; | |||
| } | |||
| std::vector<int64_t> SetInputFlag(std::vector<int64_t> *string_pos, size_t input_tensor_num) { | |||
| // -1 -> string position | |||
| // others -> input tensor position | |||
| std::vector<int64_t> res(string_pos->size() + input_tensor_num); | |||
| // without string inputs | |||
| int64_t value = 0; | |||
| if (res.size() == input_tensor_num) { | |||
| std::generate(res.begin(), res.end(), [&value]() { return value++; }); | |||
| return res; | |||
| } | |||
| for (size_t i = 0; i < string_pos->size(); i++) { | |||
| if ((*string_pos)[i] < 0) { | |||
| MS_LOG(EXCEPTION) << "string_pos cannot be a negative value"; | |||
| } | |||
| auto index = IntToSize((*string_pos)[i]); | |||
| res[index] = -1; | |||
| } | |||
| for (size_t i = 0; i < res.size(); i++) { | |||
| if (res[i] != -1) { | |||
| res[i] += value; | |||
| value++; | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| std::string GetTensorString(std::vector<std::vector<size_t>> *input_shape, size_t index, TypeId type_id, | |||
| std::vector<std::unique_ptr<T[]>> *input_host_data, std::vector<size_t> *input_size) { | |||
| ShapeVector shape; | |||
| (void)std::transform((*input_shape)[index].begin(), (*input_shape)[index].end(), std::back_inserter(shape), | |||
| [](const size_t &value) { return static_cast<int64_t>(value); }); | |||
| Tensor current_tensor(type_id, shape, (*input_host_data)[index].get(), (*input_size)[index] * sizeof(T)); | |||
| return current_tensor.ToStringNoLimit(); | |||
| } | |||
| private: | |||
| std::vector<std::string> string_value_; | |||
| std::vector<int64_t> string_pos_; | |||
| std::vector<int64_t> input_flag_; | |||
| std::unique_ptr<T *[]> input_device_data_; | |||
| std::vector<size_t> input_size_; | |||
| std::vector<std::vector<size_t> > input_shape_; | |||
| std::vector<std::vector<size_t>> input_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; // namespace kernel | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,154 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/print_reduce_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| std::vector<std::string> inputs_format; | |||
| std::vector<std::string> outputs_format; | |||
| std::vector<TypeId> inputs_type; | |||
| std::vector<TypeId> outputs_type; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; input_index++) { | |||
| inputs_format.push_back(kOpFormat_DEFAULT); | |||
| inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| for (size_t output_index = 0; output_index < output_num; output_index++) { | |||
| outputs_format.push_back(kOpFormat_DEFAULT); | |||
| outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); | |||
| } | |||
| builder.SetInputsFormat(inputs_format); | |||
| builder.SetOutputsFormat(outputs_format); | |||
| builder.SetInputsDeviceType(inputs_type); | |||
| builder.SetOutputsDeviceType(outputs_type); | |||
| return builder.Build(); | |||
| } | |||
| bool GetOptList(const std::vector<AnfNodePtr> &node_list, std::vector<AnfNodePtr> *opt_list, | |||
| std::vector<std::vector<int64_t>> *string_pos_vec, | |||
| std::vector<std::vector<std::string>> *string_value_vec) { | |||
| for (auto &node : node_list) { | |||
| // {prim::kPrimPrint} only print with string will be reduced | |||
| std::vector<int64_t> string_pos; | |||
| std::vector<std::string> string_value; | |||
| if (IsPrimitiveCNode(node, prim::kPrimPrint)) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto current_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i); | |||
| // not a string | |||
| if (current_node->cast<ValueNodePtr>() == nullptr) { | |||
| continue; | |||
| } | |||
| auto value_node = current_node->cast<ValueNodePtr>()->value(); | |||
| if (value_node->type()->generic_type_id() == kObjectTypeString) { | |||
| auto current_string_value = GetValue<std::string>(value_node); | |||
| string_pos.push_back(i); | |||
| string_value.push_back(std::string(current_string_value)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Current value node is not string or tensor"; | |||
| } | |||
| } | |||
| if (string_pos.size() != 0) { | |||
| opt_list->push_back(node); | |||
| string_pos_vec->push_back(string_pos); | |||
| string_value_vec->push_back(string_value); | |||
| } | |||
| } | |||
| } | |||
| if (opt_list->size() == 0) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool PrintReduceFusion::Run(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return()); | |||
| std::vector<AnfNodePtr> opt_list; | |||
| std::vector<std::vector<int64_t>> string_pos_vec; | |||
| std::vector<std::vector<std::string>> string_value_vec; | |||
| if (!GetOptList(node_list, &opt_list, &string_pos_vec, &string_value_vec)) { | |||
| return false; | |||
| } | |||
| for (size_t idx = 0; idx < opt_list.size(); idx++) { | |||
| auto node = opt_list[idx]; | |||
| CNodePtr cnode = utils::cast<CNodePtr>(node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| auto prim = std::make_shared<Primitive>("Print"); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim)}; | |||
| auto string_pos = string_pos_vec[idx]; | |||
| std::vector<int64_t> input_flag(input_num); | |||
| for (size_t i = 0; i < string_pos.size(); i++) { | |||
| if (string_pos[i] < 0) { | |||
| MS_LOG(EXCEPTION) << "string_pos cannot be a negative value"; | |||
| } | |||
| size_t index = LongToSize(string_pos[i]); | |||
| input_flag[index] = -1; | |||
| } | |||
| for (size_t i = 0; i < input_flag.size(); i++) { | |||
| if (input_flag[i] == -1) { | |||
| continue; | |||
| } | |||
| auto input_tensor = AnfAlgo::GetInputNode(cnode, i); | |||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||
| inputs.push_back(input_tensor); | |||
| } | |||
| // add monad | |||
| auto monad_node = AnfAlgo::GetInputNode(cnode, input_flag.size()); | |||
| MS_EXCEPTION_IF_NULL(monad_node); | |||
| inputs.push_back(monad_node); | |||
| auto string_value = string_value_vec[idx]; | |||
| // create new cnode | |||
| auto print_fused = graph->NewCNode(inputs); | |||
| // hand over the attrs to new print | |||
| AnfAlgo::SetNodeAttr("string_pos", MakeValue<std::vector<int64_t>>(string_pos), print_fused); | |||
| AnfAlgo::SetNodeAttr("string_value", MakeValue<std::vector<std::string>>(string_value), print_fused); | |||
| // set output type and shape | |||
| std::vector<TypeId> types; | |||
| std::vector<std::vector<size_t>> shapes; | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| types.push_back(AnfAlgo::GetOutputInferDataType(cnode, i)); | |||
| shapes.push_back(AnfAlgo::GetOutputInferShape(cnode, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, print_fused.get()); | |||
| // add build info | |||
| auto build_info = GenerateKernelBuildInfo(print_fused); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, print_fused.get()); | |||
| if (!manager->Replace(cnode, print_fused)) { | |||
| MS_LOG(EXCEPTION) << "manager replace node failed in print reduce fusion."; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class PrintReduceFusion : public Pass { | |||
| public: | |||
| explicit PrintReduceFusion(const std::string &name) : Pass("print_reduce") {} | |||
| ~PrintReduceFusion() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_PRINT_REDUCE_FUSION_H_ | |||
| @@ -37,6 +37,7 @@ | |||
| #include "backend/optimizer/gpu/insert_format_transform_op.h" | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||
| #include "backend/optimizer/gpu/print_reduce_fusion.h" | |||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | |||
| #include "backend/optimizer/gpu/reduce_precision_fusion.h" | |||
| @@ -141,6 +142,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum")); | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||
| pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce")); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -346,11 +346,10 @@ class Print(PrimitiveWithInfer): | |||
| In pynative mode, please use python print function. | |||
| In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print, | |||
| str remains unchanged. | |||
| In GPU, all input elements should be the same type and string is not supported. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to. | |||
| Supports multiple inputs which are separated by ','. GPU does not support string as an input. | |||
| Supports multiple inputs which are separated by ','. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| @@ -71,6 +71,27 @@ def print_testcase(nptype): | |||
| net_2(x, y) | |||
| net_3(x) | |||
| class PrintNetString(nn.Cell): | |||
| def __init__(self): | |||
| super(PrintNetString, self).__init__() | |||
| self.op = P.Print() | |||
| def construct(self, x, y): | |||
| self.op("The first Tensor is", x) | |||
| self.op("The second Tensor is", y) | |||
| self.op("This line only prints string", "Another line") | |||
| self.op("The first Tensor is", x, y, "is the second Tensor") | |||
| return x | |||
| def print_testcase_string(nptype): | |||
| x = np.ones(18).astype(nptype) | |||
| y = np.arange(9).reshape(3, 3).astype(nptype) | |||
| x = Tensor(x) | |||
| y = Tensor(y) | |||
| # graph mode | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = PrintNetString() | |||
| net(x, y) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -147,3 +168,10 @@ def test_print_float16(): | |||
| @pytest.mark.env_onecard | |||
| def test_print_float32(): | |||
| print_testcase(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_print_string(): | |||
| print_testcase_string(np.float32) | |||