/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/jit/validator.h" #include #include #include "ir/manager.h" #include "ir/dtype.h" #include "pipeline/jit/static_analysis/prim.h" #include "pipeline/jit/parse/resolve.h" namespace mindspore { namespace validator { using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractCSRTensor; using mindspore::abstract::AbstractError; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractRef; using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractType; void ValidateOperation(const AnfNodePtr &node) { if (!IsValueNode(node)) { return; } // Primitive must in whitelist auto prim = GetValueNode(node); MS_EXCEPTION_IF_NULL(prim); if (abstract::IsInWhiteList(prim)) { return; } if (prim->HasAttr("is_load")) { return; } if (prim->HasPyEvaluator()) { MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; return; } if (prim->prim_type() == PrimType::kPrimTypePyCheck) { MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; return; } if (prim->name() == "fake_bprop") { MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue(prim->GetAttr("info")); } MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); } bool CheckAbstractScalar(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); AbstractBasePtr abstract = node->abstract(); if (abstract->isa()) { TypePtr type = abstract->GetTypeTrack(); MS_EXCEPTION_IF_NULL(type); if (type->isa()) { MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString(); } if (type->isa() || type->isa()) { // Only allow string type from external. if (!IsValueNode(node)) { // Validate a type. MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString(); } } return true; } return false; } void ValidateAbstract(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(DEBUG) << "Node to validate is invalid"; return; } AbstractBasePtr abstract = node->abstract(); if (abstract == nullptr) { MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); return; } if (abstract->isa() || abstract->isa()) { // Validate a type. MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString(); } if (CheckAbstractScalar(node)) { return; } if (abstract->isa()) { // NOTICE: validate dead code? MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString(); return; } bool is_legal_abstract = abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa() || abstract->isa(); if (is_legal_abstract) { return; } // Other types show exception MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString(); } void ValidateValueNode(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(DEBUG) << "Node to validate is invalid"; return; } // InterpretedNode should be consumed during compile, not left to Runtime. if (IsValueNode(node)) { MS_LOG(EXCEPTION) << "Should not use Python object in runtime, node: " << node->DebugString() << "\n\nWe suppose all nodes generated by JIT Fallback not return to outside of graph."; } } void Validate(const FuncGraphPtr &fg) { FuncGraphManagerPtr mgr = Manage(fg, false); MS_EXCEPTION_IF_NULL(mgr); AnfNodeSet &all_nodes = mgr->all_nodes(); for (const auto &node : all_nodes) { ValidateOperation(node); ValidateValueNode(node); } for (const auto &node : all_nodes) { ValidateAbstract(node); } } } // namespace validator } // namespace mindspore