You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel.cc 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/kernel.h"
  17. #include <algorithm>
  18. #include <stack>
  19. #include <utility>
  20. #include "utils/ms_context.h"
  21. #include "utils/anf_utils.h"
  22. #include "utils/ms_device_shape_transfer.h"
  23. #include "backend/common/session/anf_runtime_algorithm.h"
  24. #include "backend/common/optimizer/helper.h"
  25. namespace mindspore {
  26. namespace kernel {
  27. constexpr int64_t kInvalidShape = -2;
  28. void KernelMod::InferShape() {
  29. auto node = anf_node_.lock();
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (!node->isa<CNode>()) {
  32. MS_LOG(EXCEPTION) << "anfnode is not a cnode";
  33. }
  34. auto cnode = node->cast<CNodePtr>();
  35. MS_EXCEPTION_IF_NULL(cnode);
  36. MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope();
  37. GetDepndLists(cnode);
  38. auto ret = InferShapeForDefiniteOutputNode(cnode);
  39. if (ret) {
  40. return;
  41. }
  42. depend_tensor_map_.clear();
  43. auto inputs = cnode->inputs();
  44. if (inputs.empty()) {
  45. MS_LOG(EXCEPTION) << "Invalid inputs";
  46. }
  47. auto context = MsContext::GetInstance();
  48. MS_EXCEPTION_IF_NULL(context);
  49. AbstractBasePtrList args_spec_list;
  50. auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
  51. auto input_size = AnfAlgo::GetInputTensorNum(cnode);
  52. std::vector<AnfNodePtr> input_nodes;
  53. for (size_t i = 0; i < input_size; i++) {
  54. auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
  55. auto real_input = input_node_with_index.first;
  56. MS_EXCEPTION_IF_NULL(real_input);
  57. auto cnode_input = cnode->input(i + 1);
  58. MS_EXCEPTION_IF_NULL(cnode_input);
  59. InferShapeForNopNode(&real_input);
  60. if (depend_list_.find(i) != depend_list_.end()) {
  61. auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
  62. bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
  63. auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, i, skip_nop_node);
  64. std::vector<int64_t> shapes =
  65. trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
  66. auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
  67. auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
  68. MS_EXCEPTION_IF_NULL(out_tensor);
  69. // The second parameter must be false, otherwise the device address cannot be released and allocated, and the
  70. // address size will be wrong in the dynamic shape scenario.
  71. out_tensor->set_device_address(output_addr, false);
  72. auto ret2 = depend_tensor_map_.try_emplace(i, out_tensor);
  73. if (!ret2.second) {
  74. MS_LOG(EXCEPTION) << "Insert map failed";
  75. }
  76. out_tensor->data_sync();
  77. auto lock = AnfUtils::GetAbstractLock(real_input.get());
  78. MS_EXCEPTION_IF_NULL(real_input->abstract());
  79. auto real_abs = real_input->abstract()->Clone();
  80. if (real_abs->isa<abstract::AbstractTensor>()) {
  81. real_abs->set_value(out_tensor);
  82. } else if (real_abs->isa<abstract::AbstractTuple>()) {
  83. auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
  84. auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
  85. MS_EXCEPTION_IF_NULL(abstract_tuple);
  86. auto tuple_elements = abstract_tuple->elements()[tuple_get_item_index];
  87. tuple_elements->set_value(out_tensor);
  88. }
  89. real_input->set_abstract(real_abs);
  90. }
  91. bool is_cnode_input = AnfAlgo::AddArgList(&args_spec_list, cnode_input, real_input, i);
  92. if (is_cnode_input) {
  93. input_nodes.push_back(cnode_input);
  94. } else {
  95. input_nodes.push_back(real_input);
  96. }
  97. }
  98. std::vector<AbstractScope> locks;
  99. std::transform(input_nodes.begin(), input_nodes.end(), std::back_inserter(locks),
  100. [](const AnfNodePtr &input) { return AnfUtils::GetAbstractLock(input.get()); });
  101. auto eval_result = opt::CppInferShape(primitive, args_spec_list);
  102. locks.clear();
  103. // cppcheck-suppress unreadVariable
  104. auto lock = AnfUtils::GetAbstractLock(cnode.get());
  105. cnode->set_abstract(eval_result);
  106. }
  107. bool KernelMod::InferShapeForDefiniteOutputNode(const CNodePtr &cnode) {
  108. MS_EXCEPTION_IF_NULL(cnode);
  109. if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimShape)) {
  110. return false;
  111. }
  112. auto input_size = AnfAlgo::GetInputTensorNum(cnode);
  113. if (input_size != 1) {
  114. MS_LOG(EXCEPTION) << "Node only has one input: " << cnode->fullname_with_scope();
  115. }
  116. auto cur_shape = dynamic_cast<mindspore::abstract::Shape *>(cnode->Shape().get())->shape();
  117. if (std::any_of(cur_shape.begin(), cur_shape.end(), [](int64_t x) { return x == kInvalidShape; })) {
  118. return false;
  119. }
  120. std::vector<int64_t> output_shape = {static_cast<int64_t>(cur_shape.size())};
  121. mindspore::abstract::BaseShapePtr shape = std::make_shared<mindspore::abstract::Shape>(output_shape);
  122. auto lock = AnfUtils::GetAbstractLock(cnode.get());
  123. auto abstract = cnode->abstract()->Clone();
  124. MS_EXCEPTION_IF_NULL(abstract);
  125. abstract->set_shape(shape);
  126. cnode->set_abstract(abstract);
  127. return true;
  128. }
  129. void KernelMod::InferShapeForNopNode(AnfNodePtr *input_node) {
  130. MS_EXCEPTION_IF_NULL(*input_node);
  131. if (!opt::IsNopNode(*input_node) || !AnfAlgo::IsDynamicShape(*input_node)) {
  132. MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
  133. return;
  134. }
  135. MS_LOG(INFO) << "Infer shape for nop node.";
  136. std::stack<AnfNodePtr> nop_road;
  137. nop_road.push(*input_node);
  138. /*lint -e716*/
  139. while (true) {
  140. auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
  141. auto in_node = input_node_with_idx.first;
  142. MS_EXCEPTION_IF_NULL(in_node);
  143. if (opt::IsNopNode(in_node)) {
  144. nop_road.push(in_node);
  145. *input_node = in_node;
  146. } else {
  147. break;
  148. }
  149. }
  150. /*lint +e716*/
  151. while (!nop_road.empty()) {
  152. auto nop_node = nop_road.top();
  153. MS_EXCEPTION_IF_NULL(nop_node);
  154. AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
  155. nop_road.pop();
  156. }
  157. }
  158. void KernelMod::GetDepndLists(const CNodePtr &cnode) {
  159. MS_EXCEPTION_IF_NULL(cnode);
  160. if (depend_list_.size() != 0) {
  161. return;
  162. }
  163. auto ret = abstract::GetDependsFormMap(cnode);
  164. if (ret.empty()) {
  165. MS_LOG(DEBUG) << "No dynamic_shape_depends found";
  166. return;
  167. }
  168. MS_LOG(INFO) << "Have depends";
  169. (void)std::transform(ret.begin(), ret.end(), std::inserter(depend_list_, depend_list_.begin()),
  170. [](const int64_t &value) { return static_cast<int>(value); });
  171. MS_LOG(INFO) << "Init End";
  172. }
  173. } // namespace kernel
  174. } // namespace mindspore