You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validator.cc 5.5 kB

5 years ago
5 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/jit/validator.h"
  19. #include <memory>
  20. #include <mutex>
  21. #include "ir/manager.h"
  22. #include "ir/dtype.h"
  23. #include "pipeline/jit/static_analysis/prim.h"
  24. #include "pipeline/jit/parse/resolve.h"
  25. namespace mindspore {
  26. namespace validator {
  27. using mindspore::abstract::AbstractBase;
  28. using mindspore::abstract::AbstractClass;
  29. using mindspore::abstract::AbstractCSRTensor;
  30. using mindspore::abstract::AbstractError;
  31. using mindspore::abstract::AbstractFunction;
  32. using mindspore::abstract::AbstractJTagged;
  33. using mindspore::abstract::AbstractList;
  34. using mindspore::abstract::AbstractRef;
  35. using mindspore::abstract::AbstractRowTensor;
  36. using mindspore::abstract::AbstractScalar;
  37. using mindspore::abstract::AbstractSparseTensor;
  38. using mindspore::abstract::AbstractTensor;
  39. using mindspore::abstract::AbstractTuple;
  40. using mindspore::abstract::AbstractType;
  41. void ValidateOperation(const AnfNodePtr &node) {
  42. if (!IsValueNode<Primitive>(node)) {
  43. return;
  44. }
  45. // Primitive must in whitelist
  46. auto prim = GetValueNode<PrimitivePtr>(node);
  47. MS_EXCEPTION_IF_NULL(prim);
  48. if (abstract::IsInWhiteList(prim)) {
  49. return;
  50. }
  51. if (prim->HasAttr("is_load")) {
  52. return;
  53. }
  54. if (prim->HasPyEvaluator()) {
  55. MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
  56. return;
  57. }
  58. if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
  59. MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
  60. return;
  61. }
  62. if (prim->name() == "fake_bprop") {
  63. MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"));
  64. }
  65. MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name();
  66. }
  67. bool CheckAbstractScalar(const AnfNodePtr &node) {
  68. MS_EXCEPTION_IF_NULL(node);
  69. AbstractBasePtr abstract = node->abstract();
  70. if (abstract->isa<AbstractScalar>()) {
  71. TypePtr type = abstract->GetTypeTrack();
  72. MS_EXCEPTION_IF_NULL(type);
  73. if (type->isa<EnvType>()) {
  74. MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
  75. }
  76. if (type->isa<Problem>() || type->isa<External>()) {
  77. // Only allow string type from external.
  78. if (!IsValueNode<StringImm>(node)) {
  79. // Validate a type.
  80. MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
  81. }
  82. }
  83. return true;
  84. }
  85. return false;
  86. }
  87. void ValidateAbstract(const AnfNodePtr &node) {
  88. if (node == nullptr) {
  89. MS_LOG(DEBUG) << "Node to validate is invalid";
  90. return;
  91. }
  92. AbstractBasePtr abstract = node->abstract();
  93. if (abstract == nullptr) {
  94. MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString();
  95. return;
  96. }
  97. if (abstract->isa<AbstractClass>() || abstract->isa<AbstractJTagged>()) {
  98. // Validate a type.
  99. MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
  100. }
  101. if (CheckAbstractScalar(node)) {
  102. return;
  103. }
  104. if (abstract->isa<AbstractError>()) {
  105. // NOTICE: validate dead code?
  106. MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
  107. return;
  108. }
  109. bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
  110. abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
  111. abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
  112. abstract->isa<AbstractSparseTensor>() || abstract->isa<AbstractCSRTensor>() ||
  113. abstract->isa<abstract::AbstractRefKey>() || abstract->isa<AbstractRef>() ||
  114. abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
  115. if (is_legal_abstract) {
  116. return;
  117. }
  118. // Other types show exception
  119. MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString();
  120. }
  121. void ValidateValueNode(const AnfNodePtr &node) {
  122. if (node == nullptr) {
  123. MS_LOG(DEBUG) << "Node to validate is invalid";
  124. return;
  125. }
  126. // InterpretedNode should be consumed during compile, not left to Runtime.
  127. if (IsValueNode<parse::InterpretedObject>(node)) {
  128. MS_LOG(EXCEPTION) << "Should not use Python object in runtime, node: " << node->DebugString()
  129. << "\n\nWe suppose all nodes generated by JIT Fallback not return to outside of graph.";
  130. }
  131. }
  132. void Validate(const FuncGraphPtr &fg) {
  133. FuncGraphManagerPtr mgr = Manage(fg, false);
  134. MS_EXCEPTION_IF_NULL(mgr);
  135. AnfNodeSet &all_nodes = mgr->all_nodes();
  136. for (const auto &node : all_nodes) {
  137. ValidateOperation(node);
  138. ValidateValueNode(node);
  139. }
  140. for (const auto &node : all_nodes) {
  141. ValidateAbstract(node);
  142. }
  143. }
  144. } // namespace validator
  145. } // namespace mindspore