You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callbacks.cc 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/callbacks.h"
  17. #include <map>
  18. #include <string>
  19. #include <memory>
  20. #include "pybind11/pybind11.h"
  21. #include "pipeline/jit/parse/python_adapter.h"
  22. #include "utils/visible.h"
  23. namespace mindspore {
  24. namespace callbacks {
  25. const char PYTHON_MOD_CALLBACK_MODULE[] = "mindspore.train.callback._callback";
  26. const char PYTHON_FUN_PROCESS_CHECKPOINT[] = "checkpoint_cb_for_save_op";
  27. const char PYTHON_FUN_PROCESS_SUMMARY[] = "summary_cb_for_save_op";
  28. const char kSummary[] = "Summary";
  29. const char kCheckPoint[] = "Save";
  30. const int ONE_SHAPE = 1;
  31. // Cache the summary callback data from ME session
  32. // Remove the GE module on new architecture
  33. // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]
  34. uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map<std::string, TensorPtr> &params_list) {
  35. // Acquire GIL before calling Python code
  36. py::gil_scoped_acquire acquire;
  37. py::list summary_list = py::list();
  38. MS_LOG(INFO) << "The Summary save callback function for graph " << graph_id
  39. << ", Param list size = " << params_list.size() << ".";
  40. for (auto &item : params_list) {
  41. std::string tag_name = item.first;
  42. auto tensor_ptr = item.second;
  43. if (tensor_ptr == nullptr) {
  44. MS_LOG(EXCEPTION) << "Summary tensor is null";
  45. }
  46. py::dict summary_value_dict;
  47. summary_value_dict["name"] = tag_name;
  48. summary_value_dict["data"] = tensor_ptr;
  49. summary_list.append(summary_value_dict);
  50. }
  51. py::bool_ ret = parse::python_adapter::CallPyFn(PYTHON_MOD_CALLBACK_MODULE, PYTHON_FUN_PROCESS_SUMMARY, summary_list);
  52. auto bool_ret = py::cast<bool>(ret);
  53. if (!bool_ret) {
  54. MS_LOG(ERROR) << "Python checkpoint return false during callback";
  55. return kCallbackFalied;
  56. }
  57. MS_LOG(DEBUG) << "End the summary save callback function.";
  58. return kCallbackOk;
  59. }
  60. } // namespace callbacks
  61. } // namespace mindspore