You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callbacks_ge.cc 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/callbacks_ge.h"
  17. #include "pybind11/pybind11.h"
  18. #include "ir/param_value_py.h"
  19. #include "transform/df_graph_manager.h"
  20. #include "transform/util.h"
  21. #include "pipeline/parse/data_converter.h"
  22. #include "pipeline/parse/python_adapter.h"
  23. #include "utils/visible.h"
  24. namespace mindspore {
  25. namespace callbacks {
  26. const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
  27. const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
  28. const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
  29. const char kSummary[] = "Summary";
  30. const char kCheckPoint[] = "Save";
  31. const int ONE_SHAPE = 1;
  32. using mindspore::transform::Status;
  33. using mindspore::transform::TransformUtil;
  34. bool GetParameterShape(const FuncGraphPtr &graph, const std::string &param_name,
  35. const std::shared_ptr<std::vector<int>> &shape) {
  36. if (graph == nullptr) {
  37. MS_LOG(ERROR) << "Graph is null, can not get graph parameter";
  38. return false;
  39. }
  40. auto parameter_nodes = graph->parameters();
  41. for (auto &node : parameter_nodes) {
  42. ParameterPtr param_node = std::static_pointer_cast<Parameter>(node);
  43. if (param_node == nullptr) {
  44. MS_LOG(ERROR) << "Parameter node is null, can not get graph parameter";
  45. return false;
  46. }
  47. if (param_node->name() == param_name) {
  48. py::object parameter;
  49. if (param_node->has_default()) {
  50. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
  51. parameter = param_value->value();
  52. }
  53. ValuePtr value = parse::data_converter::PyDataToValue(parameter);
  54. TensorPtr tensor = std::dynamic_pointer_cast<tensor::Tensor>(value);
  55. if (tensor == nullptr) {
  56. shape->push_back(ONE_SHAPE);
  57. } else {
  58. *shape = tensor->shape();
  59. }
  60. return true;
  61. }
  62. }
  63. MS_LOG(ERROR) << "Can not find parameter of name:" << param_name;
  64. return false;
  65. }
  66. static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string &parameter_name,
  67. const std::shared_ptr<ge::Tensor> &ge_tensor_ptr) {
  68. FuncGraphPtr anf_graph = transform::DfGraphManager::GetInstance().GetAnfGraph(graph_id);
  69. if (anf_graph == nullptr) {
  70. MS_LOG(ERROR) << "Get anf graph failed during callback";
  71. return nullptr;
  72. }
  73. std::shared_ptr<std::vector<int>> parameter_shape_ptr = std::make_shared<std::vector<int>>();
  74. if (!GetParameterShape(anf_graph, parameter_name, parameter_shape_ptr)) {
  75. MS_LOG(ERROR) << "Can not get parameter shape during callback";
  76. return nullptr;
  77. }
  78. return TransformUtil::ConvertGeTensor(ge_tensor_ptr, *parameter_shape_ptr);
  79. }
  80. uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map<std::string, ge::Tensor> &params_list) {
  81. // Acquire GIL before calling Python code
  82. py::gil_scoped_acquire acquire;
  83. MS_LOG(DEBUG) << "Start the checkpoint save callback function in checkpoint save process.";
  84. py::list parameter_list = py::list();
  85. for (auto &item : params_list) {
  86. std::string name = item.first;
  87. std::shared_ptr<ge::Tensor> ge_tensor_ptr = std::make_shared<ge::Tensor>(item.second);
  88. if (name.size() > 5 && name.compare(name.size() - 5, 5, "_temp") == 0) {
  89. continue;
  90. } else {
  91. TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr);
  92. if (tensor_ptr == nullptr) {
  93. MS_LOG(EXCEPTION) << "Transform ge tensor to me tensor failed";
  94. }
  95. py::dict param_dict;
  96. param_dict["name"] = name;
  97. param_dict["data"] = tensor_ptr;
  98. parameter_list.append(param_dict);
  99. }
  100. }
  101. py::bool_ ret =
  102. parse::python_adapter::CallPyFn(PYTHON_MOD_CALLBACK_MODULE, PYTHON_FUN_PROCESS_CHECKPOINT, parameter_list);
  103. auto bool_ret = py::cast<bool>(ret);
  104. uint32_t status = Status::SUCCESS;
  105. if (!bool_ret) {
  106. status = Status::FAILED;
  107. MS_LOG(ERROR) << "Python checkpoint return false during callback";
  108. }
  109. return status;
  110. }
  111. static TensorPtr GetMeTensorForSummary(const std::string &name, const std::shared_ptr<ge::Tensor> &ge_tensor_ptr) {
  112. // confirm the type by name
  113. // Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor]
  114. if (name.empty()) {
  115. MS_LOG(EXCEPTION) << "The summary name is empty.";
  116. }
  117. auto bpos = name.rfind("[:");
  118. if (bpos >= name.size()) {
  119. MS_LOG(EXCEPTION) << "The summary name(" << name << ") is invalid.";
  120. }
  121. auto tname = name.substr(bpos);
  122. if (tname == "[:Scalar]") {
  123. MS_LOG(DEBUG) << "The summary(" << name << ") is Scalar";
  124. // process the scalar type summary
  125. // Because the ge tensor is dim = 4, so set the (1,1,1,1)-->(1,)
  126. // We do the (1,) shape is scalar
  127. auto shape = std::vector<int>({ONE_SHAPE});
  128. return TransformUtil::ConvertGeTensor(ge_tensor_ptr, shape);
  129. }
  130. if (tname == "[:Tensor]" || tname == "[:Histogram]") {
  131. MS_LOG(DEBUG) << "The summary(" << name << ") is Tensor";
  132. // process the tensor summary
  133. // Now we can't get the real shape, so we keep same shape with GE
  134. return TransformUtil::ConvertGeTensor(ge_tensor_ptr);
  135. }
  136. if (tname == "[:Image]") {
  137. MS_LOG(DEBUG) << "The summary(" << name << ") is Image";
  138. // process the Image summary
  139. // Image dim = 4, is same with ge, so we keep same shape with GE
  140. return TransformUtil::ConvertGeTensor(ge_tensor_ptr);
  141. }
  142. MS_LOG(EXCEPTION) << "The summary name(" << name << ") is invalid.";
  143. }
  144. // Cache the summary callback data
  145. // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]
  146. uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map<std::string, ge::Tensor> &params_list) {
  147. // Acquire GIL before calling Python code
  148. py::gil_scoped_acquire acquire;
  149. MS_LOG(DEBUG) << "Start the summary save callback function for graph " << graph_id << ".";
  150. py::list summary_list = py::list();
  151. MS_LOG(DEBUG) << "Param list size = " << params_list.size();
  152. for (auto &item : params_list) {
  153. std::string tag_name = item.first;
  154. std::shared_ptr<ge::Tensor> ge_tensor_ptr = std::make_shared<ge::Tensor>(item.second);
  155. TensorPtr tensor_ptr = GetMeTensorForSummary(tag_name, ge_tensor_ptr);
  156. if (tensor_ptr == nullptr) {
  157. MS_LOG(EXCEPTION) << "ConvertGeTensor return tensor is null";
  158. }
  159. py::dict summary_value_dict;
  160. summary_value_dict["name"] = tag_name;
  161. summary_value_dict["data"] = tensor_ptr;
  162. summary_list.append(summary_value_dict);
  163. }
  164. py::bool_ ret = parse::python_adapter::CallPyFn(PYTHON_MOD_CALLBACK_MODULE, PYTHON_FUN_PROCESS_SUMMARY, summary_list);
  165. auto bool_ret = py::cast<bool>(ret);
  166. if (!bool_ret) {
  167. MS_LOG(ERROR) << "Python checkpoint return false during callback";
  168. return Status::FAILED;
  169. }
  170. MS_LOG(DEBUG) << "End the summary save callback function.";
  171. return Status::SUCCESS;
  172. }
  173. } // namespace callbacks
  174. } // namespace mindspore