/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "coder/train.h" #include #include #include #include #include #include #include #include "schema/ops_generated.h" #include "src/common/prim_util.h" namespace mindspore::lite::micro { std::set FindInferenceOpcoders(OperatorCoder *edge) { std::set subgraph; std::queue to_visit; to_visit.push(edge); while (!to_visit.empty()) { size_t size = to_visit.size(); for (size_t i = 0; i < size; ++i) { OperatorCoder *curr = to_visit.front(); to_visit.pop(); if (subgraph.find(curr) != subgraph.end()) { continue; } subgraph.insert(curr); for (const auto &op : curr->input_ops()) { to_visit.push(op); } } } auto item = subgraph.find(edge); if (item == subgraph.end()) { MS_LOG(ERROR) << "failed to find the edge in the subgraph"; return subgraph; } // erase edge operator coder from subgraph subgraph.erase(item); return subgraph; } int Train::TransformGraphForTrain(CoderContext *context, const std::vector> &op_coders) { const std::array loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits, schema::PrimitiveType_BinaryCrossEntropy, schema::PrimitiveType_SmoothL1Loss, schema::PrimitiveType_SmoothL1LossGrad, schema::PrimitiveType_SigmoidCrossEntropyWithLogits, schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad}; OperatorCoder *loss_op = nullptr; for (const auto &opcoder : op_coders) { const Model::Node *node = opcoder->node(); int primitive_type = GetPrimitiveType(node->primitive_); auto item = std::find(loss_types.begin(), loss_types.end(), primitive_type); if (item != loss_types.end()) { loss_op = opcoder.get(); break; } } MS_CHECK_PTR(loss_op); size_t op_num = op_coders.size(); std::vector code_blocks = context->code_blocks(); if (op_num != code_blocks.size()) { MS_LOG(INFO) << "the number of code blocks and op coders is not equal"; return RET_ERROR; } std::set inference_ops = FindInferenceOpcoders(loss_op); std::vector inferences_blocks; std::vector train_blocks; for (size_t i = 0; i < op_num; ++i) { auto &opcoder = op_coders.at(i); std::string block = code_blocks.at(i); if (inference_ops.find(opcoder.get()) != inference_ops.end()) { inferences_blocks.push_back(block); } train_blocks.push_back(block); } context->set_inference_blocks(inferences_blocks); context->set_train_blocks(train_blocks); return RET_OK; } } // namespace mindspore::lite::micro