You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel.cc 6.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel.h"
  17. #include <algorithm>
  18. #include <stack>
  19. #include "utils/ms_context.h"
  20. #include "utils/anf_utils.h"
  21. #include "runtime/device/ms_device_shape_transfer.h"
  22. #include "backend/common/session/anf_runtime_algorithm.h"
  23. #include "include/common/utils/anfalgo.h"
  24. #include "backend/common/optimizer/helper.h"
  25. namespace mindspore {
  26. namespace kernel {
  27. constexpr int64_t kInvalidShape = -2;
  28. void KernelMod::SetAtomicCleanNodes(const std::vector<CNodePtr> &atomic_clean_node) {
  29. atomic_clean_nodes_.resize(atomic_clean_node.size());
  30. for (size_t i = 0; i < atomic_clean_node.size(); ++i) {
  31. atomic_clean_nodes_[i] = atomic_clean_node[i];
  32. }
  33. }
  34. void KernelMod::InferShape() {
  35. auto node = anf_node_.lock();
  36. MS_EXCEPTION_IF_NULL(node);
  37. auto cnode = node->cast<CNodePtr>();
  38. MS_EXCEPTION_IF_NULL(cnode);
  39. MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope();
  40. GetDepndLists(cnode);
  41. auto ret = InferShapeForDefiniteOutputNode(cnode);
  42. if (ret) {
  43. return;
  44. }
  45. depend_tensor_map_.clear();
  46. auto &inputs = cnode->inputs();
  47. if (inputs.empty()) {
  48. MS_LOG(EXCEPTION) << "Invalid inputs.";
  49. }
  50. auto context = MsContext::GetInstance();
  51. MS_EXCEPTION_IF_NULL(context);
  52. AbstractBasePtrList args_spec_list;
  53. auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
  54. auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
  55. for (size_t i = 0; i < input_size; i++) {
  56. auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i);
  57. auto real_input = input_node_with_index.first;
  58. MS_EXCEPTION_IF_NULL(real_input);
  59. auto cnode_input = cnode->input(i + 1);
  60. MS_EXCEPTION_IF_NULL(cnode_input);
  61. InferShapeForNopNode(&real_input);
  62. if (depend_list_.find(i) != depend_list_.end()) {
  63. auto pre_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i);
  64. bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
  65. auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node);
  66. std::vector<int64_t> shapes =
  67. trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
  68. auto host_type = common::AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
  69. auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
  70. MS_EXCEPTION_IF_NULL(out_tensor);
  71. // The second parameter must be false, otherwise the device address cannot be released and allocated, and the
  72. // address size will be wrong in the dynamic shape scenario.
  73. out_tensor->set_device_address(output_addr, false);
  74. auto ret2 = depend_tensor_map_.try_emplace(i, out_tensor);
  75. if (!ret2.second) {
  76. MS_LOG(EXCEPTION) << "Insert map failed.";
  77. }
  78. out_tensor->data_sync();
  79. // cppcheck-suppress unreadVariable
  80. auto lock = AnfUtils::GetAbstractLock(real_input.get());
  81. auto real_abs = real_input->abstract();
  82. if (real_abs->isa<abstract::AbstractTensor>()) {
  83. real_abs->set_value(out_tensor);
  84. } else if (real_abs->isa<abstract::AbstractTuple>()) {
  85. auto tuple_get_item_index = common::AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
  86. auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
  87. MS_EXCEPTION_IF_NULL(abstract_tuple);
  88. auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
  89. tuple_elements->set_value(out_tensor);
  90. }
  91. }
  92. common::AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
  93. }
  94. auto eval_result = opt::CppInferShape(primitive, args_spec_list);
  95. cnode->set_abstract(eval_result);
  96. }
  97. bool KernelMod::InferShapeForDefiniteOutputNode(const CNodePtr &cnode) {
  98. MS_EXCEPTION_IF_NULL(cnode);
  99. if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimShape)) {
  100. return false;
  101. }
  102. auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
  103. if (input_size != 1) {
  104. MS_LOG(EXCEPTION) << "Node only has one input: " << cnode->fullname_with_scope();
  105. }
  106. auto cur_shape = dynamic_cast<mindspore::abstract::Shape *>(cnode->Shape().get())->shape();
  107. if (std::any_of(cur_shape.begin(), cur_shape.end(), [](int64_t x) { return x == kInvalidShape; })) {
  108. return false;
  109. }
  110. std::vector<int64_t> output_shape = {static_cast<int64_t>(cur_shape.size())};
  111. mindspore::abstract::BaseShapePtr shape = std::make_shared<mindspore::abstract::Shape>(output_shape);
  112. // cppcheck-suppress unreadVariable
  113. auto lock = AnfUtils::GetAbstractLock(cnode.get());
  114. auto abstract = cnode->abstract();
  115. MS_EXCEPTION_IF_NULL(abstract);
  116. abstract->set_shape(shape);
  117. return true;
  118. }
  119. void KernelMod::InferShapeForNopNode(AnfNodePtr *input_node) {
  120. MS_EXCEPTION_IF_NULL(*input_node);
  121. if (!common::AnfAlgo::IsNopNode(*input_node) || !common::AnfAlgo::IsDynamicShape(*input_node)) {
  122. MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
  123. return;
  124. }
  125. MS_LOG(INFO) << "Infer shape for nop node.";
  126. std::stack<AnfNodePtr> nop_road;
  127. nop_road.push(*input_node);
  128. /*lint -e716*/
  129. while (true) {
  130. auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(*input_node, 0);
  131. auto in_node = input_node_with_idx.first;
  132. MS_EXCEPTION_IF_NULL(in_node);
  133. if (common::AnfAlgo::IsNopNode(in_node)) {
  134. nop_road.push(in_node);
  135. *input_node = in_node;
  136. } else {
  137. break;
  138. }
  139. }
  140. /*lint +e716*/
  141. while (!nop_road.empty()) {
  142. auto nop_node = nop_road.top();
  143. MS_EXCEPTION_IF_NULL(nop_node);
  144. AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
  145. nop_road.pop();
  146. }
  147. }
  148. void KernelMod::GetDepndLists(const CNodePtr &cnode) {
  149. MS_EXCEPTION_IF_NULL(cnode);
  150. if (depend_list_.size() != 0) {
  151. return;
  152. }
  153. auto ret = abstract::GetDependsFormMap(cnode);
  154. if (ret.empty()) {
  155. MS_LOG(DEBUG) << "No dynamic_shape_depends found.";
  156. return;
  157. }
  158. MS_LOG(INFO) << "Have depends.";
  159. (void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()),
  160. [](const int64_t &value) { return static_cast<int>(value); });
  161. MS_LOG(INFO) << "Init End.";
  162. }
  163. } // namespace kernel
  164. } // namespace mindspore