You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin_loader.cc 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <numeric>
  18. #include <set>
  19. #include <string>
  20. #include <vector>
  21. #include "minddata/dataset/plugin/plugin_loader.h"
  22. #include "minddata/dataset/plugin/shared_lib_util.h"
  23. #include "mindspore/core/utils/log_adapter.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. PluginLoader *PluginLoader::GetInstance() noexcept {
  27. static PluginLoader pl;
  28. return &pl;
  29. }
  30. PluginLoader::~PluginLoader() {
  31. std::vector<std::string> keys;
  32. // get the keys from map, this is to avoid concurrent iteration and delete
  33. std::transform(plugins_.begin(), plugins_.end(), std::back_inserter(keys), [](const auto &p) { return p.first; });
  34. for (std::string &key : keys) {
  35. Status rc = UnloadPlugin(key);
  36. MSLOG_IF(ERROR, rc.IsError(), mindspore::NoExceptionType) << rc.ToString();
  37. }
  38. }
  39. // LoadPlugin() is NOT thread-safe. It is supposed to be called when Ops are being built. E.g. PluginOp should call this
  40. // within constructor instead of in its Compute() which is parallel.
  41. Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManagerBase **singleton_plugin) {
  42. RETURN_UNEXPECTED_IF_NULL(singleton_plugin);
  43. auto itr = plugins_.find(filename);
  44. // return ok if this module is already loaded
  45. if (itr != plugins_.end()) {
  46. *singleton_plugin = itr->second.first;
  47. return Status::OK();
  48. }
  49. // Open the .so file
  50. void *handle = SharedLibUtil::Load(filename);
  51. CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr, "fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
  52. // Load GetInstance function ptr from the so file, so needs to be compiled with -fPIC
  53. void *func_handle = SharedLibUtil::FindSym(handle, "GetInstance");
  54. CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
  55. // cast the returned function ptr of type void* to the type of GetInstance
  56. plugin::PluginManagerBase *(*get_instance)(plugin::MindDataManagerBase *) =
  57. reinterpret_cast<plugin::PluginManagerBase *(*)(plugin::MindDataManagerBase *)>(func_handle);
  58. RETURN_UNEXPECTED_IF_NULL(get_instance);
  59. *singleton_plugin = get_instance(nullptr); // call function ptr to get instance
  60. RETURN_UNEXPECTED_IF_NULL(*singleton_plugin);
  61. std::string v1 = (*singleton_plugin)->GetPluginVersion(), v2(plugin::kSharedIncludeVersion);
  62. // Version check, if version are not the same, log the error and return fail
  63. if (v1 != v2) {
  64. std::string err_msg = "[Plugin Version Error] expected:" + v2 + ", received:" + v1 + " please recompile.";
  65. if (SharedLibUtil::Close(handle) != 0) err_msg += ("\ndlclose() error, err_msg:" + SharedLibUtil::ErrMsg() + ".");
  66. RETURN_STATUS_UNEXPECTED(err_msg);
  67. }
  68. const std::map<std::string, std::set<std::string>> module_names = (*singleton_plugin)->GetModuleNames();
  69. for (auto &p : module_names) {
  70. std::string msg = "Plugin " + p.first + " has module:";
  71. MS_LOG(DEBUG) << std::accumulate(p.second.begin(), p.second.end(), msg,
  72. [](const std::string &msg, const std::string &nm) { return msg + " " + nm; });
  73. }
  74. // save the name and handle
  75. plugins_.insert({filename, {*singleton_plugin, handle}});
  76. return Status::OK();
  77. }
  78. Status PluginLoader::UnloadPlugin(const std::string &filename) {
  79. auto itr = plugins_.find(filename);
  80. RETURN_OK_IF_TRUE(itr == plugins_.end()); // return true if this plugin was never loaded or already removed
  81. void *func_handle = SharedLibUtil::FindSym(itr->second.second, "DestroyInstance");
  82. CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
  83. void (*destroy_instance)() = reinterpret_cast<void (*)()>(func_handle);
  84. RETURN_UNEXPECTED_IF_NULL(destroy_instance);
  85. destroy_instance();
  86. CHECK_FAIL_RETURN_UNEXPECTED(SharedLibUtil::Close(itr->second.second) == 0,
  87. "dlclose() error: " + SharedLibUtil::ErrMsg());
  88. plugins_.erase(filename);
  89. return Status::OK();
  90. }
  91. } // namespace dataset
  92. } // namespace mindspore