You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate_graph.cc 6.2 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parallel/graph_util/generate_graph.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <string>
  20. #include <utility>
  21. using mindspore::tensor::Tensor;
  22. namespace mindspore {
  23. namespace parallel {
  24. std::string GetOpPythonPath(const OperatorName &op_name) {
  25. // almost all ops are defined in two main paths
  26. const std::string ops_module = OP_PATH;
  27. const std::string inner_ops_module = INNER_OP_PATH;
  28. py::module mod = py::module::import(common::SafeCStr(ops_module));
  29. py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
  30. if (!py::hasattr(mod, common::SafeCStr(op_name))) {
  31. if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
  32. MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
  33. }
  34. return inner_ops_module;
  35. }
  36. return ops_module;
  37. }
  38. ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
  39. std::string op_path = GetOpPythonPath(op_name);
  40. py::module mod = py::module::import(common::SafeCStr(op_path));
  41. if (!py::hasattr(mod, common::SafeCStr(op_name))) {
  42. MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
  43. return nullptr;
  44. }
  45. std::vector<py::object> arg_list;
  46. (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
  47. [](const Attr &attr) { return ValuePtrToPyData(attr.second); });
  48. py::object obj =
  49. parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
  50. ValuePtr op_instance = nullptr;
  51. bool succ = parse::ConvertData(obj, &op_instance);
  52. if (!succ) {
  53. MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
  54. return nullptr;
  55. }
  56. return op_instance;
  57. }
  58. AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
  59. auto value_node = NewValueNode(value_ptr);
  60. MS_EXCEPTION_IF_NULL(value_node);
  61. return value_node->cast<AnfNodePtr>();
  62. }
  63. static std::unordered_map<int32_t, AnfNodePtr> int_tensor_map = {};
  64. AnfNodePtr CreateInt32Tensor(int32_t value) {
  65. auto it = int_tensor_map.find(value);
  66. if (it != int_tensor_map.end()) {
  67. return it->second;
  68. }
  69. mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(py::int_(value), kInt32);
  70. ValuePtr value_ptr = MakeValue(tensor_ptr);
  71. auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
  72. int_tensor_map[value] = anf_node_ptr;
  73. return anf_node_ptr;
  74. }
  75. AnfNodePtr CreatTypeInt(int32_t value) {
  76. ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value));
  77. return ValuePtrToAnfNodePtr(value_ptr);
  78. }
  79. AnfNodePtr CreatInt32Imm(int32_t value) {
  80. ValuePtr value_ptr = MakeValue(std::make_shared<Int32Imm>(value));
  81. return ValuePtrToAnfNodePtr(value_ptr);
  82. }
  83. std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
  84. PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  85. if (!prim) {
  86. MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
  87. }
  88. std::string instance_name = prim->instance_name();
  89. return HashInstanceName(instance_name);
  90. }
  91. std::string HashInstanceName(const std::string &name) {
  92. auto using_hash_name = common::GetEnv(USING_HASH_NAME);
  93. std::string instance_name;
  94. if ((using_hash_name.empty()) || (using_hash_name == "on")) {
  95. instance_name = HashName(name);
  96. } else {
  97. instance_name = name;
  98. }
  99. return instance_name;
  100. }
  101. Status GenerateGraph::Init(const CNodePtr &cnode) {
  102. if (!cnode) {
  103. MS_LOG(ERROR) << "Init:cnode is nullptr";
  104. return FAILED;
  105. }
  106. cnode_ = cnode;
  107. func_graph_ = cnode->func_graph();
  108. if (!func_graph_) {
  109. MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
  110. return FAILED;
  111. }
  112. manager_ = func_graph_->manager();
  113. if (!manager_) {
  114. MS_LOG(ERROR) << "Init:manager_ is nullptr";
  115. return FAILED;
  116. }
  117. scope_ = cnode_->scope();
  118. if (!scope_) {
  119. MS_LOG(ERROR) << "Init:scope_ is nullptr";
  120. return FAILED;
  121. }
  122. virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
  123. virtual_input_node_->set_scope(scope_);
  124. instance_name_base_ = GetInstanceNameByCNode(cnode_);
  125. name_idx_ = 0;
  126. return SUCCESS;
  127. }
  128. AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
  129. CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode
  130. MS_EXCEPTION_IF_NULL(cnode);
  131. cnode->set_scope(scope_);
  132. if (inputs.size() < 2) {
  133. MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
  134. }
  135. (void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0]
  136. auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
  137. MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
  138. return new_anf_node_ptr;
  139. }
  140. AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
  141. name_idx_++;
  142. ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
  143. if (pyop_instance == nullptr) {
  144. MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
  145. }
  146. auto value_node = NewValueNode(pyop_instance);
  147. return value_node->cast<AnfNodePtr>();
  148. }
  149. AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
  150. name_idx_++;
  151. OperatorAttrs attrs;
  152. ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
  153. if (pyop_instance == nullptr) {
  154. MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
  155. }
  156. auto value_node = NewValueNode(pyop_instance);
  157. return value_node->cast<AnfNodePtr>();
  158. }
  159. } // namespace parallel
  160. } // namespace mindspore