Browse Source

bugfix for summary and pynative bp

tags/v1.3.0
lizhenyu 4 years ago
parent
commit
3d11311ec1
3 changed files with 12 additions and 1 deletions
  1. +6
    -0
      mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.cc
  2. +1
    -0
      mindspore/ccsrc/runtime/framework/graph_compiler.cc
  3. +5
    -1
      mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc

+ 6
- 0
mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.cc View File

@@ -83,6 +83,9 @@ void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> *free_list, cons
MS_EXCEPTION_IF_NULL(device_context);
for (auto &device_tensor : *free_list) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->original_ref_count() == SIZE_MAX) {
continue;
}
// The reference count is decremented to zero to free memory, and reset to the original count.
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {
@@ -111,6 +114,9 @@ void MemoryManagerActor::FreeBatchMemory(std::vector<DeviceTensor *> *free_list,
auto &device_context = (*device_contexts)[i];
MS_EXCEPTION_IF_NULL(device_tensor);
MS_EXCEPTION_IF_NULL(device_context);
if (device_tensor->original_ref_count() == SIZE_MAX) {
continue;
}
// The reference count is decremented to zero to free memory, and reset to the original count.
device_tensor->DecreaseRefCount();
if (device_tensor->ref_count() == 0) {


+ 1
- 0
mindspore/ccsrc/runtime/framework/graph_compiler.cc View File

@@ -248,6 +248,7 @@ void SetSummaryNodesRefCount(const KernelGraph *graph) {
auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
MS_EXCEPTION_IF_NULL(device_address);
device_address->set_original_ref_count(SIZE_MAX);
device_address->ResetRefCount();
}
}
} // namespace


+ 5
- 1
mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc View File

@@ -200,6 +200,8 @@ void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &grap
}

void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph) const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
// Graph optimization relevant to device data format
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@@ -211,7 +213,9 @@ void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
}
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
pm->AddPass(std::make_shared<opt::AddReluGradV2Fusion>());


Loading…
Cancel
Save