Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v0.3.0-alpha
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "device/ascend/ascend_label_assign.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| static constexpr uint32_t kLabelGotoLabelId = 1; | |||
| static constexpr uint32_t kLabelSwitchLabelId = 2; | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||
| if (node->size() <= kLabelGotoLabelId) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); | |||
| } | |||
| auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId)); | |||
| MS_EXCEPTION_IF_NULL(label_set); | |||
| auto value = label_set->GetAttr(kAttrLabelIndex); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| uint32_t goto_label_id = GetValue<uint32_t>(value); | |||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get()); | |||
| MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; | |||
| } | |||
| static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | |||
| if (node->size() <= kLabelGotoLabelId) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); | |||
| } | |||
| std::vector<uint32_t> label_list; | |||
| for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { | |||
| auto input = node->input(i); | |||
| if (!input->isa<CNode>() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { | |||
| break; | |||
| } | |||
| auto label_set = AnfAlgo::GetCNodePrimitive(input); | |||
| MS_EXCEPTION_IF_NULL(label_set); | |||
| auto value = label_set->GetAttr(kAttrLabelIndex); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| uint32_t goto_label_id = GetValue<uint32_t>(value); | |||
| label_list.push_back(goto_label_id); | |||
| MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); | |||
| } | |||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) { | |||
| auto cnode_list = graph->execution_order(); | |||
| // 1 assign label id to label_set | |||
| uint32_t cur_label_id = 0; | |||
| for (auto &node : cnode_list) { | |||
| if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) { | |||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node); | |||
| MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; | |||
| ++cur_label_id; | |||
| } | |||
| } | |||
| // 2 update label_switch / label_goto | |||
| for (auto &node : cnode_list) { | |||
| if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { | |||
| UpdateLabelGoto(NOT_NULL(node)); | |||
| } | |||
| if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { | |||
| UpdateLabelSwitch(NOT_NULL(node)); | |||
| } | |||
| } | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||
| #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||
| #include <memory> | |||
| #include "session/kernel_graph.h" | |||
| #include "utils/contract.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| class AscendLabelAssign { | |||
| public: | |||
| static AscendLabelAssign &GetInstance() { | |||
| static AscendLabelAssign instance; // Guaranteed to be destroyed. | |||
| return instance; | |||
| } | |||
| AscendLabelAssign(const AscendLabelAssign &) = delete; | |||
| AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; | |||
| void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph); | |||
| private: | |||
| AscendLabelAssign() = default; | |||
| ~AscendLabelAssign() = default; | |||
| }; | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||
| @@ -30,6 +30,7 @@ | |||
| #include "pre_activate/ascend/ascend_backend_optimization.h" | |||
| #include "device/kernel_adjust.h" | |||
| #include "device/ascend/ascend_stream_assign.h" | |||
| #include "device/ascend/ascend_label_assign.h" | |||
| #include "predict/predict.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "ir/scalar.h" | |||
| @@ -189,6 +190,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| RootGraphExecutorValidate(graph.get()); | |||
| // assign stream | |||
| AssignStream(graph); | |||
| // assign label | |||
| AssignLabel(NOT_NULL(graph)); | |||
| // build kernel if node is cnode | |||
| BuildKernel(graph); | |||
| // alloc mem | |||
| @@ -469,6 +472,12 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| struct timeval start_time, end_time; | |||
| @@ -74,6 +74,7 @@ class AscendSession : public SessionBasic { | |||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const; | |||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | |||
| void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | |||
| @@ -14,12 +14,16 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "device/ascend/ascend_stream_assign.h" | |||
| #include "device/ascend/ascend_label_assign.h" | |||
| #include "device/ascend/tasksink/task_generator.h" | |||
| #include "device/kernel_adjust.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {} | |||
| void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } | |||
| uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } | |||