You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_runner.cc 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * Limitations under the License.
  15. */
  16. #include "transform/graph_runner.h"
  17. #include <algorithm>
  18. #include <string>
  19. #include <memory>
  20. #include "utils/log_adapter.h"
  21. #include "utils/config_manager.h"
  22. #include "sys/time.h"
  23. #include "utils/callbacks.h"
  24. #include "utils/utils.h"
  25. #include "./common.h"
  26. #ifdef ENABLE_GE
  27. #include "utils/callbacks_ge.h"
  28. #endif
  29. #ifdef NO_GE_CLIENT
  30. namespace ge {
  31. Session::Session(const std::map<std::string, std::string> &options) {
  32. if (options.empty()) {
  33. MS_LOG(ERROR) << "session input options is empty";
  34. }
  35. sessionId_ = 0;
  36. }
  37. Session::~Session() {}
  38. } // namespace ge
  39. #endif
  40. namespace mindspore {
  41. namespace transform {
  42. std::shared_ptr<ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
  43. std::shared_ptr<ge::Session> ret = std::make_shared<ge::Session>(sess_options);
  44. if (ret == nullptr) {
  45. MS_LOG(ERROR) << "Create GE session failed";
  46. return nullptr;
  47. }
  48. MS_LOG(INFO) << "Create new GE session success";
  49. return ret;
  50. }
  51. GraphRunner::GraphRunner(const GraphRunnerOptions &options)
  52. : options_(options), graph_manager_(DfGraphManager::GetInstance()) {
  53. if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) {
  54. MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode";
  55. }
  56. if (options.sess_ptr != nullptr) {
  57. sess_ = options.sess_ptr;
  58. } else {
  59. sess_ = NewSession(options.options);
  60. if (sess_ == nullptr) {
  61. MS_LOG(EXCEPTION) << "GraphRunner initialize failed!!";
  62. return;
  63. }
  64. }
  65. #if (defined ENABLE_GE)
  66. // register the callback function
  67. if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ge::GRAPH_SUCCESS) {
  68. MS_LOG(EXCEPTION) << "register callback failed!";
  69. return;
  70. }
  71. if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ge::GRAPH_SUCCESS) {
  72. MS_LOG(EXCEPTION) << "register summary callback failed!";
  73. return;
  74. }
  75. #endif
  76. std::vector<DfGraphWrapperPtr> wrappers = graph_manager_.GetAllGraphs();
  77. if (wrappers.empty()) {
  78. MS_LOG(INFO) << "The GraphManager is empty!!";
  79. return;
  80. }
  81. #ifdef ENABLE_GE
  82. for (auto &it : wrappers) {
  83. std::set<string> saved_graph = graph_manager_.GetSavedGraphs();
  84. auto iter_find = saved_graph.find(std::to_string(it->id_));
  85. if (iter_find != saved_graph.end()) {
  86. continue;
  87. }
  88. MS_LOG(INFO) << "Add the graph " << (*it).name_ << " to GE, it's id is: " << (*it).id_;
  89. graph_manager_.AddSavedGraphs(std::to_string(it->id_));
  90. (void)sess_->AddGraph(it->id_, *(it->graph_ptr_), it->options_);
  91. }
  92. #endif
  93. }
  94. Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<GeTensorPtr> &inputs,
  95. std::vector<GeTensorPtr> *outputs) {
  96. std::string name = options.name;
  97. if (name.empty()) {
  98. MS_LOG(ERROR) << "The graph name is null";
  99. return Status::INVALID_ARGUMENT;
  100. }
  101. DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name);
  102. if (wrap_ptr == nullptr) {
  103. MS_LOG(ERROR) << "Get graph form DfGraphManager failed!";
  104. return Status::NOT_FOUND;
  105. }
  106. if (wrap_ptr->graph_ptr_ == nullptr) {
  107. MS_LOG(WARNING) << "The graph is null";
  108. return Status::NOT_FOUND;
  109. }
  110. // call ge::RunGraph() to exec a graph;
  111. std::vector<GeTensor> ge_inputs;
  112. std::vector<GeTensor> ge_outputs;
  113. (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs),
  114. [](const GeTensorPtr &i) { return *i; });
  115. MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs";
  116. struct timeval start_time, end_time;
  117. (void)gettimeofday(&start_time, nullptr);
  118. #ifdef ENABLE_GE
  119. if (sess_ == nullptr) {
  120. MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
  121. return Status::FAILED;
  122. }
  123. // The information of some nodes could be changed after fusion in some cases
  124. // Therefore a graph needs to be rebuilt in above situation
  125. if (sess_->IsGraphNeedRebuild(wrap_ptr->id_)) {
  126. sess_->RemoveGraph(wrap_ptr->id_);
  127. sess_->AddGraph(wrap_ptr->id_, *(wrap_ptr->graph_ptr_), wrap_ptr->options_);
  128. }
  129. ge::Status ret = sess_->RunGraph(wrap_ptr->id_, ge_inputs, ge_outputs);
  130. if (ret != ge::GRAPH_SUCCESS) {
  131. MS_LOG(ERROR) << "Call GE RunGraph Failed, ret is: " << ret;
  132. return Status::FAILED;
  133. }
  134. #else
  135. ge_outputs.swap(ge_inputs);
  136. #endif
  137. (void)gettimeofday(&end_time, nullptr);
  138. const uint64_t kUSecondInSecond = 1000000;
  139. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  140. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  141. MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << ge_outputs.size();
  142. (void)std::transform(ge_outputs.begin(), ge_outputs.end(), std::back_inserter(*outputs),
  143. [](const GeTensor &ge_tensor) { return std::make_shared<GeTensor>(ge_tensor); });
  144. return Status::SUCCESS;
  145. }
  146. Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<MeTensorPtr> &inputs,
  147. std::vector<MeTensorPtr> *const outputs) {
  148. std::vector<GeTensorPtr> ge_inputs;
  149. for (auto it : inputs) {
  150. MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize();
  151. auto shape = (*it).shape();
  152. std::string shape_str;
  153. for (const auto &elem : shape) {
  154. shape_str += std::to_string(elem);
  155. shape_str += " ";
  156. }
  157. MS_LOG(INFO) << "inputs tensor's shape is: { " << shape_str << "}";
  158. auto ge_tensor_ptr = TransformUtil::ConvertTensor(it, kOpFormat_NCHW);
  159. if (ge_tensor_ptr != nullptr) {
  160. ge_inputs.emplace_back(ge_tensor_ptr);
  161. } else {
  162. MS_LOG(INFO) << "Convert input Me tensor to Ge tensor failed. Abort this graph";
  163. return Status::FAILED;
  164. }
  165. }
  166. std::vector<GeTensorPtr> ge_outputs;
  167. Status ret;
  168. {
  169. // Release GIL before calling into (potentially long-running) C++ code
  170. py::gil_scoped_release release;
  171. ret = RunGraph(options, ge_inputs, &ge_outputs);
  172. }
  173. if (ret != Status::SUCCESS) {
  174. return ret;
  175. } else {
  176. // conver GeTensor to MeTensor
  177. for (auto &it : ge_outputs) {
  178. auto tensor = TransformUtil::ConvertGeTensor(it);
  179. if (tensor != nullptr) {
  180. outputs->emplace_back(tensor);
  181. }
  182. }
  183. MS_LOG(INFO) << "Return Me tensor outputs num is: " << outputs->size();
  184. return Status::SUCCESS;
  185. }
  186. }
  187. } // namespace transform
  188. } // namespace mindspore