Browse Source

fix bug of host and device share weight

tags/v1.3.0
limingqi107 4 years ago
parent
commit
85222572f3
2 changed files with 14 additions and 4 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  2. +11
    -3
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -90,7 +90,9 @@ class GpuKernel : public KernelMod {
if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) {
return nullptr;
}
MS_EXCEPTION_IF_NULL(addr_list[index]->addr);
if (addr_list[index]->addr == nullptr) {
MS_LOG(EXCEPTION) << "The device address is empty, address index:" << index;
}
return reinterpret_cast<T *>(addr_list[index]->addr);
}


+ 11
- 3
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -1640,9 +1640,11 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
}
// Link the control arrow between the kernel actors.
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_kernel->fullname_with_scope()));
if (from_actor == nullptr) {
continue;
}
MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << real_depend_kernel->fullname_with_scope()
<< ", to actor: " << to_actor->GetAID().Name();
MS_EXCEPTION_IF_NULL(from_actor);
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
}
@@ -2690,9 +2692,15 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler

auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);
if (IsPersistentDeviceTensor(input_node)) {
DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);
}

// Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
if (!IsPersistentDeviceTensor(front_node)) {
continue;
}
// If the device tensor store of this device type is not exist, then create the new device tensor of this type.
if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()


Loading…
Cancel
Save