You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

coder.cc 7.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "coder/coder.h"
  17. #include <getopt.h>
  18. #include <iomanip>
  19. #include <string>
  20. #include <vector>
  21. #include <map>
  22. #include "schema/inner/model_generated.h"
  23. #include "tools/common/flag_parser.h"
  24. #include "coder/session.h"
  25. #include "coder/context.h"
  26. #include "utils/dir_utils.h"
  27. #include "securec/include/securec.h"
  28. #include "src/common/file_utils.h"
  29. #include "src/common/utils.h"
  30. #include "coder/config.h"
  31. #include "coder/generator/component/component.h"
  32. namespace mindspore::lite::micro {
  33. class CoderFlags : public virtual FlagParser {
  34. public:
  35. CoderFlags() {
  36. AddFlag(&CoderFlags::model_path_, "modelPath", "Input model path", "");
  37. AddFlag(&CoderFlags::code_path_, "codePath", "Input code path", ".");
  38. AddFlag(&CoderFlags::target_, "target", "generated code target, x86| ARM32M| ARM32A| ARM64", "x86");
  39. AddFlag(&CoderFlags::code_mode_, "codeMode", "generated code mode, Inference | Train", "Inference");
  40. AddFlag(&CoderFlags::support_parallel_, "supportParallel", "whether support parallel launch, true | false", false);
  41. AddFlag(&CoderFlags::debug_mode_, "debugMode", "dump the tensors data for debugging, true | false", false);
  42. }
  43. ~CoderFlags() override = default;
  44. std::string model_path_;
  45. bool support_parallel_{false};
  46. std::string code_path_;
  47. std::string code_mode_;
  48. bool debug_mode_{false};
  49. std::string target_;
  50. };
  51. int Coder::Run(const std::string &model_path) {
  52. session_ = CreateCoderSession();
  53. if (session_ == nullptr) {
  54. MS_LOG(ERROR) << "new session failed while running!";
  55. return RET_ERROR;
  56. }
  57. STATUS status = session_->Init(model_path);
  58. if (status != RET_OK) {
  59. MS_LOG(ERROR) << "Init session failed!";
  60. return RET_ERROR;
  61. }
  62. status = session_->Build();
  63. if (status != RET_OK) {
  64. MS_LOG(ERROR) << "Compile graph failed!";
  65. return status;
  66. }
  67. status = session_->Run();
  68. if (status != RET_OK) {
  69. MS_LOG(ERROR) << "Generate Code Files error!" << status;
  70. return status;
  71. }
  72. status = session_->GenerateCode();
  73. if (status != RET_OK) {
  74. MS_LOG(ERROR) << "Generate Code Files error!" << status;
  75. }
  76. return status;
  77. }
  78. int Configurator::ParseProjDir(std::string model_path) {
  79. // split model_path to get model file name
  80. proj_dir_ = model_path;
  81. size_t found = proj_dir_.find_last_of("/\\");
  82. if (found != std::string::npos) {
  83. proj_dir_ = proj_dir_.substr(found + 1);
  84. }
  85. found = proj_dir_.find(".ms");
  86. if (found != std::string::npos) {
  87. proj_dir_ = proj_dir_.substr(0, found);
  88. } else {
  89. MS_LOG(ERROR) << "model file's name must be end with \".ms\".";
  90. return RET_ERROR;
  91. }
  92. if (proj_dir_.size() == 0) {
  93. proj_dir_ = "net";
  94. MS_LOG(WARNING) << "parse model's name failed, use \"net\" instead.";
  95. }
  96. return RET_OK;
  97. }
  98. int Coder::Init(const CoderFlags &flags) const {
  99. static const std::map<std::string, Target> kTargetMap = {
  100. {"x86", kX86}, {"ARM32M", kARM32M}, {"ARM32A", kARM32A}, {"ARM64", kARM64}, {"All", kAllTargets}};
  101. static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}};
  102. Configurator *config = Configurator::GetInstance();
  103. std::vector<std::function<bool()>> parsers;
  104. parsers.emplace_back([&flags, config]() -> bool {
  105. if (!FileExists(flags.model_path_)) {
  106. MS_LOG(ERROR) << "model_path \"" << flags.model_path_ << "\" is not valid";
  107. return false;
  108. }
  109. if (config->ParseProjDir(flags.model_path_) != RET_OK) {
  110. return false;
  111. }
  112. return true;
  113. });
  114. parsers.emplace_back([&flags, config]() -> bool {
  115. auto target_item = kTargetMap.find(flags.target_);
  116. MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + flags.target_);
  117. config->set_target(target_item->second);
  118. return true;
  119. });
  120. parsers.emplace_back([&flags, config]() -> bool {
  121. auto code_item = kCodeModeMap.find(flags.code_mode_);
  122. MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + flags.code_mode_);
  123. config->set_code_mode(code_item->second);
  124. return true;
  125. });
  126. parsers.emplace_back([&flags, config]() -> bool {
  127. if (flags.support_parallel_ && config->target() == kARM32M) {
  128. MS_LOG(ERROR) << "arm32M cannot support parallel.";
  129. return false;
  130. }
  131. config->set_support_parallel(flags.support_parallel_);
  132. return true;
  133. });
  134. parsers.emplace_back([&flags, config]() -> bool {
  135. config->set_debug_mode(flags.debug_mode_);
  136. return true;
  137. });
  138. parsers.emplace_back([&flags, config]() -> bool {
  139. const std::string slash = std::string(kSlash);
  140. if (!flags.code_path_.empty() && !DirExists(flags.code_path_)) {
  141. MS_LOG(ERROR) << "code_gen code path " << flags.code_path_ << " is not valid";
  142. return false;
  143. }
  144. config->set_code_path(flags.code_path_);
  145. if (flags.code_path_.empty()) {
  146. std::string path = ".." + slash + config->proj_dir();
  147. config->set_code_path(path);
  148. } else {
  149. if (flags.code_path_.substr(flags.code_path_.size() - 1, 1) != slash) {
  150. std::string path = flags.code_path_ + slash + config->proj_dir();
  151. config->set_code_path(path);
  152. } else {
  153. std::string path = flags.code_path_ + config->proj_dir();
  154. config->set_code_path(path);
  155. }
  156. }
  157. return InitProjDirs(flags.code_path_, config->proj_dir()) != RET_ERROR;
  158. });
  159. if (!std::all_of(parsers.begin(), parsers.end(), [](auto &parser) -> bool { return parser(); })) {
  160. if (!flags.help) {
  161. std::cerr << flags.Usage() << std::endl;
  162. return 0;
  163. }
  164. return RET_ERROR;
  165. }
  166. auto print_parameter = [](auto name, auto value) {
  167. MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value;
  168. };
  169. print_parameter("modelPath", flags.model_path_);
  170. print_parameter("projectName", config->proj_dir());
  171. print_parameter("target", config->target());
  172. print_parameter("codePath", config->code_path());
  173. print_parameter("codeMode", config->code_mode());
  174. print_parameter("debugMode", config->debug_mode());
  175. return RET_OK;
  176. }
  177. int RunCoder(int argc, const char **argv) {
  178. CoderFlags flags;
  179. Option<std::string> err = flags.ParseFlags(argc, argv, false, false);
  180. if (err.IsSome()) {
  181. std::cerr << err.Get() << std::endl;
  182. std::cerr << flags.Usage() << std::endl;
  183. return RET_ERROR;
  184. }
  185. if (flags.help) {
  186. std::cerr << flags.Usage() << std::endl;
  187. return RET_OK;
  188. }
  189. Coder code_gen;
  190. STATUS status = code_gen.Init(flags);
  191. if (status != RET_OK) {
  192. MS_LOG(ERROR) << "Coder init Error";
  193. return status;
  194. }
  195. status = code_gen.Run(flags.model_path_);
  196. if (status != RET_OK) {
  197. MS_LOG(ERROR) << "Coder Run Error.";
  198. return status;
  199. }
  200. MS_LOG(INFO) << "end of Coder";
  201. return RET_OK;
  202. }
  203. } // namespace mindspore::lite::micro