/** * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "runtime/device/stream_synchronizer.h" #include "utils/ms_context.h" #include "distributed/collective/collective_manager.h" #include "runtime/recovery/recovery_context.h" namespace mindspore { namespace device { std::mutex StreamSynchronizer::instance_lock_; std::shared_ptr StreamSynchronizer::instance_ = nullptr; void StreamSynchronizer::Initialize() { // Non disaster recovery mode does not need to start thread and timeout mechanisms. if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery()) { return; } worker_thread_ = std::thread(&StreamSynchronizer::DoSyncStreamTask, this); } void StreamSynchronizer::Finalize() { { std::unique_lock lock(task_mutex_); stop_ = true; } do_sync_stream_cv_.notify_all(); if (worker_thread_.joinable()) { worker_thread_.join(); } device_context_ = nullptr; } bool StreamSynchronizer::SyncStream(const std::string &device_name, uint32_t timeout) { std::unique_lock reentrant_lock(reentrant_mutex_); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); uint32_t device_id = ms_context->get_param(MS_CTX_DEVICE_ID); const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id}); MS_EXCEPTION_IF_NULL(device_context); // If disable recovery or timeout==0, sync stream directly to improve performance. if (!runtime::recovery::RecoveryContext::GetInstance()->enable_recovery() || timeout == 0) { device_context->Initialize(); return device_context->SyncStream(); } std::unique_lock lock(task_mutex_); if (stop_) { MS_LOG(EXCEPTION) << "The synchronization stream task has stopped"; } device_context_ = device_context; do_sync_stream_cv_.notify_one(); if (sync_stream_time_out_) { // If sync stream timeout has happened, increase the timeout by 4 times. const uint32_t kTimeOutScaleFactor = 4; timeout *= kTimeOutScaleFactor; } if (time_out_cv_.wait_for(lock, std::chrono::seconds(timeout)) == std::cv_status::no_timeout) { if (!sync_stream_ret_) { MS_LOG(ERROR) << "Synchronize stream failed."; } return sync_stream_ret_; } else { sync_stream_time_out_ = true; runtime::recovery::RecoveryContext::GetInstance()->set_need_reinit_collective(true); if (!distributed::collective::CollectiveManager::instance()->Finalize()) { MS_LOG(ERROR) << "Finalize collective manager failed."; return false; } time_out_cv_.wait(lock, [this]() { return device_context_ == nullptr; }); MS_LOG(WARNING) << "Synchronize stream time out."; return true; } } void StreamSynchronizer::DoSyncStreamTask() { for (;;) { { std::unique_lock lock(task_mutex_); do_sync_stream_cv_.wait(lock, [this]() { return stop_ || device_context_ != nullptr; }); if (stop_) { return; } } device_context_->Initialize(); // Really sync stream. sync_stream_ret_ = device_context_->SyncStream(); { std::unique_lock lock(task_mutex_); device_context_ = nullptr; } time_out_cv_.notify_one(); } } } // namespace device } // namespace mindspore