From: @yangjie159 Reviewed-by: @wangchengyuan,@HilbertDavid Signed-off-by: @wangchengyuantags/v1.2.0-rc1
| @@ -4,6 +4,7 @@ set(CODER_SRC | |||
| ${MICRO_DIR}/coder/context.cc | |||
| ${MICRO_DIR}/coder/graph.cc | |||
| ${MICRO_DIR}/coder/session.cc | |||
| ${MICRO_DIR}/coder/train.cc | |||
| ) | |||
| set(CODER_ALLOCATOR_SRC | |||
| @@ -14,10 +15,12 @@ set(CODER_ALLOCATOR_SRC | |||
| set(CODER_GENERATOR_SRC | |||
| ${MICRO_DIR}/coder/generator/generator.cc | |||
| ${MICRO_DIR}/coder/generator/inference/inference_generator.cc | |||
| ${MICRO_DIR}/coder/generator/train/train_generator.cc | |||
| ${MICRO_DIR}/coder/generator/component/benchmark_component.cc | |||
| ${MICRO_DIR}/coder/generator/component/common_component.cc | |||
| ${MICRO_DIR}/coder/generator/component/weight_component.cc | |||
| ${MICRO_DIR}/coder/generator/component/cmake_component.cc | |||
| ${MICRO_DIR}/coder/generator/component/train_component.cc | |||
| ) | |||
| set(CODER_OPCODERS_SRC | |||
| @@ -39,7 +39,7 @@ class CoderFlags : public virtual FlagParser { | |||
| AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", "."); | |||
| AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", ""); | |||
| AddFlag(&CoderFlags::target_, "target", "generateed code target, x86| ARM32M| ARM32A| ARM64", "x86"); | |||
| AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Android ", "Normal"); | |||
| AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Normal | Inference | Train", "Normal"); | |||
| AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump perlayer's time cost and tensor, true | false", false); | |||
| } | |||
| @@ -87,7 +87,8 @@ int Coder::Run(const std::string &model_path) { | |||
| int Coder::Init(const CoderFlags &flags) const { | |||
| static const std::map<std::string, Target> kTargetMap = { | |||
| {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}}; | |||
| static const std::map<std::string, CodeMode> kCodeModeMap = {{"Normal", Code_Normal}, {"Android", Code_Android}}; | |||
| static const std::map<std::string, CodeMode> kCodeModeMap = { | |||
| {"Normal", Code_Normal}, {"Inference", Code_Inference}, {"Train", Code_Train}}; | |||
| Configurator *config = Configurator::GetInstance(); | |||
| @@ -21,7 +21,7 @@ | |||
| namespace mindspore::lite::micro { | |||
| enum Target { kX86 = 0, kARM32M = 1, kARM32A = 2, kARM64 = 3, kAllTargets = 4, kTargetUnknown = 99 }; | |||
| enum CodeMode { Code_Normal = 0, Code_Android = 1, Code_Unknown = 99 }; | |||
| enum CodeMode { Code_Normal = 0, Code_Inference = 1, Code_Train = 2, Code_Unknown = 99 }; | |||
| class Configurator { | |||
| public: | |||
| @@ -0,0 +1,180 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "coder/generator/component/train_component.h" | |||
| #include <string> | |||
| #include "coder/utils/type_cast.h" | |||
| namespace mindspore::lite::micro { | |||
| void CodeTrainParams(std::ofstream &ofs) { | |||
| ofs << "struct TrainParameter {\n" | |||
| " float beta1_;\n" | |||
| " float beta2_;\n" | |||
| " float epsilon_;\n" | |||
| "};\n" | |||
| "\n" | |||
| "enum EarlyStopType {\n" | |||
| " Diff = 0,\n" | |||
| " WeigthDiff = 1,\n" | |||
| " Abs = 2,\n" | |||
| "};\n" | |||
| "\n" | |||
| "struct EarlyStop {\n" | |||
| " enum EarlyStopType type;\n" | |||
| " float tolerate;\n" | |||
| "};\n\n"; | |||
| } | |||
| void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name) { | |||
| ofs << "/**\n" | |||
| " *\n" | |||
| " * @param size, return the number of features\n" | |||
| " * @return, the address of features\n" | |||
| " */\n" | |||
| << "FeatureParam *" << module_name << "_GetFeatures(int *size);\n\n"; | |||
| ofs << "/**\n" | |||
| " *\n" | |||
| " * @param features, the address of features\n" | |||
| " * @param size, the number of features\n" | |||
| " * @return, status\n" | |||
| " */\n" | |||
| << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size);\n\n"; | |||
| } | |||
| void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, | |||
| const std::unique_ptr<CoderContext> &ctx) { | |||
| size_t features_num = 0; | |||
| ofs << "static FeatureParam feature_params[] = {\n"; | |||
| for (const auto &item : ctx->saved_weights()) { | |||
| std::string addr = item.first; | |||
| Tensor *tensor = item.second; | |||
| if (tensor->tensor_name().empty()) { | |||
| MS_LOG(ERROR) << "exist empty feature"; | |||
| continue; | |||
| } | |||
| ofs << "\t{\"" << tensor->tensor_name() << "\", " << addr << ", " << tensor->ElementsNum() << ", " | |||
| << EnumMicroTensorDataType(tensor->data_type()) << "}, \n"; | |||
| features_num++; | |||
| } | |||
| ofs << "};\n"; | |||
| ofs << "FeatureParam *" << module_name << "_GetFeatures(int *size) {\n" | |||
| << " *size = " << features_num << ";\n" | |||
| << " return feature_params;\n" | |||
| "}\n\n"; | |||
| ofs << "int " << module_name << "_UpdateFeatures(FeatureParam *features, int size) {\n" | |||
| << " for (int i = 0; i < size; ++i) {\n" | |||
| " FeatureParam *src = features + i;\n" | |||
| " FeatureParam dst;\n" | |||
| " // find the dst feature\n" | |||
| " bool is_find = false;\n" | |||
| << " for (int j = 0; j < " << features_num << "; ++j) {\n" | |||
| << " if (strcmp(src->name, feature_params[j].name) == 0) {\n" | |||
| " dst = feature_params[j];\n" | |||
| " is_find = true;\n" | |||
| " break;\n" | |||
| " }\n" | |||
| " }\n" | |||
| " if (!is_find) {\n" | |||
| " MICRO_ERROR(\"invalid feature param: %s\", src->name);\n" | |||
| " return RET_ERROR;\n" | |||
| " }\n" | |||
| " if (src->elenums != dst.elenums) {\n" | |||
| " MICRO_ERROR(\"feature %s elenums is mismatch, src: %lu, dst: %lu\", src->name, src->elenums, " | |||
| "dst.elenums);\n" | |||
| " return RET_ERROR;\n" | |||
| " }\n" | |||
| " memcpy(dst.data, src->data, src->elenums * sizeof(float));\n" | |||
| " }\n" | |||
| " MICRO_INFO(\"update features map success\");\n" | |||
| " return RET_OK;\n" | |||
| "}\n\n"; | |||
| } | |||
| void CodeTrainState(std::ofstream &ofs, const std::string &module_name) { | |||
| ofs << "/**\n" | |||
| " * Train Function\n" | |||
| " * @param epoch, the train epoch\n" | |||
| " * @param iterations, which is equal to batch_num, the number of iterations of each epoch\n" | |||
| " * @param use_train_param, default parameters already exists, such as the momentum, user can update these\n" | |||
| " * parameters to improve the accuracy\n" | |||
| " * @param parameter, the TrainParameter contains epsilon/beta1/beta2\n" | |||
| " * @return status\n" | |||
| " */\n" | |||
| << "int " << module_name | |||
| << "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter *parameter, " | |||
| "const struct EarlyStop *early_stop);\n\n"; | |||
| } | |||
| void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx) { | |||
| std::vector<Tensor *> inputs = ctx->graph_inputs(); | |||
| size_t inputs_num = inputs.size(); | |||
| auto inputs_tostring = [&]() { | |||
| std::string result; | |||
| result += "{"; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| result += ctx->input_name() + std::to_string(i) + ", "; | |||
| } | |||
| result += "}"; | |||
| return result; | |||
| }; | |||
| auto wrap = [](int i) { return "[" + std::to_string(i) + "]"; }; | |||
| auto offset_inputs = [&]() { | |||
| std::string src = "origin_inputs"; | |||
| std::string dst = "input_ptr"; | |||
| std::string result; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| result += dst + wrap(i) += " = " + src + wrap(i) + " + j * " + std::to_string(inputs[i]->Size()) + ";\n"; | |||
| } | |||
| return result; | |||
| }; | |||
| auto varify_inputs = [&]() { | |||
| std::string result; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| result += "origin_input" + wrap(i) + " + iterations * " + std::to_string(inputs[i]->Size()) + " == NULL"; | |||
| i < inputs.size() - 1 ? result += " || " : result += ""; | |||
| } | |||
| return result; | |||
| }; | |||
| ofs << "int " << module_name | |||
| << "_Train(const int epoch, const int iterations, bool use_train_param, const struct TrainParameter " | |||
| "*parameter, const struct EarlyStop *early_stop) {\n" | |||
| " if (iterations <= 0 || epoch <= 0) {\n" | |||
| " MICRO_ERROR(\"error iterations or epoch!, epoch:%d, iterations:%d\", epoch, iterations);\n" | |||
| " return RET_ERROR;\n" | |||
| " }\n" | |||
| " MICRO_INFO(\"train epoch: %d, batch_num: %d\", epoch, iterations);\n" | |||
| << " const void *origin_input[] = " << inputs_tostring() << ";\n"; | |||
| ofs << " if (" << varify_inputs() << ") {\n" | |||
| << " MICRO_ERROR(\"input data is invalid, epoch: %d, iterations: %d\", epoch, iterations);\n" | |||
| " return RET_ERROR;\n" | |||
| " }\n"; | |||
| ofs << " for (int i = 0; i < epoch; ++i) {\n" | |||
| << " const void *input_ptr[" << inputs_num << "];\n" | |||
| << " float loss = 0;\n" | |||
| << " for (int j = 0; j < iterations; ++j) {\n" | |||
| << " " << offset_inputs() << "\n" | |||
| << " " << module_name << "_SetInputs(input_ptr, " << inputs_num << ");\n" | |||
| << " " << module_name << "_Inference();\n" | |||
| << " loss = " << module_name << "_ComputeLossAndGradient();\n" | |||
| << " }\n" | |||
| " }\n" | |||
| " return RET_OK;\n" | |||
| "};\n\n"; | |||
| } | |||
| } // namespace mindspore::lite::micro | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ | |||
| #define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <fstream> | |||
| #include "src/tensor.h" | |||
| #include "coder/context.h" | |||
| namespace mindspore::lite::micro { | |||
| void CodeTrainParams(std::ofstream &ofs); | |||
| void CodeFeaturesState(std::ofstream &ofs, const std::string &module_name); | |||
| void CodeFeaturesImplement(std::ofstream &ofs, const std::string &module_name, | |||
| const std::unique_ptr<CoderContext> &ctx); | |||
| void CodeTrainState(std::ofstream &ofs, const std::string &module_name); | |||
| void CodeTrainImplement(std::ofstream &ofs, const std::string &module_name, const std::unique_ptr<CoderContext> &ctx); | |||
| } // namespace mindspore::lite::micro | |||
| #endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_COMPONENT_H_ | |||
| @@ -59,15 +59,13 @@ Generator::Generator(std::unique_ptr<CoderContext> ctx) { | |||
| Generator::~Generator() { (void)umask(origin_umask_); } | |||
| void Generator::CodeNetRunFunc(std::ofstream &ofs) { | |||
| // generate net predict code | |||
| // generate net inference code | |||
| ofs << "void " << config_->module_name() << "_Inference() {\n"; | |||
| if (config_->code_mode() == CodeMode::Code_Android) { | |||
| if (config_->code_mode() == CodeMode::Code_Inference) { | |||
| ofs << "int thread_num = GetCurrentThreadNum(THREAD_POOL_DEFAULT);\n"; | |||
| } | |||
| for (const auto &codeBlock : ctx_->code_blocks()) { | |||
| ofs << "\t{\n"; | |||
| ofs << codeBlock; | |||
| ofs << "\t}\n"; | |||
| for (const auto &block : ctx_->code_blocks()) { | |||
| ofs << "\t{\n" << block << "\t}\n"; | |||
| } | |||
| ofs << "}\n"; | |||
| } | |||
| @@ -28,7 +28,7 @@ int InferenceGenerator::CodeNetHFile() { | |||
| MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); | |||
| MS_LOG(INFO) << "write " << net_include_file; | |||
| ofs << g_hwLicense; | |||
| if (config_->code_mode() == CodeMode::Code_Android) { | |||
| if (config_->code_mode() == CodeMode::Code_Inference) { | |||
| ofs << "#include \"src/runtime/thread_pool.h\"\n"; | |||
| } | |||
| ofs << "#include \"microtensor.h\"\n\n"; | |||
| @@ -78,7 +78,7 @@ int InferenceGenerator::CodeBenchmarkFile() { | |||
| if (config_->is_weight_file()) { | |||
| CodeBenchmarkInitWeight(ofs, config_->module_name()); | |||
| } | |||
| if (config_->code_mode() == CodeMode::Code_Android) { | |||
| if (config_->code_mode() == CodeMode::Code_Inference) { | |||
| CodeBenchmarkConfigThread(ofs); | |||
| } | |||
| CodeBenchmarkInference(ofs, config_->module_name()); | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "coder/generator/train/train_generator.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include "coder/generator/component/common_component.h" | |||
| #include "coder/generator/component/benchmark_component.h" | |||
| #include "coder/generator/component/train_component.h" | |||
| #include "coder/generator/component/const_blocks/license.h" | |||
| namespace mindspore::lite::micro { | |||
| void TrainGenerator::CodeGradientFunc(std::ofstream &ofs) const { | |||
| ofs << "float " << config_->module_name() << "_ComputeLossAndGradient() {\n"; | |||
| ofs << " float loss = 0;\n"; | |||
| for (const auto &block : ctx_->train_blocks()) { | |||
| ofs << " {\n" << block << " }\n"; | |||
| } | |||
| ofs << " return loss;\n"; | |||
| ofs << "}\n"; | |||
| } | |||
| int TrainGenerator::CodeNetHFile() { | |||
| std::string net_include_file = net_inc_file_path_ + net_inc_hfile_; | |||
| std::ofstream ofs(net_include_file); | |||
| MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); | |||
| MS_LOG(INFO) << "write " << net_include_file; | |||
| ofs << g_hwLicense; | |||
| if (config_->code_mode() == CodeMode::Code_Inference) { | |||
| ofs << "#include \"src/runtime/thread_pool.h\"\n"; | |||
| } | |||
| ofs << "#include \"microtensor.h\"\n\n"; | |||
| CodeTrainParams(ofs); | |||
| CodeInputAndOutputState(ofs, config_->module_name()); | |||
| if (is_get_quant_args_) { | |||
| CodeGraphQuantArgsState(ofs, config_->module_name()); | |||
| } | |||
| if (config_->is_weight_file()) { | |||
| CodeInitWeightState(ofs, config_->module_name()); | |||
| } | |||
| CodeManageResourceState(ofs, config_->module_name()); | |||
| CodeInferenceState(ofs, config_->module_name()); | |||
| CodeFeaturesState(ofs, config_->module_name()); | |||
| CodeTrainState(ofs, config_->module_name()); | |||
| return RET_OK; | |||
| } | |||
| int TrainGenerator::CodeNetCFile() { | |||
| std::string net_impl_file = net_src_file_path_ + net_src_cfile_; | |||
| std::ofstream ofs(net_impl_file); | |||
| MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); | |||
| MS_LOG(INFO) << "write " << net_impl_file; | |||
| CodeSourceFileInclude(ofs, net_weight_hfile_, net_inc_hfile_); | |||
| CodeInputAndOutputImplement(ofs, config_->module_name(), ctx_); | |||
| CodeInitResourceImplement(ofs, config_->module_name(), ctx_); | |||
| CodeFreeResourceImplement(ofs, config_->module_name(), ctx_); | |||
| CodeFeaturesImplement(ofs, config_->module_name(), ctx_); | |||
| if (is_get_quant_args_) { | |||
| CodeGraphQuantArgsImplement(ofs, config_->module_name(), ctx_); | |||
| } | |||
| CodeNetRunFunc(ofs); | |||
| CodeGradientFunc(ofs); | |||
| CodeTrainImplement(ofs, config_->module_name(), ctx_); | |||
| ofs.close(); | |||
| return RET_OK; | |||
| } | |||
| int TrainGenerator::CodeBenchmarkFile() { | |||
| std::string net_main_impl_file = net_main_file_path_ + net_main_cfile_; | |||
| std::ofstream ofs(net_main_impl_file); | |||
| MS_LOG(INFO) << "write " << net_main_impl_file; | |||
| MS_CHECK_TRUE(!ofs.bad(), "filed to open file"); | |||
| std::vector<Tensor *> inputs = ctx_->graph_inputs(); | |||
| size_t inputs_num = inputs.size(); | |||
| CodeBenchmarkHeader(ofs, net_inc_hfile_); | |||
| CodeBenchmarkUsage(ofs); | |||
| CodeBenchmarkWarmup(ofs, config_->module_name()); | |||
| CodeBenchmarkSetInputs(ofs, config_->module_name(), ctx_); | |||
| CodeBenchmarkSetBuffer(ofs, config_->module_name()); | |||
| if (config_->is_weight_file()) { | |||
| CodeBenchmarkInitWeight(ofs, config_->module_name()); | |||
| } | |||
| if (config_->code_mode() == CodeMode::Code_Inference) { | |||
| CodeBenchmarkConfigThread(ofs); | |||
| } | |||
| CodeBenchmarkInference(ofs, config_->module_name()); | |||
| CodeBenchmarkPrintOutputs(ofs, config_->module_name()); | |||
| CodeBenchmarkFreeResourse(ofs, config_->module_name(), inputs_num); | |||
| ofs.close(); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite::micro | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_ | |||
| #define MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "micro/coder/generator/generator.h" | |||
| namespace mindspore::lite::micro { | |||
| class TrainGenerator : public Generator { | |||
| public: | |||
| explicit TrainGenerator(std::unique_ptr<CoderContext> ctx) : Generator(std::move(ctx)) {} | |||
| ~TrainGenerator() override = default; | |||
| private: | |||
| int CodeNetHFile() override; | |||
| int CodeNetCFile() override; | |||
| int CodeBenchmarkFile() override; | |||
| void CodeGradientFunc(std::ofstream &ofs) const; | |||
| }; | |||
| } // namespace mindspore::lite::micro | |||
| #endif // MINDSPORE_LITE_MICRO_CODER_GENERATOR_TRAIN_GENERATOR_H_ | |||
| @@ -28,8 +28,8 @@ | |||
| #include "securec/include/securec.h" | |||
| #include "coder/opcoders/op_coder_register.h" | |||
| #include "coder/log.h" | |||
| namespace mindspore::lite::micro { | |||
| class CoderContext; | |||
| constexpr int kPrecision = 19; | |||
| #define CODE_PARALLEL_FUNC(func) code << "ParallelLaunch(THREAD_POOL_DEFAULT, " << func << ", &args, thread_num);\n" | |||
| @@ -61,7 +61,7 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build() { | |||
| } | |||
| op_coder->set_input_tensor_indices(input_indices_); | |||
| op_coder->set_output_tensor_indices(output_indices_); | |||
| int thread_num = this->mode_ == CodeMode::Code_Android ? kMAX_THREAD_NUM_SUPPORT : 1; | |||
| int thread_num = this->mode_ == CodeMode::Code_Inference ? kMAX_THREAD_NUM_SUPPORT : 1; | |||
| op_coder->set_thread_num(thread_num); | |||
| parameter->thread_num_ = thread_num; | |||
| op_coder->set_parameter(parameter); | |||
| @@ -16,13 +16,14 @@ | |||
| #include "coder/session.h" | |||
| #include <set> | |||
| #include <queue> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "coder/allocator/allocator.h" | |||
| #include "coder/context.h" | |||
| #include "coder/train.h" | |||
| #include "coder/allocator/allocator.h" | |||
| #include "coder/generator/generator.h" | |||
| #include "coder/generator/inference/inference_generator.h" | |||
| #include "coder/generator/train/train_generator.h" | |||
| #include "coder/opcoders/op_coder_builder.h" | |||
| #include "coder/utils/coder_utils.h" | |||
| #include "coder/log.h" | |||
| @@ -89,6 +90,9 @@ void CoderSession::EndCode() { | |||
| blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_); | |||
| context_->set_code_blocks(blocks); | |||
| } | |||
| if (config->code_mode() == Code_Train) { | |||
| Train::TransformGraphForTrain(context_.get(), op_coders_); | |||
| } | |||
| } | |||
| int CoderSession::Run() { | |||
| @@ -123,10 +127,14 @@ int CoderSession::GenerateCode() { | |||
| CodeMode code_mode = config->code_mode(); | |||
| switch (code_mode) { | |||
| case Code_Normal: | |||
| case Code_Android: | |||
| MS_LOG(INFO) << "generate code for Android"; | |||
| case Code_Inference: | |||
| MS_LOG(INFO) << "generate code for Inference"; | |||
| generator = std::make_shared<InferenceGenerator>(std::move(context_)); | |||
| break; | |||
| case Code_Train: | |||
| MS_LOG(INFO) << "generate code for Inference"; | |||
| generator = std::make_shared<TrainGenerator>(std::move(context_)); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "unsupported generator code mode, " << code_mode; | |||
| return RET_ERROR; | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "coder/train.h" | |||
| #include <memory> | |||
| #include <set> | |||
| #include <map> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace mindspore::lite::micro { | |||
| std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) { | |||
| std::set<OperatorCoder *> subgraph; | |||
| std::queue<OperatorCoder *> to_visit; | |||
| to_visit.push(edge); | |||
| while (!to_visit.empty()) { | |||
| size_t size = to_visit.size(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| OperatorCoder *curr = to_visit.front(); | |||
| to_visit.pop(); | |||
| if (subgraph.find(curr) != subgraph.end()) { | |||
| continue; | |||
| } | |||
| subgraph.insert(curr); | |||
| for (const auto &op : curr->input_ops()) { | |||
| to_visit.push(op); | |||
| } | |||
| } | |||
| } | |||
| auto item = subgraph.find(edge); | |||
| if (item == subgraph.end()) { | |||
| MS_LOG(ERROR) << "failed to find the edge in the subgraph"; | |||
| return subgraph; | |||
| } | |||
| // erase edge operator coder from subgraph | |||
| subgraph.erase(item); | |||
| return subgraph; | |||
| } | |||
| int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders) { | |||
| const std::set<schema::PrimitiveType> loss_types = {schema::PrimitiveType_SoftmaxCrossEntropy, | |||
| schema::PrimitiveType_SparseSoftmaxCrossEntropy, | |||
| schema::PrimitiveType_BinaryCrossEntropy, | |||
| schema::PrimitiveType_SmoothL1Loss, | |||
| schema::PrimitiveType_SmoothL1LossGrad, | |||
| schema::PrimitiveType_SigmoidCrossEntropyWithLogits, | |||
| schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad}; | |||
| OperatorCoder *loss_op = nullptr; | |||
| for (const auto &opcoder : op_coders) { | |||
| auto primitive_type = static_cast<schema::PrimitiveType>(opcoder->primitive()->Type()); | |||
| auto item = loss_types.find(primitive_type); | |||
| if (item != loss_types.end()) { | |||
| loss_op = opcoder.get(); | |||
| break; | |||
| } | |||
| } | |||
| MS_CHECK_PTR(loss_op); | |||
| size_t op_num = op_coders.size(); | |||
| std::vector<std::string> code_blocks = context->code_blocks(); | |||
| if (op_num != code_blocks.size()) { | |||
| MS_LOG(INFO) << "the number of code blocks and op coders is not equal"; | |||
| return RET_ERROR; | |||
| } | |||
| std::set<OperatorCoder *> inference_ops = FindInferenceOpcoders(loss_op); | |||
| std::vector<std::string> inferences_blocks; | |||
| std::vector<std::string> train_blocks; | |||
| for (size_t i = 0; i < op_num; ++i) { | |||
| auto &opcoder = op_coders.at(i); | |||
| std::string block = code_blocks.at(i); | |||
| if (inference_ops.find(opcoder.get()) != inference_ops.end()) { | |||
| inferences_blocks.push_back(block); | |||
| } | |||
| train_blocks.push_back(block); | |||
| } | |||
| context->set_inference_blocks(inferences_blocks); | |||
| context->set_train_blocks(train_blocks); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite::micro | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_ | |||
| #define MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "coder/context.h" | |||
| #include "coder/opcoders/op_coder.h" | |||
| namespace mindspore::lite::micro { | |||
| class Train { | |||
| public: | |||
| static int TransformGraphForTrain(CoderContext *context, | |||
| const std::vector<std::unique_ptr<OperatorCoder>> &op_coders); | |||
| }; | |||
| } // namespace mindspore::lite::micro | |||
| #endif // MINDSPORE_LITE_MICRO_CODER_CODER_TRAIN_H_ | |||
| @@ -142,32 +142,4 @@ std::vector<std::string> SplitString(std::string str, const std::string &pattern | |||
| } | |||
| return results; | |||
| } | |||
| std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) { | |||
| std::set<OperatorCoder *> subgraph; | |||
| std::queue<OperatorCoder *> to_visit; | |||
| to_visit.push(edge); | |||
| while (!to_visit.empty()) { | |||
| size_t size = to_visit.size(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| OperatorCoder *curr = to_visit.front(); | |||
| to_visit.pop(); | |||
| if (subgraph.find(curr) != subgraph.end()) { | |||
| continue; | |||
| } | |||
| subgraph.insert(curr); | |||
| for (const auto &op : curr->input_ops()) { | |||
| to_visit.push(op); | |||
| } | |||
| } | |||
| } | |||
| auto item = subgraph.find(edge); | |||
| if (item == subgraph.end()) { | |||
| MS_LOG(ERROR) << "failed to find the edge in the subgraph"; | |||
| return subgraph; | |||
| } | |||
| // erase edge operator coder from subgraph | |||
| subgraph.erase(item); | |||
| return subgraph; | |||
| } | |||
| } // namespace mindspore::lite::micro | |||
| @@ -35,8 +35,6 @@ std::vector<std::string> AddDumpDataInfo(const std::vector<std::string> &blocks, | |||
| void PrintTensorData(const lite::Tensor *tensor, std::ofstream &ofs); | |||
| std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge); | |||
| } // namespace mindspore::lite::micro | |||
| #endif // MINDSPORE_LITE_MICRO_CODER_UTILS_CODER_UTILS_H_ | |||