/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/graph_kernel/substitute_dropout.h" #include #include #include #include #include "base/core_ops.h" #include "utils/utils.h" #include "backend/common/optimizer/helper.h" #include "common/graph_kernel/graph_kernel_helper.h" #include "backend/common/session/anf_runtime_algorithm.h" #include "ir/tensor.h" #include "kernel/kernel_build_info.h" #include "runtime/device/kernel_info.h" namespace mindspore { namespace prim { inline const PrimitivePtr kPrimGkDropout = std::make_shared("GkDropout"); } // namespace prim namespace graphkernel { using opt::CheckCNodeInputSize; using opt::kDropoutInputTensorNum; int64_t DropoutExpander::seed_ = time(nullptr); AnfNodePtr DropoutExpander::PreProcess(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); CheckCNodeInputSize(cnode, kDropoutInputTensorNum); auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0); ShapeVector shape_i64; std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); // Get seed from original dropout's attrs, rather than set seed by time. // Only seed0 and seed1 are all equal to 0, then set seed = time. auto node_prim = GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(node_prim); int64_t seed = GetValue(node_prim->GetAttr("Seed0")); if (seed == 0) { seed = GetValue(node_prim->GetAttr("Seed1")); if (seed == 0) { seed = seed_++; } } // Create a uniform_real kernel to generate random value. auto tensor = std::make_shared(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), static_cast(&shape[0]), kNumberTypeInt64); AnfNodePtrList uniform_real_input = {NewValueNode(prim::kPrimCudnnUniformReal), NewValueNode(tensor)}; uniform_real_input[1]->set_abstract(tensor->ToAbstract()); uniform_real_input[1]->set_kernel_info(std::make_shared()); auto uniform_real_node = func_graph->NewCNode(uniform_real_input); SetNodeAttrSafely("seed", MakeValue(seed), uniform_real_node); AnfAlgo::SetNodeAttr("seed2", MakeValue(static_cast(0)), uniform_real_node); uniform_real_node->set_abstract(std::make_shared(kFloat32, shape_i64)); // Set kernel_info for uniform_real node auto uniform_real_kernel_info_builder = std::make_shared(); uniform_real_kernel_info_builder->SetInputsFormat({kOpFormat_DEFAULT}); uniform_real_kernel_info_builder->SetInputsDeviceType({kNumberTypeInt32}); uniform_real_kernel_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); uniform_real_kernel_info_builder->SetOutputsDeviceType({kNumberTypeFloat32}); uniform_real_kernel_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE); uniform_real_kernel_info_builder->SetProcessor(kernel::Processor::CUDA); AnfAlgo::SetSelectKernelBuildInfo(uniform_real_kernel_info_builder->Build(), uniform_real_node.get()); // Create a GKDropout node with uniform_real as its second input. AnfNodePtrList gkdropout_inputs = {NewValueNode(prim::kPrimGkDropout), cnode->input(1), uniform_real_node}; auto new_dropout_node = func_graph->NewCNode(gkdropout_inputs); SetNodeAttrSafely("keep_prob", MakeValue(AnfAlgo::GetNodeAttr(cnode, "keep_prob")), new_dropout_node); // the output info is unchanged. new_dropout_node->set_abstract(node->abstract()); auto old_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(node); auto dropout_kernel_info_builder = std::make_shared(old_kernel_info); dropout_kernel_info_builder->SetInputsFormat({old_kernel_info->GetInputFormat(0), kOpFormat_DEFAULT}); dropout_kernel_info_builder->SetInputsDeviceType({old_kernel_info->GetInputDeviceType(0), kNumberTypeFloat32}); AnfAlgo::SetSelectKernelBuildInfo(dropout_kernel_info_builder->Build(), new_dropout_node.get()); return new_dropout_node; } AnfNodePtr DropoutExpander::Run(const AnfNodePtr &node) { auto gkdropout_node = PreProcess(node->func_graph(), node); return PyExpander::Run(gkdropout_node); } } // namespace graphkernel } // namespace mindspore