You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin_loader.cc 4.8 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/plugin/plugin_loader.h"
  17. #include <algorithm>
  18. #include <numeric>
  19. #include <set>
  20. #include <vector>
  21. #include "mindspore/core/utils/log_adapter.h"
  22. #include "minddata/dataset/plugin/shared_lib_util.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. PluginLoader *PluginLoader::GetInstance() noexcept {
  26. static PluginLoader pl;
  27. return &pl;
  28. }
  29. PluginLoader::~PluginLoader() {
  30. std::vector<std::string> keys;
  31. // get the keys from map, this is to avoid concurrent iteration and delete
  32. std::transform(plugins_.begin(), plugins_.end(), std::back_inserter(keys), [](const auto &p) { return p.first; });
  33. for (std::string &key : keys) {
  34. Status rc = UnloadPlugin(key);
  35. MSLOG_IF(ERROR, rc.IsError(), mindspore::NoExceptionType) << rc.ToString();
  36. }
  37. }
  38. // LoadPlugin() is NOT thread-safe. It is supposed to be called when Ops are being built. E.g. PluginOp should call this
  39. // within constructor instead of in its Compute() which is parallel.
  40. Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManagerBase **singleton_plugin) {
  41. RETURN_UNEXPECTED_IF_NULL(singleton_plugin);
  42. auto itr = plugins_.find(filename);
  43. // return ok if this module is already loaded
  44. if (itr != plugins_.end()) {
  45. *singleton_plugin = itr->second.first;
  46. return Status::OK();
  47. }
  48. // Open the .so file
  49. void *handle = SharedLibUtil::Load(filename);
  50. CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr,
  51. "[Internal ERROR] Fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
  52. // Load GetInstance function ptr from the so file, so needs to be compiled with -fPIC
  53. void *func_handle = SharedLibUtil::FindSym(handle, "GetInstance");
  54. CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr,
  55. "[Internal ERROR] Fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
  56. // cast the returned function ptr of type void* to the type of GetInstance
  57. plugin::PluginManagerBase *(*get_instance)(plugin::MindDataManagerBase *) =
  58. reinterpret_cast<plugin::PluginManagerBase *(*)(plugin::MindDataManagerBase *)>(func_handle);
  59. RETURN_UNEXPECTED_IF_NULL(get_instance);
  60. *singleton_plugin = get_instance(nullptr); // call function ptr to get instance
  61. RETURN_UNEXPECTED_IF_NULL(*singleton_plugin);
  62. std::string v1 = (*singleton_plugin)->GetPluginVersion(), v2(plugin::kSharedIncludeVersion);
  63. if (v1 != v2) {
  64. std::string err_msg = "[Internal ERROR] expected:" + v2 + ", received:" + v1 + " please recompile.";
  65. if (SharedLibUtil::Close(handle) != 0) err_msg += ("\ndlclose() error, err_msg:" + SharedLibUtil::ErrMsg() + ".");
  66. RETURN_STATUS_UNEXPECTED(err_msg);
  67. }
  68. const std::map<std::string, std::set<std::string>> module_names = (*singleton_plugin)->GetModuleNames();
  69. for (auto &p : module_names) {
  70. std::string msg = "Plugin " + p.first + " has module:";
  71. MS_LOG(DEBUG) << std::accumulate(p.second.begin(), p.second.end(), msg,
  72. [](const std::string &msg, const std::string &nm) { return msg + " " + nm; });
  73. }
  74. // save the name and handle
  75. std::pair<plugin::PluginManagerBase *, void *> plugin_new = std::make_pair(*singleton_plugin, handle);
  76. plugins_.insert({filename, plugin_new});
  77. return Status::OK();
  78. }
  79. Status PluginLoader::UnloadPlugin(const std::string &filename) {
  80. auto itr = plugins_.find(filename);
  81. RETURN_OK_IF_TRUE(itr == plugins_.end()); // return true if this plugin was never loaded or already removed
  82. void *func_handle = SharedLibUtil::FindSym(itr->second.second, "DestroyInstance");
  83. CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr,
  84. "[Internal ERROR] Fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
  85. void (*destroy_instance)() = reinterpret_cast<void (*)()>(func_handle);
  86. RETURN_UNEXPECTED_IF_NULL(destroy_instance);
  87. destroy_instance();
  88. CHECK_FAIL_RETURN_UNEXPECTED(SharedLibUtil::Close(itr->second.second) == 0,
  89. "[Internal ERROR] dlclose() error: " + SharedLibUtil::ErrMsg());
  90. plugins_.erase(filename);
  91. return Status::OK();
  92. }
  93. } // namespace dataset
  94. } // namespace mindspore