| @@ -291,8 +291,8 @@ size_t SwitchActor::GetIndex(OpContext<DeviceTensor> *context) { | |||
| char buf[kMaxSwitchCondSize] = {0}; | |||
| ShapeVector host_shape; | |||
| if (!device_tensor->SyncDeviceToHost(host_shape, size, type_id, static_cast<void *>(buf))) { | |||
| MS_LOG(ERROR) << GetAID().Name() + " get index from device address failed, type id:" + std::to_string(type_id) + | |||
| ", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType())); | |||
| MS_LOG(ERROR) << GetAID().Name() << " get index from device address failed, type id:" << std::to_string(type_id) | |||
| << ", device type:" << std::to_string(static_cast<int>(device_context_->GetDeviceAddressType())); | |||
| } | |||
| if (type_id == TypeId::kNumberTypeInt32) { | |||
| @@ -413,7 +413,8 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) { | |||
| for (size_t j = 0; j < AnfAlgo::GetOutputTensorNum(backend_node.first); ++j) { | |||
| if (backend_node.first->kernel_info() != nullptr && AnfAlgo::OutputAddrExist(backend_node.first, j, false) && | |||
| AnfAlgo::GetMutableOutputAddr(backend_node.first, j, false).get() == input_device_tensors_[from_index]) { | |||
| Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, j, | |||
| auto output_index = j; | |||
| Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, backend_node.first, output_index, | |||
| result_arrow->to_input_index_, context); | |||
| is_send = true; | |||
| MS_LOG(DEBUG) << "Switch actor:" << GetAID() << " send result addr:" << input_device_tensors_[from_index] | |||
| @@ -528,7 +528,6 @@ bool IsSubCallNode(const AnfNodePtr &node) { | |||
| } | |||
| const auto inputs = node->cast<CNodePtr>()->inputs(); | |||
| if (!AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { | |||
| return false; | |||
| } | |||
| @@ -670,7 +669,6 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> | |||
| FuncGraphPtr FetchFuncGraphByNode(const AnfNodePtr &node) { | |||
| auto front_node = GetFrontNodeByBackendNode(node); | |||
| // If the front node is nullptr, we can check its inputs. | |||
| if (front_node == nullptr) { | |||
| if (node->isa<CNode>()) { | |||
| @@ -839,7 +837,6 @@ DeviceContext *ControlNodeParser::GetFrontValueNodeDeviceContext(const AnfNodePt | |||
| auto iter = std::find_if( | |||
| front_value_nodes_.begin(), front_value_nodes_.end(), | |||
| [value_node](const auto &front_node_with_context) { return front_node_with_context.first == value_node; }); | |||
| if (iter != front_value_nodes_.end()) { | |||
| return iter->second; | |||
| } | |||
| @@ -1224,7 +1221,6 @@ void ControlNodeParser::FetchFrontToBackendParameter(const std::vector<KernelGra | |||
| auto device_context = device_contexts[i]; | |||
| for (const auto ¶meter : graph->input_nodes()) { | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(parameter); | |||
| if (front_node != nullptr && front_node->isa<Parameter>() && | |||
| front_to_backend_parameters_.find(front_node) == front_to_backend_parameters_.end()) { | |||
| front_to_backend_parameters_[front_node] = {parameter, device_context}; | |||
| @@ -1363,7 +1359,6 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ | |||
| } else if (front_output->isa<Parameter>()) { | |||
| // Output is a parameter. | |||
| const auto iter = formal_to_real_parameters_.find(front_output); | |||
| if (iter != formal_to_real_parameters_.end()) { | |||
| for (const auto &node : iter->second) { | |||
| (void)(*results).emplace(node); | |||
| @@ -1396,7 +1391,6 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ | |||
| } else if (front_output->isa<CNode>()) { | |||
| // Output is a kernel. | |||
| const auto iter = front_to_backend_kernels_.find(AnfAlgo::VisitKernelWithReturnType(front_output, 0)); | |||
| if (iter != front_to_backend_kernels_.end()) { | |||
| (void)(*results).emplace(iter->second.first); | |||
| } else { | |||
| @@ -1418,7 +1412,6 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode( | |||
| for (const auto &front_input : front_inputs) { | |||
| const auto node_with_index = AnfAlgo::VisitKernelWithReturnType(front_input, 0); | |||
| if (node_with_index.first->isa<Parameter>()) { | |||
| const auto &iter = front_to_backend_parameters.find(real_parameter); | |||
| if (iter == front_to_backend_parameters.end()) { | |||
| @@ -1465,7 +1458,6 @@ void ControlNodeParser::FetchBackendParameterNode(const std::vector<KernelGraphP | |||
| } | |||
| for (const auto ¶meter : graph->input_nodes()) { | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(parameter); | |||
| if (front_node != nullptr && front_node->isa<Parameter>() && | |||
| (*front_to_backend_parameters).find(front_node) == (*front_to_backend_parameters).end()) { | |||
| (*front_to_backend_parameters)[front_node] = {parameter, device_context}; | |||
| @@ -1506,7 +1498,6 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector<KernelGraphPtr> | |||
| const auto &graph = graphs[i]; | |||
| for (const auto &value_node : graph->graph_value_nodes()) { | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(value_node); | |||
| if (front_node != nullptr) { | |||
| (void)formal_to_real_parameters_[front_node].emplace_back(value_node, 0); | |||
| } | |||
| @@ -1038,7 +1038,6 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph | |||
| const auto &backend_node = backend_iter->second.first; | |||
| auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node); | |||
| if (iter != host_queue_ds_actor->data_nodes_.end()) { | |||
| (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, | |||
| iter - host_queue_ds_actor->data_nodes_.begin()); | |||
| @@ -2080,7 +2079,6 @@ void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr | |||
| for (const auto &node : control_nodes) { | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| auto inputs = cnode->inputs(); | |||
| // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor. | |||
| if (inputs[0]->isa<CNode>()) { | |||
| auto actor = FetchActor(inputs[0]->DebugString()); | |||
| @@ -2109,7 +2107,10 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { | |||
| auto actor = actor_name_to_actor_[node->DebugString()]; | |||
| LinkDataArrowForSwitchActor(graph_compiler_info, dynamic_cast<SwitchActor *>(actor)); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor); | |||
| } else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) { | |||
| // Link the data arrow for the input of the call node. | |||
| const auto &actor_name = node->DebugString(); | |||
| @@ -2566,6 +2567,7 @@ void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &grap | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_); | |||
| } | |||
| } | |||