diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc new file mode 100644 index 0000000000..e4239117c2 --- /dev/null +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "device/ascend/ascend_label_assign.h" +#include "session/anf_runtime_algorithm.h" + +static constexpr uint32_t kLabelGotoLabelId = 1; +static constexpr uint32_t kLabelSwitchLabelId = 2; + +namespace mindspore { +namespace device { +namespace ascend { + +static void UpdateLabelGoto(NotNull node) { + if (node->size() <= kLabelGotoLabelId) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); + } + auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId)); + MS_EXCEPTION_IF_NULL(label_set); + auto value = label_set->GetAttr(kAttrLabelIndex); + MS_EXCEPTION_IF_NULL(value); + uint32_t goto_label_id = GetValue(value); + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); + MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; +} + +static void UpdateLabelSwitch(NotNull node) { + if (node->size() <= kLabelGotoLabelId) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); + } + std::vector label_list; + for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { + auto input = node->input(i); + if (!input->isa() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { + break; + } + + auto label_set = AnfAlgo::GetCNodePrimitive(input); + MS_EXCEPTION_IF_NULL(label_set); + auto value = label_set->GetAttr(kAttrLabelIndex); + MS_EXCEPTION_IF_NULL(value); + uint32_t goto_label_id = GetValue(value); + label_list.push_back(goto_label_id); + MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; + } + AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), node.get()); +} + +void AscendLabelAssign::AssignLabel(NotNull &> graph) { + auto cnode_list = graph->execution_order(); + // 1 assign label id to label_set + uint32_t cur_label_id = 0; + for (auto &node : cnode_list) { + if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) { + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(cur_label_id), node); + MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; + ++cur_label_id; + } + } + // 2 update label_switch / label_goto + for (auto &node : cnode_list) { + if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { + UpdateLabelGoto(NOT_NULL(node)); + } + + if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { + UpdateLabelSwitch(NOT_NULL(node)); + } + } +} + +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/device/ascend/ascend_label_assign.h new file mode 100644 index 0000000000..1cc0351c60 --- /dev/null +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ + +#include +#include "session/kernel_graph.h" +#include "utils/contract.h" + +namespace mindspore { +namespace device { +namespace ascend { + +class AscendLabelAssign { + public: + static AscendLabelAssign &GetInstance() { + static AscendLabelAssign instance; // Guaranteed to be destroyed. + return instance; + } + + AscendLabelAssign(const AscendLabelAssign &) = delete; + AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; + + void AssignLabel(NotNull &> graph); + + private: + AscendLabelAssign() = default; + ~AscendLabelAssign() = default; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 3a0595269b..de2256300e 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -30,6 +30,7 @@ #include "pre_activate/ascend/ascend_backend_optimization.h" #include "device/kernel_adjust.h" #include "device/ascend/ascend_stream_assign.h" +#include "device/ascend/ascend_label_assign.h" #include "predict/predict.h" #include "session/anf_runtime_algorithm.h" #include "ir/scalar.h" @@ -189,6 +190,8 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { RootGraphExecutorValidate(graph.get()); // assign stream AssignStream(graph); + // assign label + AssignLabel(NOT_NULL(graph)); // build kernel if node is cnode BuildKernel(graph); // alloc mem @@ -469,6 +472,12 @@ void AscendSession::AssignStream(const std::shared_ptr &kernel_grap MS_LOG(INFO) << "Finish!"; } +void AscendSession::AssignLabel(NotNull kernel_graph) const { + MS_LOG(INFO) << "Start!"; + device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; struct timeval start_time, end_time; diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 7752e2bbdc..5a7b50d7c5 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -74,6 +74,7 @@ class AscendSession : public SessionBasic { void AdjustKernel(const std::shared_ptr &kernel_graph) const; void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; void AssignStream(const std::shared_ptr &kernel_graph) const; + void AssignLabel(NotNull kernel_graph) const; void BuildKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index 9c4fe2539d..5f195d6b3a 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -14,12 +14,16 @@ * limitations under the License. */ #include "device/ascend/ascend_stream_assign.h" +#include "device/ascend/ascend_label_assign.h" #include "device/ascend/tasksink/task_generator.h" #include "device/kernel_adjust.h" namespace mindspore { namespace device { namespace ascend { + +void AscendLabelAssign::AssignLabel(NotNull &>) {} + void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; }