You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin_loader.cc 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <numeric>
  17. #include "minddata/dataset/plugin/plugin_loader.h"
  18. #include "minddata/dataset/plugin/shared_lib_util.h"
  19. #include "mindspore/core/utils/log_adapter.h"
  20. namespace mindspore {
  21. namespace dataset {
  22. PluginLoader *PluginLoader::GetInstance() noexcept {
  23. static PluginLoader pl;
  24. return &pl;
  25. }
  26. PluginLoader::~PluginLoader() {
  27. std::vector<std::string> keys;
  28. // get the keys from map, this is to avoid concurrent iteration and delete
  29. std::transform(plugins_.begin(), plugins_.end(), std::back_inserter(keys), [](const auto &p) { return p.first; });
  30. for (std::string &key : keys) {
  31. Status rc = UnloadPlugin(key);
  32. MSLOG_IF(ERROR, rc.IsError(), mindspore::NoExceptionType) << rc.ToString();
  33. }
  34. }
  35. // LoadPlugin() is NOT thread-safe. It is supposed to be called when Ops are being built. E.g. PluginOp should call this
  36. // within constructor instead of in its Compute() which is parallel.
  37. Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManagerBase **singleton_plugin) {
  38. RETURN_UNEXPECTED_IF_NULL(singleton_plugin);
  39. auto itr = plugins_.find(filename);
  40. // return ok if this module is already loaded
  41. if (itr != plugins_.end()) {
  42. *singleton_plugin = itr->second.first;
  43. return Status::OK();
  44. }
  45. // Open the .so file
  46. void *handle = SharedLibUtil::Load(filename);
  47. CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr, "fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
  48. // Load GetInstance function ptr from the so file, so needs to be compiled with -fPIC
  49. void *func_handle = SharedLibUtil::FindSym(handle, "GetInstance");
  50. CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
  51. // cast the returned function ptr of type void* to the type of GetInstance
  52. plugin::PluginManagerBase *(*get_instance)(plugin::MindDataManagerBase *) =
  53. reinterpret_cast<plugin::PluginManagerBase *(*)(plugin::MindDataManagerBase *)>(func_handle);
  54. RETURN_UNEXPECTED_IF_NULL(get_instance);
  55. *singleton_plugin = get_instance(nullptr); // call function ptr to get instance
  56. RETURN_UNEXPECTED_IF_NULL(*singleton_plugin);
  57. std::string v1 = (*singleton_plugin)->GetPluginVersion(), v2(plugin::kSharedIncludeVersion);
  58. if (v1 != v2) {
  59. std::string err_msg = "[Plugin Version Error] expected:" + v2 + ", received:" + v1 + " please recompile.";
  60. if (SharedLibUtil::Close(handle) != 0) err_msg += ("\ndlclose() error, err_msg:" + SharedLibUtil::ErrMsg() + ".");
  61. RETURN_STATUS_UNEXPECTED(err_msg);
  62. }
  63. const std::map<std::string, std::set<std::string>> module_names = (*singleton_plugin)->GetModuleNames();
  64. for (auto &p : module_names) {
  65. std::string msg = "Plugin " + p.first + " has module:";
  66. MS_LOG(DEBUG) << std::accumulate(p.second.begin(), p.second.end(), msg,
  67. [](const std::string &msg, const std::string &nm) { return msg + " " + nm; });
  68. }
  69. // save the name and handle
  70. std::pair<plugin::PluginManagerBase *, void *> plugin_new = std::make_pair(*singleton_plugin, handle);
  71. plugins_.insert({filename, plugin_new});
  72. return Status::OK();
  73. }
  74. Status PluginLoader::UnloadPlugin(const std::string &filename) {
  75. auto itr = plugins_.find(filename);
  76. RETURN_OK_IF_TRUE(itr == plugins_.end()); // return true if this plugin was never loaded or already removed
  77. void *func_handle = SharedLibUtil::FindSym(itr->second.second, "DestroyInstance");
  78. CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
  79. void (*destroy_instance)() = reinterpret_cast<void (*)()>(func_handle);
  80. RETURN_UNEXPECTED_IF_NULL(destroy_instance);
  81. destroy_instance();
  82. CHECK_FAIL_RETURN_UNEXPECTED(SharedLibUtil::Close(itr->second.second) == 0,
  83. "dlclose() error: " + SharedLibUtil::ErrMsg());
  84. plugins_.erase(filename);
  85. return Status::OK();
  86. }
  87. } // namespace dataset
  88. } // namespace mindspore