| @@ -38,7 +38,7 @@ void GatherActor::Init() { | |||
| auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_); | |||
| output_data_.emplace_back(data.get()); | |||
| output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(std::move(data)); | |||
| output_data_by_output_index_[IntToSize(data_arrow->from_output_index_)].emplace_back(std::move(data)); | |||
| } | |||
| } | |||
| @@ -113,7 +113,7 @@ void GatherActor::SendOutput(OpContext<DeviceTensor> *context) const { | |||
| // 2.Send output result. | |||
| for (const auto &result_arrow : output_result_arrows_) { | |||
| MS_EXCEPTION_IF_NULL(result_arrow); | |||
| size_t from_index = result_arrow->from_output_index_; | |||
| size_t from_index = IntToSize(result_arrow->from_output_index_); | |||
| const auto &front_node = data_nodes_[from_index].first; | |||
| for (const auto &backend_node : front_to_backend_parameter_.at(front_node)) { | |||
| if (AnfAlgo::GetMutableOutputAddr(backend_node.first, backend_node.second, false).get() == | |||
| @@ -85,7 +85,7 @@ class GatherActor : public OpActor<DeviceTensor> { | |||
| // The device tensors for launch. | |||
| std::vector<DeviceTensor *> input_device_tensors_; | |||
| // The branch if for current step. | |||
| int input_branch_id_; | |||
| int input_branch_id_{kInvalidBranchID}; | |||
| // Input data. | |||
| std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::stack<DeviceTensor *>>> input_data_; | |||
| @@ -208,7 +208,7 @@ void SwitchActor::AddInput(const KernelWithIndex node_with_index, const size_t b | |||
| branch_inputs_pos_[branch].push_back(iter - input_nodes_.begin()); | |||
| return; | |||
| } | |||
| device_tensor_store_keys_.push_back({input_nodes_.size(), node.get()}); | |||
| device_tensor_store_keys_.emplace_back(input_nodes_.size(), node.get()); | |||
| branch_inputs_pos_[branch].push_back(input_nodes_.size()); | |||
| input_nodes_.push_back(node_with_index); | |||
| return; | |||
| @@ -245,7 +245,7 @@ void SwitchActor::AddInput(const AnfNodePtr &node, const size_t branch) { | |||
| } else if (IsCallNode(real_input.first)) { | |||
| std::vector<AnfNodePtr> call_nodes; | |||
| const auto call_output_num = FetchOutputSizebyCallNode(real_input.first, &call_nodes); | |||
| if (call_output_num <= 0) { | |||
| if (call_output_num == 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid output num for call input:" << AnfAlgo::GetNodeDebugString(real_input.first); | |||
| } | |||
| for (size_t i = 0; i < call_output_num; ++i) { | |||
| @@ -268,7 +268,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) { | |||
| input_branch_ids_[context->sequential_num_].empty()) { | |||
| MS_LOG(ERROR) << "Invalid branch id for actor:" + GetAID().Name(); | |||
| } | |||
| size_t branch_id = input_branch_ids_[context->sequential_num_].top(); | |||
| auto branch_id = input_branch_ids_[context->sequential_num_].top(); | |||
| input_branch_ids_[context->sequential_num_].pop(); | |||
| if (branch_id_to_index_.find(branch_id) == branch_id_to_index_.end()) { | |||
| MS_LOG(ERROR) << "Invalid branch id for switch actor:" + GetAID().Name() + | |||
| @@ -278,9 +278,8 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) { | |||
| } | |||
| DeviceTensor *device_tensor = input_device_tensors_[0]; | |||
| if (device_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Index of switch actor is empty:" + GetAID().Name(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| auto inputs = node_->inputs(); | |||
| TypeId type_id = AnfAlgo::GetOutputInferDataType(inputs[kSwitchCondPos], 0); | |||
| size_t size = abstract::TypeIdSize(type_id); | |||
| @@ -291,7 +290,10 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) { | |||
| int64_t index = 0; | |||
| char buf[kMaxSwitchCondSize] = {0}; | |||
| ShapeVector host_shape; | |||
| device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf)); | |||
| if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) { | |||
| MS_LOG(ERROR) << GetAID().Name() + " get index from device address failed, type id:" + std::to_string(type_id) + | |||
| ", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType())); | |||
| } | |||
| if (type_id == TypeId::kNumberTypeInt32) { | |||
| index = static_cast<int64_t>((static_cast<int32_t *>(static_cast<void *>(buf)))[0]); | |||
| @@ -306,7 +308,7 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) { | |||
| // SwitchLayer node support negative index range [-size, -1]. | |||
| if (index < 0) { | |||
| index += branch_func_graph_.size(); | |||
| index += SizeToInt(branch_func_graph_.size()); | |||
| } | |||
| return static_cast<size_t>(index); | |||
| } | |||
| @@ -403,7 +405,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) { | |||
| " total:" + std::to_string(branch_inputs_pos_[index].size()) + " actor:" + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| size_t from_index = branch_inputs_pos_[index][result_arrow->from_output_index_]; | |||
| size_t from_index = branch_inputs_pos_[index][IntToSize(result_arrow->from_output_index_)]; | |||
| MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index]; | |||
| bool is_send = false; | |||
| @@ -556,7 +556,7 @@ bool IsSubCallNode(const AnfNodePtr &node) { | |||
| std::vector<KernelWithIndex> FetchAllRealInputNodeByParameter(const KernelWithIndex &node) { | |||
| std::vector<KernelWithIndex> parameters; | |||
| const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, node.second); | |||
| const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node.first, SizeToInt(node.second)); | |||
| const auto &real_node = real_node_with_index.first; | |||
| if (real_node->isa<Parameter>()) { | |||
| if (!HasAbstractRef(real_node) && !HasAbstractMonad(real_node)) { | |||
| @@ -749,7 +749,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons | |||
| } | |||
| } | |||
| FetchFrontToBackendParameter(graphs, device_contexts, control_nodes, real_to_formal_front_parameters, | |||
| FetchFrontToBackendParameter(graphs, device_contexts, real_to_formal_front_parameters, | |||
| formal_to_real_front_parameters); | |||
| FetchFuncGraphToParameter(control_nodes); | |||
| @@ -924,14 +924,14 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr | |||
| << AnfAlgo::GetNodeDebugString(parameters[i - kCallInputStartPos]) | |||
| << ", used the default format"; | |||
| CreateDeviceTensorForFrontParameter(inputs[i], device_contexts[0]); | |||
| front_value_nodes_.push_back({inputs[i], device_contexts[0]}); | |||
| front_value_nodes_.emplace_back(inputs[i], device_contexts[0]); | |||
| continue; | |||
| } | |||
| const auto &backend_node = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].first; | |||
| const auto &device_context = front_to_backend_parameters_[parameters[i - kCallInputStartPos]].second; | |||
| CreateDeviceTensorForValueNode(inputs[i], backend_node, device_context); | |||
| front_value_nodes_.push_back({inputs[i], device_context}); | |||
| front_value_nodes_.emplace_back(inputs[i], device_context); | |||
| } | |||
| } | |||
| } | |||
| @@ -940,18 +940,15 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr | |||
| for (size_t index = 0; index < graphs.size(); ++index) { | |||
| const auto &graph = graphs[index]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto execution_order = graph->execution_order(); | |||
| for (const auto ¶meter : graph->input_nodes()) { | |||
| const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter); | |||
| const auto &internal_node = graph->GetFrontNodeByInternalParameter(parameter); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| if (IsInternalParameter(parameter, graph)) { | |||
| auto front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter); | |||
| MS_EXCEPTION_IF_NULL(front_node_with_index.first); | |||
| const auto &front_output_with_index = | |||
| AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false); | |||
| const auto &front_output_with_index = AnfAlgo::VisitKernelWithReturnType( | |||
| front_node_with_index.first, SizeToInt(front_node_with_index.second), false); | |||
| auto front_output_node = front_output_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(front_output_node); | |||
| if (AnfAlgo::CheckPrimitiveType(front_output_node, prim::kPrimSwitch)) { | |||
| @@ -959,7 +956,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr | |||
| FetchValueNodeBySwitchNode(front_output_node, &value_nodes); | |||
| for (const auto value_node : value_nodes) { | |||
| CreateDeviceTensorForValueNode(value_node, parameter, device_contexts[index]); | |||
| front_value_nodes_.push_back({value_node, device_contexts[index]}); | |||
| front_value_nodes_.emplace_back(value_node, device_contexts[index]); | |||
| } | |||
| } | |||
| } | |||
| @@ -975,7 +972,7 @@ void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &contr | |||
| if (output->isa<ValueNode>() && GetFrontValueNodeDeviceContext(output) == nullptr) { | |||
| const auto &device_context = call_node_to_backend_parameter.second.second; | |||
| CreateDeviceTensorForValueNode(output, call_node_to_backend_parameter.second.first, device_context); | |||
| front_value_nodes_.push_back({output, device_context}); | |||
| front_value_nodes_.emplace_back(output, device_context); | |||
| } | |||
| } | |||
| } | |||
| @@ -1054,7 +1051,7 @@ void ControlNodeParser::FetchFrontToFrontParameter( | |||
| if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { | |||
| std::vector<AnfNodePtr> call_inputs; | |||
| call_inputs.assign(inputs.begin() + kCallInputStartPos, inputs.end()); | |||
| call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end()); | |||
| switch_input_parse(inputs[0], call_inputs); | |||
| } else if (IsCallNode(inputs[0])) { | |||
| continue; | |||
| @@ -1095,7 +1092,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std:: | |||
| if (backend_iter == front_to_backend_parameters_.end()) { | |||
| CreateDeviceTensorForFrontParameter(parameter, device_context); | |||
| front_to_backend_parameters_[parameter] = {parameter, device_context}; | |||
| front_parameters_.push_back({parameter, device_context}); | |||
| front_parameters_.emplace_back(parameter, device_context); | |||
| } | |||
| } | |||
| @@ -1215,7 +1212,6 @@ std::vector<KernelWithIndex> FetchParameterbyKernelGraph(const KernelGraphPtr &g | |||
| void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs, | |||
| const std::vector<DeviceContext *> &device_contexts, | |||
| const std::vector<AnfNodePtr> &control_nodes, | |||
| const RealToFormalNode &real_to_formal_front_parameters, | |||
| const RealToFormalNode &formal_to_real_front_parameters) { | |||
| if (graphs.size() != device_contexts.size()) { | |||
| @@ -1356,12 +1352,12 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ | |||
| std::set<AnfNodePtr> *switch_nodes, | |||
| std::set<KernelWithIndex> *results) { | |||
| if (front_output->isa<ValueNode>()) { | |||
| (*results).insert({front_output, 0}); | |||
| (*results).emplace(front_output, 0); | |||
| const auto &iter = formal_to_real_parameters_.find(front_output); | |||
| if (iter != formal_to_real_parameters_.end()) { | |||
| for (const auto &node : iter->second) { | |||
| (*results).insert(node); | |||
| (*results).emplace(node); | |||
| } | |||
| } | |||
| } else if (front_output->isa<Parameter>()) { | |||
| @@ -1370,7 +1366,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ | |||
| if (iter != formal_to_real_parameters_.end()) { | |||
| for (const auto &node : iter->second) { | |||
| (*results).insert(node); | |||
| (*results).emplace(node); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node for front parameter:" << AnfAlgo::GetNodeDebugString(front_output); | |||
| @@ -1402,7 +1398,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ | |||
| const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0)); | |||
| if (iter != front_to_backend_kernels_.end()) { | |||
| (*results).insert(iter->second.first); | |||
| (*results).emplace(iter->second.first); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node for front kernel:" << AnfAlgo::GetNodeDebugString(front_output); | |||
| } | |||
| @@ -1429,7 +1425,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode( | |||
| MS_LOG(WARNING) << "Cannot find backend node of node:" << AnfAlgo::GetNodeDebugString(node_with_index.first); | |||
| continue; | |||
| } | |||
| formal_to_real_parameters_[formal_parameter].push_back({iter->second.first, 0}); | |||
| formal_to_real_parameters_[formal_parameter].emplace_back(iter->second.first, 0); | |||
| } else { | |||
| const auto iter = front_to_backend_kernels_.find(node_with_index); | |||
| if (iter == front_to_backend_kernels_.end()) { | |||
| @@ -1439,7 +1435,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode( | |||
| } | |||
| } | |||
| } else if (real_parameter->isa<ValueNode>()) { | |||
| formal_to_real_parameters_[formal_parameter].push_back({real_parameter, 0}); | |||
| formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0); | |||
| } else if (IsCallNode(real_parameter)) { | |||
| const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); | |||
| for (const auto func_graph : func_graphs) { | |||
| @@ -1512,7 +1508,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(value_node); | |||
| if (front_node != nullptr) { | |||
| formal_to_real_parameters_[front_node].push_back({value_node, 0}); | |||
| formal_to_real_parameters_[front_node].emplace_back(value_node, 0); | |||
| } | |||
| } | |||
| } | |||
| @@ -1521,7 +1517,7 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> | |||
| for (const auto &front_weight : host_parameter_to_weight.second) { | |||
| const auto &iter = front_to_backend_parameters_.find(host_parameter_to_weight.first); | |||
| if (iter != front_to_backend_parameters_.end()) { | |||
| formal_to_real_parameters_[front_weight].push_back({iter->second.first, 0}); | |||
| formal_to_real_parameters_[front_weight].emplace_back(iter->second.first, 0); | |||
| } | |||
| } | |||
| } | |||
| @@ -1552,10 +1548,10 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> | |||
| } | |||
| } | |||
| for (const auto parameter_pair : front_to_backend_parameters) { | |||
| formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); | |||
| formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0); | |||
| } | |||
| for (const auto parameter_pair : front_to_backend_parameters_) { | |||
| formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); | |||
| formal_to_real_parameters_[parameter_pair.first].emplace_back(parameter_pair.second.first, 0); | |||
| } | |||
| } | |||
| @@ -150,7 +150,6 @@ class ControlNodeParser { | |||
| // 2. The parameter from control nodes. | |||
| void FetchFrontToBackendParameter(const std::vector<KernelGraphPtr> &graphs, | |||
| const std::vector<DeviceContext *> &device_contexts, | |||
| const std::vector<AnfNodePtr> &control_nodes, | |||
| const RealToFormalNode &real_to_formal_front_parameters, | |||
| const RealToFormalNode &formal_to_real_front_parameters); | |||
| // Get the relationship between the front and backend of the executable kernel in all kernel graphs. | |||
| @@ -1262,7 +1262,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler | |||
| if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) { | |||
| continue; | |||
| } | |||
| parameters.push_back({parameter, 0}); | |||
| parameters.emplace_back(parameter, 0); | |||
| } | |||
| const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph); | |||
| @@ -1291,7 +1291,7 @@ std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompiler | |||
| if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) { | |||
| continue; | |||
| } | |||
| parameters.push_back({inputs[i], 0}); | |||
| parameters.emplace_back(inputs[i], 0); | |||
| } | |||
| auto func_graph = control_node->func_graph(); | |||
| @@ -1335,7 +1335,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const GraphCompilerInf | |||
| if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) { | |||
| const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph); | |||
| const auto &real_front_node_with_index = | |||
| AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second); | |||
| AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, SizeToInt(kernel_with_index.second)); | |||
| if (HasAbstractRef(real_front_node_with_index.first)) { | |||
| to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, | |||
| real_front_node_with_index.first.get()); | |||
| @@ -1966,6 +1966,7 @@ void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| // Set branch index into switch actor. | |||
| size_t branch_index = switch_actor->branch_id_to_index_.size(); | |||
| @@ -2071,7 +2072,6 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke | |||
| void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) { | |||
| for (const auto &node : control_nodes) { | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| const auto &from_func_graph = node->func_graph(); | |||
| auto inputs = cnode->inputs(); | |||
| // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor. | |||
| @@ -2079,6 +2079,8 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr | |||
| auto actor = FetchActor(inputs[0]->DebugString()); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { | |||
| if (HasAbstractMonad(inputs[i])) { | |||
| continue; | |||
| @@ -2107,6 +2109,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -2123,8 +2126,8 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi | |||
| continue; | |||
| } | |||
| gather_actor->device_tensor_store_keys_.push_back( | |||
| {i - kCallInputStartPos - persist_input_num, inputs[i].get()}); | |||
| gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num, | |||
| inputs[i].get()); | |||
| gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] = | |||
| graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]); | |||
| } else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) || | |||
| @@ -2150,7 +2153,9 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto actor = FetchActor(func_graph->get_return()->DebugString()); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast<SwitchActor *>(actor)); | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor); | |||
| } | |||
| // Link arrow for gather actor for call input kernel graph. | |||
| @@ -2160,6 +2165,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi | |||
| auto actor = FetchActor(kernel_graph->ToString()); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) { | |||
| const auto &input_with_index = gather_actor->data_nodes_[i]; | |||
| @@ -2286,8 +2292,10 @@ void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_c | |||
| const auto &actor_name = backend_node.first->fullname_with_scope(); | |||
| const auto &actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index); | |||
| auto from_actor = dynamic_cast<KernelActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index); | |||
| from_actor->output_data_arrows_.emplace_back(op_arrow); | |||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false); | |||
| UpdateRefCount(device_tensor.get(), true); | |||
| @@ -2416,6 +2424,8 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> | |||
| actor = FetchActor(func_graph->get_return()->DebugString()); | |||
| if (actor != nullptr) { | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID()); | |||
| switch_actor->input_controls_num_++; | |||
| } | |||
| @@ -2458,6 +2468,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID()); | |||
| switch_actor->input_controls_num_++; | |||
| } | |||
| @@ -2470,6 +2481,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> | |||
| const auto &actor = FetchActor(actor_name); | |||
| if (actor != nullptr) { | |||
| const auto &gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID()); | |||
| gather_actor->input_controls_num_++; | |||
| } | |||
| @@ -2495,6 +2507,7 @@ void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| size_t branch_index = switch_actor->branch_id_to_index_.size(); | |||
| if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) { | |||
| @@ -2517,6 +2530,8 @@ void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &grap | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) { | |||
| const auto &func_graph = switch_actor->branch_func_graph_[i]; | |||
| if (func_graph == nullptr) { | |||
| @@ -2556,6 +2571,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_); | |||
| } | |||
| } | |||