You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

substitute_dropout.cc 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/substitute_dropout.h"
  17. #include <vector>
  18. #include <string>
  19. #include <algorithm>
  20. #include <memory>
  21. #include "base/core_ops.h"
  22. #include "utils/utils.h"
  23. #include "backend/common/optimizer/helper.h"
  24. #include "common/graph_kernel/graph_kernel_helper.h"
  25. #include "backend/common/session/anf_runtime_algorithm.h"
  26. #include "ir/tensor.h"
  27. #include "kernel/kernel_build_info.h"
  28. #include "runtime/device/kernel_info.h"
  29. namespace mindspore {
  30. namespace prim {
  31. inline const PrimitivePtr kPrimGkDropout = std::make_shared<Primitive>("GkDropout");
  32. } // namespace prim
  33. namespace graphkernel {
  34. using opt::CheckCNodeInputSize;
  35. using opt::kDropoutInputTensorNum;
  36. int64_t DropoutExpander::seed_ = time(nullptr);
  37. AnfNodePtr DropoutExpander::PreProcess(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  38. MS_EXCEPTION_IF_NULL(node);
  39. CNodePtr cnode = node->cast<CNodePtr>();
  40. MS_EXCEPTION_IF_NULL(cnode);
  41. CheckCNodeInputSize(cnode, kDropoutInputTensorNum);
  42. auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
  43. ShapeVector shape_i64;
  44. std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); });
  45. // Get seed from original dropout's attrs, rather than set seed by time.
  46. // Only seed0 and seed1 are all equal to 0, then set seed = time.
  47. auto node_prim = GetCNodePrimitive(node);
  48. MS_EXCEPTION_IF_NULL(node_prim);
  49. int64_t seed = GetValue<int64_t>(node_prim->GetAttr("Seed0"));
  50. if (seed == 0) {
  51. seed = GetValue<int64_t>(node_prim->GetAttr("Seed1"));
  52. if (seed == 0) {
  53. seed = seed_++;
  54. }
  55. }
  56. // Create a uniform_real kernel to generate random value.
  57. auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())),
  58. static_cast<void *>(&shape[0]), kNumberTypeInt64);
  59. AnfNodePtrList uniform_real_input = {NewValueNode(prim::kPrimCudnnUniformReal), NewValueNode(tensor)};
  60. uniform_real_input[1]->set_abstract(tensor->ToAbstract());
  61. uniform_real_input[1]->set_kernel_info(std::make_shared<device::KernelInfo>());
  62. auto uniform_real_node = func_graph->NewCNode(uniform_real_input);
  63. SetNodeAttrSafely("seed", MakeValue(seed), uniform_real_node);
  64. AnfAlgo::SetNodeAttr("seed2", MakeValue(static_cast<int64_t>(0)), uniform_real_node);
  65. uniform_real_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_i64));
  66. // Set kernel_info for uniform_real node
  67. auto uniform_real_kernel_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  68. uniform_real_kernel_info_builder->SetInputsFormat({kOpFormat_DEFAULT});
  69. uniform_real_kernel_info_builder->SetInputsDeviceType({kNumberTypeInt32});
  70. uniform_real_kernel_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
  71. uniform_real_kernel_info_builder->SetOutputsDeviceType({kNumberTypeFloat32});
  72. uniform_real_kernel_info_builder->SetKernelType(KernelType::UNKNOWN_KERNEL_TYPE);
  73. uniform_real_kernel_info_builder->SetProcessor(kernel::Processor::CUDA);
  74. AnfAlgo::SetSelectKernelBuildInfo(uniform_real_kernel_info_builder->Build(), uniform_real_node.get());
  75. // Create a GKDropout node with uniform_real as its second input.
  76. AnfNodePtrList gkdropout_inputs = {NewValueNode(prim::kPrimGkDropout), cnode->input(1), uniform_real_node};
  77. auto new_dropout_node = func_graph->NewCNode(gkdropout_inputs);
  78. SetNodeAttrSafely("keep_prob", MakeValue(AnfAlgo::GetNodeAttr<float>(cnode, "keep_prob")), new_dropout_node);
  79. // the output info is unchanged.
  80. new_dropout_node->set_abstract(node->abstract());
  81. auto old_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(node);
  82. auto dropout_kernel_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(old_kernel_info);
  83. dropout_kernel_info_builder->SetInputsFormat({old_kernel_info->GetInputFormat(0), kOpFormat_DEFAULT});
  84. dropout_kernel_info_builder->SetInputsDeviceType({old_kernel_info->GetInputDeviceType(0), kNumberTypeFloat32});
  85. AnfAlgo::SetSelectKernelBuildInfo(dropout_kernel_info_builder->Build(), new_dropout_node.get());
  86. return new_dropout_node;
  87. }
  88. AnfNodePtr DropoutExpander::Run(const AnfNodePtr &node) {
  89. auto gkdropout_node = PreProcess(node->func_graph(), node);
  90. return PyExpander::Run(gkdropout_node);
  91. }
  92. } // namespace graphkernel
  93. } // namespace mindspore