You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf.cc 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ir/anf.h"
  19. #include <algorithm>
  20. #include <sstream>
  21. #include <vector>
  22. #include <unordered_map>
  23. #include "ir/visitor.h"
  24. #include "pipeline/static_analysis/static_analysis.h"
  25. #include "operator/ops.h"
  26. #include "parallel/ops_info/ops_utils.h"
  27. namespace mindspore {
  28. // namespace to support intermediate representation definition
  29. // Methods of AnfNode
  30. TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
  31. BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
  32. std::string AnfNode::ToString() const {
  33. return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
  34. }
  35. CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
  36. : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
  37. // Check if CNode is an apply with the specific Primitive.
  38. bool CNode::IsApply(const PrimitivePtr &value) const {
  39. if (value == nullptr) {
  40. return false;
  41. }
  42. if (inputs_.size() != 0 && IsValueNode<Primitive>(inputs_[0])) {
  43. PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(inputs_[0]);
  44. if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) {
  45. return true;
  46. }
  47. }
  48. return false;
  49. }
  50. void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; }
  51. std::string CNode::DebugString(int recursive_level) const {
  52. std::ostringstream buffer;
  53. if (recursive_level > 0) {
  54. if (func_graph() != nullptr) {
  55. buffer << func_graph()->ToString() << ":";
  56. }
  57. buffer << ToString() << "{";
  58. bool is_first_node = true;
  59. int idx = 0;
  60. for (auto &node : inputs_) {
  61. MS_EXCEPTION_IF_NULL(node);
  62. if (is_first_node) {
  63. is_first_node = false;
  64. } else {
  65. buffer << ", ";
  66. }
  67. buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1);
  68. idx++;
  69. }
  70. buffer << "}";
  71. } else {
  72. buffer << ToString();
  73. }
  74. return buffer.str();
  75. }
  76. OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
  77. if (operator_info_ != nullptr) {
  78. MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
  79. << ", using the new one: " << operator_info->name();
  80. auto old_ptr = operator_info_;
  81. operator_info_ = operator_info;
  82. return old_ptr;
  83. }
  84. operator_info_ = operator_info;
  85. return nullptr;
  86. }
  87. std::string CNode::fullname_with_scope() {
  88. // if full name is set, return its name immediately
  89. if (!fullname_with_scope_.empty()) {
  90. return fullname_with_scope_;
  91. }
  92. if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
  93. IsApply(prim::kPrimHistogramSummary)) {
  94. std::string tag = GetValue<std::string>(GetValueNode(input(1)));
  95. if (tag == "") {
  96. MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
  97. }
  98. std::string name;
  99. if (IsApply(prim::kPrimScalarSummary)) {
  100. name = tag + "[:Scalar]";
  101. } else if (IsApply(prim::kPrimImageSummary)) {
  102. name = tag + "[:Image]";
  103. } else if (IsApply(prim::kPrimHistogramSummary)) {
  104. name = tag + "[:Histogram]";
  105. } else {
  106. name = tag + "[:Tensor]";
  107. }
  108. fullname_with_scope_ = name;
  109. } else {
  110. // cnode input 0 should be primitive ptr
  111. auto value_ptr = input(0)->cast<ValueNodePtr>();
  112. if (value_ptr == nullptr) {
  113. MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
  114. fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
  115. return fullname_with_scope_;
  116. }
  117. auto input_value = value_ptr->value();
  118. if (input_value == nullptr) {
  119. MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
  120. fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
  121. return fullname_with_scope_;
  122. }
  123. PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
  124. MS_EXCEPTION_IF_NULL(scope());
  125. MS_EXCEPTION_IF_NULL(prim);
  126. fullname_with_scope_ =
  127. scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
  128. }
  129. return fullname_with_scope_;
  130. }
  131. std::string ValueNode::ToString() const {
  132. MS_EXCEPTION_IF_NULL(value_);
  133. if (value_->isa<FuncGraph>()) {
  134. return value_->cast<FuncGraphPtr>()->ToString();
  135. }
  136. std::ostringstream buffer;
  137. buffer << AnfNode::ToString();
  138. buffer << "(" << value_->ToString() << ")";
  139. return buffer.str();
  140. }
  141. std::string ValueNode::DebugString(int) const {
  142. MS_EXCEPTION_IF_NULL(value_);
  143. std::ostringstream buffer;
  144. buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString();
  145. return buffer.str();
  146. }
  147. std::string ValueNode::fullname_with_scope() {
  148. if (!fullname_with_scope_.empty()) {
  149. return fullname_with_scope_;
  150. }
  151. MS_EXCEPTION_IF_NULL(scope());
  152. fullname_with_scope_ = scope()->name() + "/" + "data-" + id_generator::get_id(shared_from_base<ValueNode>());
  153. return fullname_with_scope_;
  154. }
  155. void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
  156. void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
  157. void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
  158. bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
  159. MS_EXCEPTION_IF_NULL(node);
  160. auto cnode = node->cast<CNodePtr>();
  161. if (cnode != nullptr) {
  162. return cnode->IsApply(value);
  163. }
  164. return false;
  165. }
  166. PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
  167. if (node == nullptr) {
  168. return nullptr;
  169. }
  170. auto cnode = node->cast<CNodePtr>();
  171. if (cnode != nullptr) {
  172. if (cnode->size() > 0) {
  173. auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
  174. return prim;
  175. }
  176. }
  177. return nullptr;
  178. }
  179. std::string GetCNodeFuncName(const CNodePtr cnode) {
  180. if (cnode->inputs().empty()) {
  181. return "";
  182. }
  183. AnfNodePtr valuenode = cnode->input(0);
  184. if (valuenode->isa<ValueNode>()) {
  185. auto value = GetValueNode(valuenode);
  186. // check whether the valuenode is primitive
  187. if (value->isa<Primitive>()) {
  188. return value->cast<PrimitivePtr>()->name();
  189. }
  190. return value->ToString();
  191. }
  192. return "";
  193. }
  194. bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
  195. if (IsValueNode<Primitive>(node)) {
  196. PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
  197. MS_EXCEPTION_IF_NULL(value);
  198. if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) {
  199. return true;
  200. }
  201. }
  202. return false;
  203. }
  204. namespace id_generator {
  205. static std::unordered_map<std::string, int> node_ids;
  206. std::string get_id(const AnfNodePtr &node) {
  207. auto type_name = node->type_name();
  208. if (node_ids.find(type_name) == node_ids.end()) {
  209. node_ids[type_name] = 0;
  210. } else {
  211. node_ids[type_name]++;
  212. }
  213. return std::to_string(node_ids[type_name]);
  214. }
  215. void reset_id() { node_ids.clear(); }
  216. } // namespace id_generator
  217. } // namespace mindspore