You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

backend.cc 14 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "vm/backend.h"
  17. #include <algorithm>
  18. #include <vector>
  19. #include "utils/log_adapter.h"
  20. #include "ir/anf.h"
  21. #include "utils/callbacks.h"
  22. #include "utils/graph_utils.h"
  23. #include "utils/base_ref_extends.h"
  24. #include "session/session_factory.h"
  25. #include "common/utils.h"
  26. #ifdef ENABLE_GE
  27. #include "utils/callbacks_ge.h"
  28. #endif
  29. namespace mindspore {
  30. namespace compile {
  31. bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); }
  32. LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) {
  33. // multi_graph merge to one, big graph have paramters in begin and only have one output
  34. MS_LOG(DEBUG) << "graph:" << g->ToString() << " parameter size:" << g->parameters().size();
  35. multi_result_.inputs = g->parameters();
  36. final_output_ = NewValueNode("fake_output");
  37. multi_result_.outputs = {final_output_};
  38. GraphId final_g = target_sess_->GetFinalRunGraph();
  39. multi_result_.run = std::make_shared<RunFunc>(
  40. [final_g, this](const VectorRef &args) -> VectorRef { return MsRunGraph(final_g, args, ""); });
  41. return multi_result_;
  42. }
  43. LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) {
  44. MS_LOG(DEBUG) << "MsConvert";
  45. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  46. auto cached = g_ConvertCache.find(lst);
  47. if (cached != g_ConvertCache.end()) {
  48. return cached->second;
  49. }
  50. LinConvertResult result;
  51. FuncGraphPtr fg;
  52. AnfNodePtrList inputs;
  53. AnfNodePtrList outputs;
  54. std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst);
  55. result.inputs = inputs;
  56. result.outputs = outputs;
  57. result.graph_id = kInvalidGraphId;
  58. GraphId graph_id = kInvalidGraphId;
  59. if (target != target_device_ && !target.empty()) {
  60. CreateOtherSession(target);
  61. graph_id = other_sess_->CompileGraph(lst, outputs);
  62. } else {
  63. graph_id = target_sess_->CompileGraph(lst, outputs);
  64. }
  65. if (MsContext::GetInstance()->precompile_only()) {
  66. MS_LOG(INFO) << "PrecompileOnly, stop run graph";
  67. return result;
  68. }
  69. if (target != target_device_ && !target.empty()) {
  70. other_sess_->BuildGraph(graph_id);
  71. } else if (!is_multi_graph_sink_) {
  72. target_sess_->BuildGraph(graph_id);
  73. }
  74. result.run = std::make_shared<RunFunc>(
  75. [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
  76. MS_EXCEPTION_IF_NULL(result.run);
  77. result.simu_run = std::make_shared<RunFunc>(
  78. [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id, args); });
  79. MS_EXCEPTION_IF_NULL(result.simu_run);
  80. result.graph_id = graph_id;
  81. graph_id_map_[graph_id] = result;
  82. (void)g_ConvertCache.emplace(lst, result);
  83. return result;
  84. }
  85. void MsBackend::SetSwitchActive(const BaseRef &c, bool cond) {
  86. GraphId active_g = simu_cond_map_[c].cond_graph_map[cond];
  87. GraphId cond_g = kInvalidGraphId;
  88. if (utils::isa<AnfNodePtr>(c)) {
  89. cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(c));
  90. } else {
  91. MS_LOG(EXCEPTION) << "cond not a anf node:" << c.ToString();
  92. }
  93. auto before_cond = curr_switch_;
  94. if (curr_switch_.hash() != c.hash()) {
  95. // invoke while false->before true call
  96. if (simu_cond_map_[before_cond].cond_graph_map.count(false)) {
  97. active_g = simu_cond_map_[before_cond].cond_graph_map[false];
  98. } else {
  99. active_g = kInvalidGraphId;
  100. }
  101. // while x < y:
  102. // z = y + 1
  103. // while z < c2:
  104. // out = out + 1
  105. // z = z + 1
  106. if (active_g == cond_g) {
  107. active_g = kInvalidGraphId;
  108. simu_cond_map_[before_cond].cond_graph_map[false] = kInvalidGraphId;
  109. }
  110. MS_LOG(DEBUG) << "invoke set active:" << active_g;
  111. }
  112. MS_LOG(DEBUG) << "switch set active:" << active_g << ", " << cond_g;
  113. target_sess_->SetActive(active_g, cond_g);
  114. }
  115. void MsBackend::SetSwitchGraph() {
  116. MS_LOG(DEBUG) << "SetSwitchGraph curr_switch:" << curr_switch_.ToString();
  117. if (is_switch_call_) {
  118. GraphId false_g = kInvalidGraphId;
  119. GraphId true_g = kInvalidGraphId;
  120. MS_LOG(DEBUG) << "start SetSwitchGraph";
  121. true_g = simu_cond_map_[curr_switch_].cond_graph_map[true];
  122. bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
  123. if (!curr_cond) {
  124. if (simu_cond_map_[curr_switch_].cond_graph_map.count(curr_cond)) {
  125. // has false branch
  126. false_g = simu_cond_map_[curr_switch_].cond_graph_map[false];
  127. }
  128. GraphId cond_g = kInvalidGraphId;
  129. if (utils::isa<AnfNodePtr>(curr_switch_)) {
  130. cond_g = target_sess_->GetGraphIdByNode(utils::cast<AnfNodePtr>(curr_switch_));
  131. } else {
  132. MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString();
  133. }
  134. MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g;
  135. target_sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast<AnfNodePtr>(curr_switch_));
  136. }
  137. is_switch_call_ = false;
  138. MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;
  139. }
  140. }
  141. // convert node from formal parameter to actual parameter,
  142. // and actual parameter is graph user's formal parameter.
  143. // get top while graph's parameter in recall while.
  144. AnfNodePtr MsBackend::ConvertGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
  145. std::unordered_map<AnfNodePtr, size_t> params_index;
  146. auto result = node;
  147. auto graph = result->func_graph();
  148. while (func_graph != graph) {
  149. auto iter = graph_user_inputs_.find(graph);
  150. if (iter == graph_user_inputs_.end()) {
  151. break;
  152. }
  153. params_index.clear();
  154. auto &params = graph->parameters();
  155. for (size_t i = 0; i < params.size(); ++i) {
  156. params_index[params[i]] = i;
  157. }
  158. graph = iter->second.first;
  159. auto &inputs = iter->second.second;
  160. result = inputs[params_index[result]];
  161. }
  162. return result;
  163. }
  164. void MsBackend::SetGraphUserInputs(const FuncGraphPtr &func_graph, const FuncGraphPtr &user,
  165. const AnfNodePtrList &inputs) {
  166. if (graph_user_inputs_.find(func_graph) != graph_user_inputs_.end()) {
  167. return;
  168. }
  169. graph_user_inputs_[func_graph] = {user, inputs};
  170. }
  171. void MsBackend::RecallGraphInput(const FuncGraphPtr &func_graph, const VectorRef &args, const BaseRef &c) {
  172. std::unordered_map<AnfNodePtr, size_t> params_index;
  173. auto &params = func_graph->parameters();
  174. for (size_t i = 0; i < params.size(); ++i) {
  175. params_index[params[i]] = i;
  176. }
  177. // recall all child graphs in this while
  178. auto &graph_inputs = graph_inputs_[c];
  179. for (auto &iter : graph_inputs) {
  180. auto &graph = iter.first;
  181. auto &old_args = iter.second;
  182. auto &result = graph_id_map_[graph];
  183. auto &inputs = result.inputs;
  184. for (size_t i = 0; i < inputs.size(); ++i) {
  185. auto input = ConvertGraphInput(func_graph, inputs[i]);
  186. auto it = params_index.find(input);
  187. if (it != params_index.end()) {
  188. old_args[i] = args[it->second];
  189. }
  190. }
  191. target_sess_->SetChildGraphInput(graph, old_args);
  192. }
  193. graph_inputs_.erase(c);
  194. }
  195. // compile set input output
  196. VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
  197. MS_LOG(DEBUG) << "set graph input:" << g;
  198. // switch maybe twice
  199. target_sess_->SetChildGraphInput(g, args);
  200. if (is_switch_call_) {
  201. if (!curr_switch_.is_null()) {
  202. // push this {g, args} to all user while graph_inputs for nest while,
  203. // when current condition recall over delete this cond in graph_inputs.
  204. for (auto &iter : graph_inputs_) {
  205. iter.second.push_back({g, args});
  206. }
  207. if (graph_inputs_.find(curr_switch_) == graph_inputs_.end()) {
  208. graph_inputs_[curr_switch_].push_back({g, args});
  209. }
  210. }
  211. bool curr_cond = simu_cond_map_[curr_switch_].curr_cond;
  212. MS_LOG(DEBUG) << "switch call MsSimuRunGraph:" << curr_cond << ", " << g;
  213. simu_cond_map_[curr_switch_].cond_graph_map[curr_cond] = g;
  214. SetSwitchGraph();
  215. }
  216. std::vector<BaseRef> outputs;
  217. (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
  218. [](const AnfNodePtr &v) { return v; });
  219. return VectorRef(outputs);
  220. }
  221. VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
  222. MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
  223. // Run graph
  224. std::vector<tensor::TensorPtr> inputs;
  225. for (const auto &arg : args) {
  226. if (utils::isa<tensor::TensorPtr>(arg)) {
  227. auto value = utils::cast<tensor::TensorPtr>(arg);
  228. inputs.push_back(value);
  229. } else if (utils::isa<ValuePtr>(arg)) {
  230. auto value = utils::cast<ValuePtr>(arg);
  231. if (value->isa<ValueTuple>()) {
  232. (void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(),
  233. std::back_inserter(inputs),
  234. [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
  235. } else if (value->isa<Scalar>()) {
  236. tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
  237. MS_EXCEPTION_IF_NULL(scalar_tensor);
  238. inputs.push_back(scalar_tensor);
  239. } else {
  240. inputs.push_back(value->cast<tensor::TensorPtr>());
  241. }
  242. } else if (utils::isa<PyObjectRef>(arg)) {
  243. auto value = utils::cast<PyObjectRef>(arg).object_;
  244. inputs.push_back(py::cast<tensor::TensorPtr>(value));
  245. } else if (utils::isa<VectorRefPtr>(arg)) {
  246. auto args_new = utils::cast<VectorRef>(arg);
  247. (void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs),
  248. [](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); });
  249. } else {
  250. MS_LOG(WARNING) << "Invalid input type.";
  251. }
  252. }
  253. VectorRef outputs;
  254. // call ms rungraph (graphId, input ,output)
  255. if (target != target_device_ && !target.empty()) {
  256. other_sess_->RunGraph(g, inputs, &outputs);
  257. } else {
  258. target_sess_->RunGraph(g, inputs, &outputs);
  259. }
  260. MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
  261. return outputs;
  262. }
  263. SwitchCondStatus MsBackend::SetSimuCond(const BaseRef &c, bool value) {
  264. MS_LOG(DEBUG) << "set cond :" << c.ToString() << ", " << simu_cond_map_.size();
  265. CondGraph cond_graph;
  266. cond_graph.curr_cond = value;
  267. if (simu_cond_map_.find(c) == simu_cond_map_.end()) {
  268. simu_cond_map_[c] = cond_graph;
  269. }
  270. if (simu_cond_map_[c].cond_graph_map.count(value)) {
  271. return kCondAlreadyRun;
  272. }
  273. simu_cond_map_[c].curr_cond = value;
  274. MS_LOG(DEBUG) << "end set cond ";
  275. return kCondOk;
  276. }
  277. void MsBackend::SimulateRun(FinalVMPtr rt, FuncGraphPtr root) {
  278. MS_LOG(DEBUG) << "Simulate run,root:" << root->ToString() << ", " << root->parameters().size();
  279. std::vector<BaseRef> args;
  280. auto parameters = root->parameters();
  281. (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args),
  282. [](const AnfNodePtr &v) { return v; });
  283. MS_LOG(DEBUG) << "Simulate start";
  284. (void)target_sess_->SetFinalGraphInput(parameters);
  285. BaseRef output = rt->Eval(VectorRef(args));
  286. target_sess_->SetFinalGraphOutput(output);
  287. MS_LOG(DEBUG) << "Simulate Eval end";
  288. }
  289. void MsBackend::Link(GraphId graph_id) {
  290. if (graph_id == kInvalidGraphId) {
  291. graph_id = target_sess_->GetFinalRunGraph();
  292. }
  293. target_sess_->BuildGraph(graph_id);
  294. }
  295. Backend::Backend(const std::string &name) : name_(name) {
  296. MS_LOG(DEBUG) << "select backend:" << name;
  297. convert_fn_ = backends[name_];
  298. is_switch_call_ = false;
  299. is_multi_graph_sink_ = false;
  300. simu_flag_ = false;
  301. }
  302. MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
  303. convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
  304. target_sess_ = session::SessionFactory::Get().Create(target);
  305. if (target_sess_ == nullptr) {
  306. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  307. }
  308. target_sess_->Init(device_id);
  309. target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  310. target_device_ = target;
  311. }
  312. void MsBackend::CreateOtherSession(const std::string &target) {
  313. if (other_sess_ != nullptr && other_device_ == target) {
  314. return;
  315. }
  316. other_sess_ = session::SessionFactory::Get().Create(kCPUDevice);
  317. if (other_sess_ == nullptr) {
  318. MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
  319. }
  320. other_sess_->Init(0);
  321. other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
  322. other_device_ = target;
  323. }
  324. GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_->CompileGraph(fg); }
  325. VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
  326. } // namespace compile
  327. } // namespace mindspore