/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "src/node.h" #include #include #include #include #include #include "schema/inner/ms_generated.h" #include "common/mslog.h" #include "common/op_utils.h" #include "include/errorcode.h" #include "src/op_factory.h" namespace mindspore { namespace predict { Node::Node(const NodeDef *nodeDef) : id(std::string(nodeDef->opDef()->name()->c_str())), type(GetOpTypeName(*nodeDef)) {} Node::~Node() { if (op != nullptr) { delete op; } } NODE_ID Node::ID() { return id; } std::string Node::Type() { return type; } void Node::SetTensors(const NodeDef &nodeDef, const std::vector &allTensors) { if (nodeDef.opDef() == nullptr) { MS_LOGE("nodeDef is null"); return; } auto inputIndex = nodeDef.opDef()->inputIndex(); MS_ASSERT(inputIndex != nullptr); inputs.resize(inputIndex->size()); std::transform(inputIndex->begin(), inputIndex->end(), inputs.begin(), [allTensors](int i) { return allTensors[i]; }); auto outputIndex = nodeDef.opDef()->outputIndex(); MS_ASSERT(outputIndex != nullptr); outputs.resize(outputIndex->size()); std::transform(outputIndex->begin(), outputIndex->end(), outputs.begin(), [allTensors](int i) { return allTensors[i]; }); } void Node::SetDepends(const std::unordered_set &deps) { depends = deps; } std::unordered_set Node::GetDepends() { return depends; } void Node::AddInEdge(Node *node) { if (node == nullptr) { MS_LOGE("node is null"); return; } inEdges.insert(node); } void Node::AddOutEdge(Node *node) { if (node == nullptr) { MS_LOGE("node is null"); return; } outEdges.insert(node); } std::unordered_set &Node::GetAllInEdges() { return inEdges; } std::unordered_set &Node::GetAllOutEdges() { return outEdges; } std::vector &Node::GetOutputTensors() { return outputs; } std::vector &Node::GetInputTensors() { return inputs; } int Node::InitOp(const OpDef &opDef, const Context &ctx) { OpDesc dst; dst.type = GetOpType(opDef); dst.arch = X86_FP32; MS_ASSERT(OpFactory::GetInstance() != nullptr); op = OpFactory::GetInstance()->GetOp(inputs, outputs, opDef, ctx, dst); if (op == nullptr) { MS_LOGE("Can't find opName: %s, type: %s ", id.c_str(), type.c_str()); return RET_ERROR; } return RET_OK; } int Node::Run(const Context &ctx) { MS_LOGD("%s run start", id.c_str()); auto ret = MallocOutput(ctx); if (ret != RET_OK) { MS_LOGE("MallocOutput failed: %d", ret); return ret; } if (op == nullptr) { MS_LOGE("op is nullptr."); return RET_ERROR; } ret = op->Execute(inputs, outputs); if (ret != RET_OK) { return ret; } FreeInput(); return RET_OK; } int Node::MallocOutput(const Context &ctx) { size_t refCount = outEdges.size(); for (auto tensor : outputs) { if (tensor == nullptr) { MS_LOGE("tensor in outputs is nullptr"); return RET_ERROR; } auto ret = tensor->MallocData(ctx.allocator, refCount); if (ret != RET_OK) { return ret; } } return RET_OK; } void Node::FreeInput() { for (auto tensor : inputs) { if (tensor == nullptr) { MS_LOGW("tensor in inputs is nullptr"); return; } if (tensor->RefCount() != MSConst_WEIGHT_REFCOUNT) { tensor->FreeData(); } } } } // namespace predict } // namespace mindspore