You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

coder.cc 7.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "coder/coder.h"
  17. #include <getopt.h>
  18. #include <iomanip>
  19. #include <string>
  20. #include <vector>
  21. #include <map>
  22. #include "schema/inner/model_generated.h"
  23. #include "tools/common/flag_parser.h"
  24. #include "coder/session.h"
  25. #include "coder/context.h"
  26. #include "utils/dir_utils.h"
  27. #include "securec/include/securec.h"
  28. #include "src/common/file_utils.h"
  29. #include "src/common/utils.h"
  30. #include "coder/coder_config.h"
  31. namespace mindspore::lite::micro {
  32. class CoderFlags : public virtual FlagParser {
  33. public:
  34. CoderFlags() {
  35. AddFlag(&CoderFlags::is_weight_file_, "isWeightFile", "whether generating weight binary file, true| false", false);
  36. AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
  37. AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
  38. AddFlag(&CoderFlags::code_module_name_, "moduleName", "Input code module name", "");
  39. AddFlag(&CoderFlags::target_, "target", "generated code target, x86| ARM32M| ARM32A| ARM64", "x86");
  40. AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Inference | Train", "Inference");
  41. AddFlag(&CoderFlags::support_parallel_, "supportParallel", "whether support parallel launch, true | false", false);
  42. AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump the tensors data for debugging, true | false", false);
  43. }
  44. ~CoderFlags() override = default;
  45. std::string model_path_;
  46. bool support_parallel_{false};
  47. bool is_weight_file_{false};
  48. std::string code_module_name_;
  49. std::string code_path_;
  50. std::string code_mode_;
  51. bool debug_mode_{false};
  52. std::string target_;
  53. };
  54. int Coder::Run(const std::string &model_path) {
  55. session_ = CreateCoderSession();
  56. if (session_ == nullptr) {
  57. MS_LOG(ERROR) << "new session failed while running";
  58. return RET_ERROR;
  59. }
  60. STATUS status = session_->Init(model_path);
  61. if (status != RET_OK) {
  62. MS_LOG(ERROR) << "Init session failed.";
  63. return RET_ERROR;
  64. }
  65. status = session_->Build();
  66. if (status != RET_OK) {
  67. MS_LOG(ERROR) << "Set Input resize shapes error";
  68. return status;
  69. }
  70. status = session_->Run();
  71. if (status != RET_OK) {
  72. MS_LOG(ERROR) << "Generate Code Files error. " << status;
  73. return status;
  74. }
  75. status = session_->GenerateCode();
  76. if (status != RET_OK) {
  77. MS_LOG(ERROR) << "Generate Code Files error " << status;
  78. }
  79. return status;
  80. }
  81. int Coder::Init(const CoderFlags &flags) const {
  82. static const std::map<std::string, Target> kTargetMap = {
  83. {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
  84. static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}};
  85. Configurator *config = Configurator::GetInstance();
  86. std::vector<std::function<bool()>> parsers;
  87. parsers.emplace_back([flags, config]() -> bool {
  88. config->set_is_weight_file(flags.is_weight_file_);
  89. return true;
  90. });
  91. parsers.emplace_back([&flags, config]() -> bool {
  92. auto target_item = kTargetMap.find(flags.target_);
  93. MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
  94. config->set_target(target_item->second);
  95. return true;
  96. });
  97. parsers.emplace_back([&flags, config]() -> bool {
  98. auto code_item = kCodeModeMap.find(flags.code_mode_);
  99. MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + flags.code_mode_);
  100. config->set_code_mode(code_item->second);
  101. return true;
  102. });
  103. parsers.emplace_back([&flags, config]() -> bool {
  104. config->set_support_parallel(flags.support_parallel_);
  105. return true;
  106. });
  107. parsers.emplace_back([&flags, config]() -> bool {
  108. config->set_debug_mode(flags.debug_mode_);
  109. return true;
  110. });
  111. parsers.emplace_back([&flags, config]() -> bool {
  112. if (!FileExists(flags.model_path_)) {
  113. MS_LOG(ERROR) << "code_gen model_path " << flags.model_path_ << " is not valid";
  114. return false;
  115. }
  116. if (flags.code_module_name_.empty() || isdigit(flags.code_module_name_.at(0))) {
  117. MS_LOG(ERROR) << "code_gen code module name " << flags.code_module_name_
  118. << " not valid: it must be given and the first char could not be number";
  119. return false;
  120. }
  121. config->set_module_name(flags.code_module_name_);
  122. return true;
  123. });
  124. parsers.emplace_back([&flags, config]() -> bool {
  125. const std::string slash = std::string(kSlash);
  126. if (!flags.code_path_.empty() && !DirExists(flags.code_path_)) {
  127. MS_LOG(ERROR) << "code_gen code path " << flags.code_path_ << " is not valid";
  128. return false;
  129. }
  130. config->set_code_path(flags.code_path_);
  131. if (flags.code_path_.empty()) {
  132. std::string path = ".." + slash + config->module_name();
  133. config->set_code_path(path);
  134. } else {
  135. if (flags.code_path_.substr(flags.code_path_.size() - 1, 1) != slash) {
  136. std::string path = flags.code_path_ + slash + config->module_name();
  137. config->set_code_path(path);
  138. } else {
  139. std::string path = flags.code_path_ + config->module_name();
  140. config->set_code_path(path);
  141. }
  142. }
  143. return InitProjDirs(flags.code_path_, config->module_name()) != RET_ERROR;
  144. });
  145. if (!std::all_of(parsers.begin(), parsers.end(), [](auto &parser) -> bool { return parser(); })) {
  146. if (!flags.help) {
  147. std::cerr << flags.Usage() << std::endl;
  148. return 0;
  149. }
  150. return RET_ERROR;
  151. }
  152. auto print_parameter = [](auto name, auto value) {
  153. MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value;
  154. };
  155. print_parameter("modelPath", flags.model_path_);
  156. print_parameter("target", config->target());
  157. print_parameter("codePath", config->code_path());
  158. print_parameter("codeMode", config->code_mode());
  159. print_parameter("codeModuleName", config->module_name());
  160. print_parameter("isWeightFile", config->is_weight_file());
  161. print_parameter("debugMode", config->debug_mode());
  162. return RET_OK;
  163. }
  164. int RunCoder(int argc, const char **argv) {
  165. CoderFlags flags;
  166. Option<std::string> err = flags.ParseFlags(argc, argv, false, false);
  167. if (err.IsSome()) {
  168. std::cerr << err.Get() << std::endl;
  169. std::cerr << flags.Usage() << std::endl;
  170. return RET_ERROR;
  171. }
  172. if (flags.help) {
  173. std::cerr << flags.Usage() << std::endl;
  174. return RET_OK;
  175. }
  176. Coder code_gen;
  177. STATUS status = code_gen.Init(flags);
  178. if (status != RET_OK) {
  179. MS_LOG(ERROR) << "Coder init Error : " << status;
  180. return status;
  181. }
  182. status = code_gen.Run(flags.model_path_);
  183. if (status != RET_OK) {
  184. MS_LOG(ERROR) << "Run Coder Error : " << status;
  185. return status;
  186. }
  187. MS_LOG(INFO) << "end of Coder";
  188. return RET_OK;
  189. }
  190. } // namespace mindspore::lite::micro