You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

device_context_manager.cc 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/hardware/device_context_manager.h"
  17. namespace mindspore {
  18. namespace device {
  19. void DeviceContextManager::Register(const std::string &device_name, DeviceContextCreator &&device_context_creator) {
  20. if (device_context_creators_.find(device_name) == device_context_creators_.end()) {
  21. (void)device_context_creators_.emplace(device_name, device_context_creator);
  22. }
  23. }
  24. void DeviceContextManager::ClearDeviceContexts() {
  25. for (auto &iter : device_contexts_) {
  26. MS_LOG(INFO) << "Release device " << iter.first;
  27. MS_EXCEPTION_IF_NULL(iter.second);
  28. iter.second->Destroy();
  29. }
  30. device_contexts_.clear();
  31. }
  32. DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) {
  33. std::string device_context_key_str = device_context_key.ToString();
  34. auto device_context_iter = device_contexts_.find(device_context_key_str);
  35. if (device_context_iter != device_contexts_.end()) {
  36. return device_context_iter->second.get();
  37. }
  38. std::shared_ptr<DeviceContext> device_context;
  39. auto creator_iter = device_context_creators_.find(device_context_key.device_name_);
  40. if (creator_iter != device_context_creators_.end()) {
  41. device_context = (creator_iter->second)(device_context_key);
  42. MS_EXCEPTION_IF_NULL(device_context);
  43. device_contexts_[device_context_key_str] = device_context;
  44. } else {
  45. MS_LOG(EXCEPTION) << "Create device context failed, please make sure target device:"
  46. << device_context_key.device_name_ << " is available.";
  47. }
  48. return device_context.get();
  49. }
  50. void DeviceContextManager::UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) {
  51. std::string old_key_str = old_key.ToString();
  52. std::string new_key_str = new_key.ToString();
  53. auto handle = device_contexts_.extract(old_key_str);
  54. if (handle.empty()) {
  55. MS_LOG(EXCEPTION) << "Can not find device context for: " << old_key_str;
  56. }
  57. handle.key() = new_key_str;
  58. (void)device_contexts_.insert(std::move(handle));
  59. }
  60. } // namespace device
  61. } // namespace mindspore