Browse Source

[bugfix]SyncDeviceToHost failed when device address size is zero

tags/v0.7.0-beta
lizhenyu 5 years ago
parent
commit
1becddf3a4
2 changed files with 15 additions and 3 deletions
  1. +7
    -3
      mindspore/ccsrc/backend/session/gpu_session.cc
  2. +8
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc

+ 7
- 3
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -65,6 +65,8 @@ void GPUSession::StartKernelRT() const {

void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
@@ -73,9 +75,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
if (context_ptr->execution_mode() != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
}
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();


+ 8
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc View File

@@ -32,6 +32,10 @@ namespace device {
namespace gpu {
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr);
bool need_sync = (size != 0) && (size_ != 0);
if (!need_sync) {
return true;
}
auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream);
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
@@ -48,6 +52,10 @@ bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, T

bool GPUDeviceAddress::SyncHostToDevice(const std::vector<int> &, size_t size, TypeId, const void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr);
bool need_sync = (size != 0) && (size_ != 0);
if (!need_sync) {
return true;
}
auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream);
if (size != size_) {


Loading…
Cancel
Save