Browse Source

add reorder of node after opt

tags/v1.2.0-rc1
LianLiguang 4 years ago
parent
commit
d2fe2aa7cc
4 changed files with 70 additions and 4 deletions
  1. +60
    -2
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +1
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  3. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
  4. +8
    -0
      mindspore/ops/_grad/grad_nn_ops.py

+ 60
- 2
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1259,10 +1259,65 @@ bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) { void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
std::vector<CNodePtr> all_opt_list; std::vector<CNodePtr> all_opt_list;
std::vector<CNodePtr> non_opt_list; std::vector<CNodePtr> non_opt_list;

std::vector<CNodePtr> trans_list;
std::vector<CNodePtr> transpose_list;
std::vector<CNodePtr> cast_list;
for (const auto &node : *node_list) { for (const auto &node : *node_list) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
auto trans_pose_func = [&](const CNodePtr &node) -> bool {
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetCNodeName(node) == prim::kPrimTranspose->name()) {
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0);
MS_EXCEPTION_IF_NULL(kernel_index.first);
if (kernel_index.first->isa<CNode>() && kOptOperatorSet.find(AnfAlgo::GetCNodeName(
kernel_index.first->cast<CNodePtr>())) != kOptOperatorSet.end()) {
return true;
}
}
return false;
};

auto trans_data_func = [&](const CNodePtr &node) -> bool {
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetCNodeName(node) == prim::KPrimTransData->name()) {
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0);
MS_EXCEPTION_IF_NULL(kernel_index.first);
if (kernel_index.first->isa<CNode>() && kOptOperatorSet.find(AnfAlgo::GetCNodeName(
kernel_index.first->cast<CNodePtr>())) != kOptOperatorSet.end()) {
return true;
}
if (!kernel_index.first->isa<CNode>()) {
return false;
}
return trans_pose_func(kernel_index.first->cast<CNodePtr>());
}
return false;
};

auto cast_func = [&](const CNodePtr &node) -> bool {
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetCNodeName(node) == prim::kPrimCast->name()) {
auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(node, 0), 0);
MS_EXCEPTION_IF_NULL(kernel_index.first);
if (kernel_index.first->isa<CNode>() && kOptOperatorSet.find(AnfAlgo::GetCNodeName(
kernel_index.first->cast<CNodePtr>())) != kOptOperatorSet.end()) {
return true;
}
if (!kernel_index.first->isa<CNode>()) {
return false;
}
return trans_data_func(kernel_index.first->cast<CNodePtr>());
}
return false;
};

if (trans_pose_func(node)) {
transpose_list.emplace_back(node);
} else if (trans_data_func(node)) {
trans_list.emplace_back(node);
} else if (cast_func(node)) {
cast_list.emplace_back(node);
} else if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
all_opt_list.emplace_back(node); all_opt_list.emplace_back(node);
} else { } else {
non_opt_list.emplace_back(node); non_opt_list.emplace_back(node);
@@ -1271,6 +1326,9 @@ void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_
node_list->clear(); node_list->clear();
std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
std::copy(transpose_list.begin(), transpose_list.end(), std::back_inserter(*node_list));
std::copy(trans_list.begin(), trans_list.end(), std::back_inserter(*node_list));
std::copy(cast_list.begin(), cast_list.end(), std::back_inserter(*node_list));
} }


TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) {


+ 1
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -267,7 +267,7 @@ void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, te
MS_EXCEPTION_IF_NULL(real_input); MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr; tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) { if (real_input->isa<ValueNode>()) {
auto value_node = input->cast<ValueNodePtr>();
auto value_node = real_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node); auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);


+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc View File

@@ -377,7 +377,7 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string
} }
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
} }


+ 8
- 0
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================


"""Define the grad rules of neural network related operations.""" """Define the grad rules of neural network related operations."""
import os
import numpy as np import numpy as np
from mindspore.ops import _selected_grad_ops as SG from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
@@ -28,6 +29,7 @@ from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner from ..operations import _inner_ops as inner
from ... import context from ... import context


env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")


@bprop_getters.register(P.BiasAdd) @bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self): def get_bprop_bias_add(self):
@@ -55,6 +57,8 @@ def get_bprop_conv2d(self):


def bprop(x, w, out, dout): def bprop(x, w, out, dout):
dx = input_grad(dout, w, get_shape(x)) dx = input_grad(dout, w, get_shape(x))
if env_force_bprop_seq == '1':
x = F.depend(x, dx)
dw = filter_grad(dout, x, get_shape(w)) dw = filter_grad(dout, x, get_shape(w))
return dx, dw return dx, dw


@@ -173,6 +177,8 @@ def get_bprop_depthwise_conv2d_native(self):


def bprop(x, w, out, dout): def bprop(x, w, out, dout):
dx = input_grad(get_shape(x), w, dout) dx = input_grad(get_shape(x), w, dout)
if env_force_bprop_seq == '1':
x = F.depend(x, dx)
dw = filter_grad(x, get_shape(w), dout) dw = filter_grad(x, get_shape(w), dout)
return dx, dw return dx, dw


@@ -1047,6 +1053,8 @@ def get_bprop_conv2d_backprop_input(self):


def bprop(x, w, f_sizes, out, dout): def bprop(x, w, f_sizes, out, dout):
dx = input_grad(dout, w) dx = input_grad(dout, w)
if env_force_bprop_seq == '1':
x = F.depend(x, dx)
dw = filter_grad(x, dout, F.shape(w)) dw = filter_grad(x, dout, F.shape(w))
return dx, dw, zeros_like(f_sizes) return dx, dw, zeros_like(f_sizes)




Loading…
Cancel
Save