From d2fe2aa7cc847cc7fa4a9780d92f426bccecb05d Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Thu, 7 Jan 2021 10:19:01 +0800 Subject: [PATCH] add reorder of node after opt --- .../backend/session/anf_runtime_algorithm.cc | 62 ++++++++++++++++++- .../ccsrc/backend/session/ascend_session.cc | 2 +- .../device/ascend/kernel_select_ascend.cc | 2 +- mindspore/ops/_grad/grad_nn_ops.py | 8 +++ 4 files changed, 70 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 412c10ff40..00ba037e77 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1259,10 +1259,65 @@ bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { std::vector all_opt_list; std::vector non_opt_list; - + std::vector trans_list; + std::vector transpose_list; + std::vector cast_list; for (const auto &node : *node_list) { MS_EXCEPTION_IF_NULL(node); - if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { + auto trans_pose_func = [&](const CNodePtr &node) -> bool { + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetCNodeName(node) == prim::kPrimTranspose->name()) { + auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0); + MS_EXCEPTION_IF_NULL(kernel_index.first); + if (kernel_index.first->isa() && kOptOperatorSet.find(AnfAlgo::GetCNodeName( + kernel_index.first->cast())) != kOptOperatorSet.end()) { + return true; + } + } + return false; + }; + + auto trans_data_func = [&](const CNodePtr &node) -> bool { + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetCNodeName(node) == prim::KPrimTransData->name()) { + auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0); + MS_EXCEPTION_IF_NULL(kernel_index.first); + if (kernel_index.first->isa() && kOptOperatorSet.find(AnfAlgo::GetCNodeName( + kernel_index.first->cast())) != kOptOperatorSet.end()) { + return true; + } + if (!kernel_index.first->isa()) { + return false; + } + return trans_pose_func(kernel_index.first->cast()); + } + return false; + }; + + auto cast_func = [&](const CNodePtr &node) -> bool { + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetCNodeName(node) == prim::kPrimCast->name()) { + auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0); + MS_EXCEPTION_IF_NULL(kernel_index.first); + if (kernel_index.first->isa() && kOptOperatorSet.find(AnfAlgo::GetCNodeName( + kernel_index.first->cast())) != kOptOperatorSet.end()) { + return true; + } + if (!kernel_index.first->isa()) { + return false; + } + return trans_data_func(kernel_index.first->cast()); + } + return false; + }; + + if (trans_pose_func(node)) { + transpose_list.emplace_back(node); + } else if (trans_data_func(node)) { + trans_list.emplace_back(node); + } else if (cast_func(node)) { + cast_list.emplace_back(node); + } else if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { all_opt_list.emplace_back(node); } else { non_opt_list.emplace_back(node); @@ -1271,6 +1326,9 @@ void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_ node_list->clear(); std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); + std::copy(transpose_list.begin(), transpose_list.end(), std::back_inserter(*node_list)); + std::copy(trans_list.begin(), trans_list.end(), std::back_inserter(*node_list)); + std::copy(cast_list.begin(), cast_list.end(), std::back_inserter(*node_list)); } TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 16ad0dd496..0d96022918 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -267,7 +267,7 @@ void GetOpInputTensors(const CNodePtr &cnode, const std::mapisa()) { - auto value_node = input->cast(); + auto value_node = real_input->cast(); MS_EXCEPTION_IF_NULL(value_node); auto value = GetValueNode(value_node); MS_EXCEPTION_IF_NULL(value_node); diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index b33957c10a..100841e6e5 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -377,7 +377,7 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vectorSetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; + std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 2ed2e3ad58..b0a098663f 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """Define the grad rules of neural network related operations.""" +import os import numpy as np from mindspore.ops import _selected_grad_ops as SG from mindspore.ops.primitive import constexpr @@ -28,6 +29,7 @@ from ..operations import _grad_ops as G from ..operations import _inner_ops as inner from ... import context +env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ") @bprop_getters.register(P.BiasAdd) def get_bprop_bias_add(self): @@ -55,6 +57,8 @@ def get_bprop_conv2d(self): def bprop(x, w, out, dout): dx = input_grad(dout, w, get_shape(x)) + if env_force_bprop_seq == '1': + x = F.depend(x, dx) dw = filter_grad(dout, x, get_shape(w)) return dx, dw @@ -173,6 +177,8 @@ def get_bprop_depthwise_conv2d_native(self): def bprop(x, w, out, dout): dx = input_grad(get_shape(x), w, dout) + if env_force_bprop_seq == '1': + x = F.depend(x, dx) dw = filter_grad(x, get_shape(w), dout) return dx, dw @@ -1047,6 +1053,8 @@ def get_bprop_conv2d_backprop_input(self): def bprop(x, w, f_sizes, out, dout): dx = input_grad(dout, w) + if env_force_bprop_seq == '1': + x = F.depend(x, dx) dw = filter_grad(x, dout, F.shape(w)) return dx, dw, zeros_like(f_sizes)