You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backend.cc 12 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "vm/backend.h"
  17. #include <algorithm>
  18. #include <vector>
  19. #include "utils/log_adapter.h"
  20. #include "ir/anf.h"
  21. #include "utils/callbacks.h"
  22. #include "utils/graph_utils.h"
  23. #include "session/session_factory.h"
  24. #include "common/utils.h"
  25. #ifdef ENABLE_GE
  26. #include "utils/callbacks_ge.h"
  27. #endif
  28. namespace mindspore {
  29. namespace compile {
  30. bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
  31. LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
  32. // multi_graph merge to one, big graph have paramters in begin and only have one output
  33. MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
  34. multi_result_.inputs = g->parameters();
  35. final_output_ = NewValueNode("fake_output");
  36. multi_result_.outputs = {final_output_};
  37. GraphId final_g = sess_->GetFinalRunGraph();
  38. multi_result_.run = std::make_shared<RunFunc>(
  39. [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args); });
  40. return multi_result_;
  41. }
  42. LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
  43. MS_LOG(DEBUG) << "MsConvert";
  44. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  45. auto cached = g_ConvertCache.find(lst);
  46. if (cached != g_ConvertCache.end()) {
  47. return cached->second;
  48. }
  49. LinConvertResult result;
  50. FuncGraphPtr fg;
  51. AnfNodePtrList inputs;
  52. AnfNodePtrList outputs;
  53. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
  54. result.inputs = inputs;
  55. result.outputs = outputs;
  56. result.graph_id = kInvalidGraphId;
  57. auto graph_id = sess_->CompileGraph(lst, outputs);
  58. if (MsContext::GetInstance()->precompile_only()) {
  59. MS_LOG(INFO) << "PrecompileOnly, stop run graph";
  60. return result;
  61. }
  62. result.run = std::make_shared<RunFunc>(
  63. [graph_id, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args); });
  64. MS_EXCEPTION_IF_NULL(result.run);
  65. result.simu_run = std::make_shared<RunFunc>(
  66. [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); });
  67. MS_EXCEPTION_IF_NULL(result.simu_run);
  68. result.graph_id = graph_id;
  69. graph_id_map_[graph_id] = result;
  70. (void)g_ConvertCache.emplace(lst, result);
  71. return result;
  72. }
  73. void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
  74. GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
  75. GraphId cond_g = kInvalidGraphId;
  76. if (utils::isa<AnfNodePtr>(c)) {
  77. cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
  78. } else {
  79. MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
  80. }
  81. auto before_cond = curr_switch_;
  82. if (curr_switch_.hash() != c.hash()) {
  83. // invoke while false->before true call
  84. if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
  85. active_g = simu_cond_map_[before_cond].cond_graph_map[false];
  86. } else {
  87. active_g = kInvalidGraphId;
  88. }
  89. // while x < y:
  90. // z = y + 1
  91. // while z < c2:
  92. // out = out + 1
  93. // z = z + 1
  94. if (active_g == cond_g) {
  95. active_g = kInvalidGraphId;
  96. simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
  97. }
  98. MS_LOG(DEBUG) << "invoke set active:" << active_g;
  99. }
  100. MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
  101. sess_->SetActive(active_g, cond_g);
  102. }
  103. void MsBackend::SetSwitchGraph() {
  104. MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
  105. if (is_switch_call_) {
  106. GraphId false_g = kInvalidGraphId;
  107. GraphId true_g = kInvalidGraphId;
  108. MS_LOG(DEBUG) << "start SetSwitchGraph";
  109. true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
  110. bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
  111. if (!curr_cond) {
  112. if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
  113. // has false branch
  114. false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
  115. }
  116. GraphId cond_g = kInvalidGraphId;
  117. if (utils::isa<AnfNodePtr>(curr_switch_)) {
  118. cond_g = sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
  119. } else {
  120. MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
  121. }
  122. MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
  123. sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
  124. }
  125. is_switch_call_ = false;
  126. MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
  127. }
  128. }
  129. // convert node from formal parameter to actual parameter,
  130. // and actual parameter is graph user's formal parameter.
  131. // get top while graph's parameter in recall while.
  132. AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  133. std::unordered_map<AnfNodePtr, size_t> params_index;
  134. auto result = node;
  135. auto graph = result->func_graph();
  136. while (func_graph != graph) {
  137. auto iter = graph_user_inputs_.find(graph);
  138. if (iter == graph_user_inputs_.end()) {
  139. break;
  140. }
  141. params_index.clear();
  142. auto &params = graph->parameters();
  143. for (size_t i = 0; i < params.size(); ++i) {
  144. params_index[params[i]] = i;
  145. }
  146. graph = iter->second.first;
  147. auto &inputs = iter->second.second;
  148. result = inputs[params_index[result]];
  149. }
  150. return result;
  151. }
  152. void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
  153. const AnfNodePtrList &inputs) {
  154. if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
  155. return;
  156. }
  157. graph_user_inputs_[func_graph] = {user, inputs};
  158. }
  159. void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
  160. std::unordered_map<AnfNodePtr, size_t> params_index;
  161. auto &params = func_graph->parameters();
  162. for (size_t i = 0; i < params.size(); ++i) {
  163. params_index[params[i]] = i;
  164. }
  165. // recall all child graphs in this while
  166. auto &graph_inputs = graph_inputs_[c];
  167. for (auto &iter : graph_inputs) {
  168. auto &graph = iter.first;
  169. auto &old_args = iter.second;
  170. auto &result = graph_id_map_[graph];
  171. auto &inputs = result.inputs;
  172. for (size_t i = 0; i < inputs.size(); ++i) {
  173. auto input = ConvertGraphInput(func_graph, inputs[i]);
  174. auto it = params_index.find(input);
  175. if (it != params_index.end()) {
  176. old_args[i] = args[it->second];
  177. }
  178. }
  179. sess_->SetChildGraphInput(graph, old_args);
  180. }
  181. graph_inputs_.erase(c);
  182. }
  183. // compile set input output
  184. VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
  185. MS_LOG(DEBUG) << "set graph input:" << g;
  186. // switch maybe twice
  187. sess_->SetChildGraphInput(g, args);
  188. if (is_switch_call_) {
  189. if (!curr_switch_.is_null()) {
  190. // push this {g, args} to all user while graph_inputs for nest while,
  191. // when current condition recall over delete this cond in graph_inputs.
  192. for (auto &iter : graph_inputs_) {
  193. iter.second.push_back({g, args});
  194. }
  195. if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
  196. graph_inputs_[curr_switch_].push_back({g, args});
  197. }
  198. }
  199. bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
  200. MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
  201. simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
  202. SetSwitchGraph();
  203. }
  204. std::vector<BaseRef> outputs;
  205. (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
  206. [](const AnfNodePtr &v) { return v; });
  207. return VectorRef(outputs);
  208. }
  209. VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
  210. MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
  211. // Run graph
  212. std::vector<tensor::TensorPtr> inputs;
  213. for (const auto &arg : args) {
  214. if (utils::isa<tensor::TensorPtr>(arg)) {
  215. auto value = utils::cast<tensor::TensorPtr>(arg);
  216. inputs.push_back(value);
  217. } else if (utils::isa<ValuePtr>(arg)) {
  218. auto value = utils::cast<ValuePtr>(arg);
  219. if (value->isa<ValueTuple>()) {
  220. (void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(),
  221. std::back_inserter(inputs),
  222. [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
  223. } else if (value->isa<Scalar>()) {
  224. tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
  225. MS_EXCEPTION_IF_NULL(scalar_tensor);
  226. inputs.push_back(scalar_tensor);
  227. } else {
  228. inputs.push_back(value->cast<tensor::TensorPtr>());
  229. }
  230. } else if (utils::isa<PyObjectRef>(arg)) {
  231. auto value = utils::cast<PyObjectRef>(arg).object_;
  232. inputs.push_back(py::cast<tensor::TensorPtr>(value));
  233. } else if (utils::isa<VectorRefPtr>(arg)) {
  234. auto args_new = utils::cast<VectorRef>(arg);
  235. (void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs),
  236. [](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); });
  237. } else {
  238. MS_LOG(WARNING) << "Invalid input type.";
  239. }
  240. }
  241. VectorRef outputs;
  242. // call ms rungraph (graphId, input ,output)
  243. sess_->RunGraph(g, inputs, &outputs);
  244. MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
  245. return outputs;
  246. }
  247. SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
  248. MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
  249. CondGraph cond_graph;
  250. cond_graph.curr_cond = value;
  251. if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
  252. simu_cond_map_[c] = cond_graph;
  253. }
  254. if (simu_cond_map_[c].cond_graph_map.count(value)) {
  255. return kCondAlreadyRun;
  256. }
  257. simu_cond_map_[c].curr_cond = value;
  258. MS_LOG(DEBUG) << "end set cond ";
  259. return kCondOk;
  260. }
  261. void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
  262. MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
  263. std::vector<BaseRef> args;
  264. auto parameters = root->parameters();
  265. (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
  266. [](const AnfNodePtr &v) { return v; });
  267. MS_LOG(DEBUG) << "Simulate start";
  268. (void)sess_->SetFinalGraphInput(parameters);
  269. BaseRef output = rt->Eval(VectorRef(args));
  270. sess_->SetFinalGraphOutput(output);
  271. MS_LOG(DEBUG) << "Simulate Eval end";
  272. }
  273. void MsBackend::Link(GraphId graph_id) {
  274. if (graph_id == kInvalidGraphId) {
  275. graph_id = sess_->GetFinalRunGraph();
  276. }
  277. sess_->BuildGraph(graph_id);
  278. }
  279. Backend::Backend(const std::string &name) : name_(name) {
  280. MS_LOG(DEBUG) << "select backend:" << name;
  281. convert_fn_ = backends[name_];
  282. is_switch_call_ = false;
  283. is_multi_graph_sink_ = false;
  284. simu_flag_ = false;
  285. }
  286. MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
  287. convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1);
  288. sess_ = session::SessionFactory::Get().Create(target);
  289. if (sess_ == nullptr) {
  290. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  291. }
  292. sess_->Init(device_id);
  293. sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  294. }
  295. } // namespace compile
  296. } // namespace mindspore