Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v0.3.0-alpha
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include "device/ascend/ascend_label_assign.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| static constexpr uint32_t kLabelGotoLabelId = 1; | |||||
| static constexpr uint32_t kLabelSwitchLabelId = 2; | |||||
| namespace mindspore { | |||||
| namespace device { | |||||
| namespace ascend { | |||||
| static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||||
| if (node->size() <= kLabelGotoLabelId) { | |||||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); | |||||
| } | |||||
| auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId)); | |||||
| MS_EXCEPTION_IF_NULL(label_set); | |||||
| auto value = label_set->GetAttr(kAttrLabelIndex); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| uint32_t goto_label_id = GetValue<uint32_t>(value); | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get()); | |||||
| MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; | |||||
| } | |||||
| static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | |||||
| if (node->size() <= kLabelGotoLabelId) { | |||||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); | |||||
| } | |||||
| std::vector<uint32_t> label_list; | |||||
| for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { | |||||
| auto input = node->input(i); | |||||
| if (!input->isa<CNode>() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { | |||||
| break; | |||||
| } | |||||
| auto label_set = AnfAlgo::GetCNodePrimitive(input); | |||||
| MS_EXCEPTION_IF_NULL(label_set); | |||||
| auto value = label_set->GetAttr(kAttrLabelIndex); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| uint32_t goto_label_id = GetValue<uint32_t>(value); | |||||
| label_list.push_back(goto_label_id); | |||||
| MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); | |||||
| } | |||||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) { | |||||
| auto cnode_list = graph->execution_order(); | |||||
| // 1 assign label id to label_set | |||||
| uint32_t cur_label_id = 0; | |||||
| for (auto &node : cnode_list) { | |||||
| if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) { | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node); | |||||
| MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; | |||||
| ++cur_label_id; | |||||
| } | |||||
| } | |||||
| // 2 update label_switch / label_goto | |||||
| for (auto &node : cnode_list) { | |||||
| if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { | |||||
| UpdateLabelGoto(NOT_NULL(node)); | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { | |||||
| UpdateLabelSwitch(NOT_NULL(node)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ascend | |||||
| } // namespace device | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||||
| #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||||
| #include <memory> | |||||
| #include "session/kernel_graph.h" | |||||
| #include "utils/contract.h" | |||||
| namespace mindspore { | |||||
| namespace device { | |||||
| namespace ascend { | |||||
| class AscendLabelAssign { | |||||
| public: | |||||
| static AscendLabelAssign &GetInstance() { | |||||
| static AscendLabelAssign instance; // Guaranteed to be destroyed. | |||||
| return instance; | |||||
| } | |||||
| AscendLabelAssign(const AscendLabelAssign &) = delete; | |||||
| AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; | |||||
| void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph); | |||||
| private: | |||||
| AscendLabelAssign() = default; | |||||
| ~AscendLabelAssign() = default; | |||||
| }; | |||||
| } // namespace ascend | |||||
| } // namespace device | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||||
| @@ -30,6 +30,7 @@ | |||||
| #include "pre_activate/ascend/ascend_backend_optimization.h" | #include "pre_activate/ascend/ascend_backend_optimization.h" | ||||
| #include "device/kernel_adjust.h" | #include "device/kernel_adjust.h" | ||||
| #include "device/ascend/ascend_stream_assign.h" | #include "device/ascend/ascend_stream_assign.h" | ||||
| #include "device/ascend/ascend_label_assign.h" | |||||
| #include "predict/predict.h" | #include "predict/predict.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "ir/scalar.h" | #include "ir/scalar.h" | ||||
| @@ -189,6 +190,8 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| RootGraphExecutorValidate(graph.get()); | RootGraphExecutorValidate(graph.get()); | ||||
| // assign stream | // assign stream | ||||
| AssignStream(graph); | AssignStream(graph); | ||||
| // assign label | |||||
| AssignLabel(NOT_NULL(graph)); | |||||
| // build kernel if node is cnode | // build kernel if node is cnode | ||||
| BuildKernel(graph); | BuildKernel(graph); | ||||
| // alloc mem | // alloc mem | ||||
| @@ -469,6 +472,12 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const { | |||||
| MS_LOG(INFO) << "Start!"; | |||||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); | |||||
| MS_LOG(INFO) << "Finish!"; | |||||
| } | |||||
| void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| struct timeval start_time, end_time; | struct timeval start_time, end_time; | ||||
| @@ -74,6 +74,7 @@ class AscendSession : public SessionBasic { | |||||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const; | |||||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | void MemoryAlloc(KernelGraph *kernel_graph) const; | ||||
| void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | ||||
| @@ -14,12 +14,16 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "device/ascend/ascend_stream_assign.h" | #include "device/ascend/ascend_stream_assign.h" | ||||
| #include "device/ascend/ascend_label_assign.h" | |||||
| #include "device/ascend/tasksink/task_generator.h" | #include "device/ascend/tasksink/task_generator.h" | ||||
| #include "device/kernel_adjust.h" | #include "device/kernel_adjust.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {} | |||||
| void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } | void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } | ||||
| uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } | uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } | ||||