|
|
|
@@ -64,8 +64,7 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
KernelActor *FindKernelActor(const std::unordered_map<std::string, KernelActorPtr> &kernel_actors_map, |
|
|
|
const std::string &name) { |
|
|
|
KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) { |
|
|
|
auto iter = kernel_actors_map.find(name); |
|
|
|
if (iter != kernel_actors_map.end()) { |
|
|
|
return iter->second.get(); |
|
|
|
@@ -185,6 +184,16 @@ void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, c |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(device_tensor); |
|
|
|
const auto &host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address()); |
|
|
|
// If the host tensor has the device address, it indicates that the device address of host tensor is new. |
|
|
|
if (host_tensor_address != nullptr) { |
|
|
|
if (host_tensor_address != device_tensor) { |
|
|
|
AnfAlgo::SetOutputAddr(host_tensor_address, 0, node.get()); |
|
|
|
DeviceTensorStore::GetInstance().Insert(node.get(), host_tensor_address); |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared. |
|
|
|
if (device_tensor->GetPtr() != nullptr) { |
|
|
|
return; |
|
|
|
@@ -382,8 +391,9 @@ void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const std::vector<T |
|
|
|
|
|
|
|
// 2.Prepare the data of host tensor queue(non weighted parameters of graph). |
|
|
|
const auto &host_tensor_queue = FetchHostQueue(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(host_tensor_queue); |
|
|
|
host_tensor_queue->PushData(host_tensors); |
|
|
|
if (host_tensor_queue != nullptr) { |
|
|
|
host_tensor_queue->PushData(host_tensors); |
|
|
|
} |
|
|
|
|
|
|
|
// 3.Prepare the output tensor of graph. |
|
|
|
for (const auto &output_node : graph->outputs()) { |
|
|
|
@@ -472,7 +482,7 @@ ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceConte |
|
|
|
void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) { |
|
|
|
MS_EXCEPTION_IF_NULL(actor_set); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::unordered_map<std::string, KernelActorPtr> kernel_actors_temp_map; |
|
|
|
KernelMapActor kernel_actors_temp_map; |
|
|
|
for (auto &actor : actor_set->kernel_actors_) { |
|
|
|
MS_EXCEPTION_IF_NULL(actor); |
|
|
|
kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor); |
|
|
|
@@ -488,8 +498,15 @@ void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, Grap |
|
|
|
// Link the control arrows of kernel actor. |
|
|
|
LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy); |
|
|
|
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { |
|
|
|
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i, true); |
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) { |
|
|
|
auto input_node = AnfAlgo::GetInputNode(kernel, i); |
|
|
|
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node. |
|
|
|
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map); |
|
|
|
if (HasAbstractMonad(input_node)) { |
|
|
|
continue; // No data arrow for monad input. |
|
|
|
} |
|
|
|
|
|
|
|
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); |
|
|
|
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); |
|
|
|
auto from_kernel = from_kernel_with_output_idx.first; |
|
|
|
|
|
|
|
@@ -583,6 +600,8 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const Kernel |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_actor); |
|
|
|
if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) { |
|
|
|
no_input_kernel_actors.emplace_back(kernel_actor); |
|
|
|
// The no input kernel actor will be triggered by loop count actor, so need set the input_controls_num_. |
|
|
|
kernel_actor->input_controls_num_ = 1; |
|
|
|
} |
|
|
|
} |
|
|
|
return no_input_kernel_actors; |
|
|
|
@@ -689,6 +708,58 @@ void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, Loo |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, |
|
|
|
const KernelMapActor &kernel_actors_map) { |
|
|
|
MS_EXCEPTION_IF_NULL(to_actor); |
|
|
|
MS_EXCEPTION_IF_NULL(from_node); |
|
|
|
if (!from_node->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// Find the real input node, include the monad node and make tuple node. |
|
|
|
const std::vector<PrimitivePtr> &return_types = {prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple}; |
|
|
|
const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, true, return_types); |
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first); |
|
|
|
if (!input_kernel_with_output_idx.first->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
const auto &input_cnode = input_kernel_with_output_idx.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_cnode); |
|
|
|
|
|
|
|
// Get the real depend input by monad node which needs to link the control arrow. |
|
|
|
AnfNodePtr real_depend_input = nullptr; |
|
|
|
if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) { |
|
|
|
real_depend_input = input_cnode->input(kUpdateStateRealInput); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) { |
|
|
|
real_depend_input = input_cnode->input(kLoadStateInput); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) { |
|
|
|
// Make tuple node needs to be expanded. |
|
|
|
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) { |
|
|
|
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), kernel_actors_map); |
|
|
|
} |
|
|
|
return; |
|
|
|
} else { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_depend_input); |
|
|
|
if (!real_depend_input->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// The monad node and make tuple node need recursion. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) || |
|
|
|
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) || |
|
|
|
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) { |
|
|
|
LinkControlArrowByAutoMonad(to_actor, real_depend_input, kernel_actors_map); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Link the control arrow between the kernel actors. |
|
|
|
auto from_actor = FindKernelActor(kernel_actors_map, real_depend_input->fullname_with_scope()); |
|
|
|
MS_EXCEPTION_IF_NULL(from_actor); |
|
|
|
from_actor->output_op_controls_.emplace_back(to_actor->GetAID()); |
|
|
|
to_actor->input_controls_num_++; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(loop_count_actor); |
|
|
|
@@ -784,7 +855,6 @@ HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) con |
|
|
|
if (iter != graph_to_host_queue_.end()) { |
|
|
|
return iter->second.get(); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Can't find the host tensor queue map of graph: " << graph->ToString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|