Browse Source

!1790 remove transdata only connected with control depend

Merge pull request !1790 from lianliguang/remove-the-useless-transdata-connected-with-the-control-depend
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
971f10d222
7 changed files with 179 additions and 46 deletions
  1. +14
    -9
      mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc
  2. +26
    -4
      mindspore/ccsrc/pre_activate/common/helper.cc
  3. +4
    -0
      mindspore/ccsrc/pre_activate/common/helper.h
  4. +44
    -24
      mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc
  5. +9
    -9
      tests/st/networks/models/bert/test_bert_tdt_lossscale.py
  6. +42
    -0
      tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc
  7. +40
    -0
      tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py

+ 14
- 9
mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc View File

@@ -61,16 +61,14 @@ bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type,

bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node,
size_t *cast_index) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
// Check whether the cast node is used for input by only one another node.
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(node) == manager->node_users().end() || manager->node_users()[node].size() != 1) {
auto output_node_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(output_node_list);
if (output_node_list->size() != 1) {
return false;
}
*next_node = manager->node_users()[node].begin()->first;
*cast_index = IntToSize(manager->node_users()[node].begin()->second - 1);
auto node_pair = output_node_list->at(0);
*next_node = node_pair.first;
*cast_index = node_pair.second - 1;
return true;
}

@@ -148,7 +146,10 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
if (alternative_kernel_info == kernel_info_list.end()) {
return nullptr;
}
MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_op_name;
auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node);
MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString()
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
<< (*alternative_kernel_info)->ToString();
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
if (node->inputs().size() < kCastInputNum) {
auto op_name = AnfAlgo::GetCNodeName(node);
@@ -217,6 +218,10 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
if (kernel_info_it == kernel_info_list.end()) {
return nullptr;
}
auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op);
MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString()
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
<< (*kernel_info_it)->ToString();
AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get());

auto prior_name = AnfAlgo::GetCNodeName(prior_op);


+ 26
- 4
mindspore/ccsrc/pre_activate/common/helper.cc View File

@@ -16,6 +16,7 @@

#include "pre_activate/common/helper.h"
#include <string>
#include <utility>
#include <unordered_set>
#include <algorithm>
#include <map>
@@ -475,15 +476,36 @@ void RemoveNopNode(session::KernelGraph *const graph) {
}
}

bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(node) == manager->node_users().end()) {
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
return manager->node_users()[node].size() > 1;
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
continue;
}
output_node_list->push_back(output_info);
}
return output_node_list;
}

bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto output_node_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(output_node_list);
return output_node_list->size() > 1;
}

AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {


+ 4
- 0
mindspore/ccsrc/pre_activate/common/helper.h View File

@@ -18,6 +18,7 @@

#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <unordered_set>
#include "ir/func_graph.h"
@@ -163,6 +164,9 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt

bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);

std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node);

void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);

bool AnfEqual(const BaseRef &a, const BaseRef &b);


+ 44
- 24
mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc View File

@@ -44,11 +44,11 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
return cnode->input(kSingleInputIndex);
}

bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
return false;
return nullptr;
}
std::vector<AnfNodePtr> new_make_tuple_inputs;
bool need_update = false;
@@ -75,17 +75,16 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(cnode, new_make_tuple);
return new_make_tuple;
}
return true;
return nullptr;
}
} // namespace

const BaseRef OptimizeDependence::DefinePattern() const {
VarPtr X = std::make_shared<Var>("X");
MS_EXCEPTION_IF_NULL(X);
VarPtr Y = std::make_shared<Var>("Y");
MS_EXCEPTION_IF_NULL(Y);
return VectorRef({prim::kPrimDepend, X, Y});
VarPtr X = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({X, Xs});
}

const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@@ -95,27 +94,48 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
if (!node->isa<CNode>()) {
return nullptr;
}
auto node_name = AnfAlgo::GetCNodeName(node);
if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) {
return nullptr;
}
size_t index = 0;
auto depend_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
CheckCNodeInputSize(depend_cnode, kDependInputNum);
auto replacing_node = depend_cnode->input(kDependInputNum - 1);
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {
return nullptr;
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)};
if (node_name == prim::kPrimDepend->name()) {
index = 1;
new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend));
}
auto replacing_cnode = replacing_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replacing_cnode);
// Deal with the make_tuple with TransData or Cast inputs.
if (ReplaceMakeTuple(func_graph, replacing_cnode)) {
return nullptr;
if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) {
MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got "
<< AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString();
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) {
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
return nullptr;
auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode);
while (index < input_num) {
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
++index;
MS_EXCEPTION_IF_NULL(replacing_node);
if (!replacing_node->isa<CNode>()) {
new_depend_inputs.push_back(replacing_node);
continue;
}
auto replacing_cnode = replacing_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(replacing_cnode);
// Deal with the make_tuple with TransData or Cast inputs.
auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode);
if (make_tuple_replace_node != nullptr) {
new_depend_inputs.push_back(make_tuple_replace_node);
continue;
}
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
if (replace_node == nullptr) {
new_depend_inputs.push_back(replacing_node);
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: "
<< node->DebugString();
continue;
}
new_depend_inputs.push_back(replace_node);
}
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex),
depend_cnode->input(kRealInputIndexInDepend), replace_node};
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_depend;
if (kernel_graph == nullptr) {


+ 9
- 9
tests/st/networks/models/bert/test_bert_tdt_lossscale.py View File

@@ -201,18 +201,18 @@ def test_bert_percision():
loss_value = np.array(callback.loss_list)
assert np.allclose(loss_value[0], 12.206575, 0, 0.000001)

expect_loss_value = [12.206575, 11.980493, 11.984225, 11.878742, 11.832555, 12.410444, 12.008799,
12.620619, 12.22254, 12.4261055]
expect_loss_value = [12.206575, 11.865044, 11.828129, 11.826707, 11.82108, 12.407423, 12.005459,
12.621225, 12.222903, 12.427446]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

overflow = np.array(callback.overflow_list)
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
expect_overflow = [False, False, False, True, False, False, False, True, False, False]
print("overflow: {}".format(overflow))
assert (overflow == expect_overflow).all()

loss_scale = np.array(callback.lossscale_list)
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0]
print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)

@@ -259,27 +259,27 @@ def test_bert_performance():

# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
expect_loss_value = [10.237753, 10.213153, 10.212972]
expect_loss_value = [10.235566, 10.207392, 10.206976]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)

overflow = np.array(callback.overflow_list)
expect_overflow = [False, False, False]
expect_overflow = [True, True, True]
print("overflow: {}".format(overflow))
assert (overflow == expect_overflow).all()

loss_scale = np.array(callback.lossscale_list)
expect_loss_scale = [16384.0, 16384.0, 16384.0]
expect_loss_scale = [262144.0, 262144.0, 262144.0]
print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)

epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
expect_epoch_mseconds = 1726
expect_epoch_mseconds = 1600
print("epoch mseconds: {}".format(epoch_mseconds))
assert epoch_mseconds <= expect_epoch_mseconds + 5

per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
expect_per_step_mseconds = 17
expect_per_step_mseconds = 16
print("per step mseconds: {}".format(per_step_mseconds))
assert per_step_mseconds <= expect_per_step_mseconds + 1



+ 42
- 0
tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc View File

@@ -68,5 +68,47 @@ TEST_F(TestHWOptimizeDependence, test_optimize_dependence_with_make_tuple) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_dependence_with_make_tuple", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}


TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tuple) {
/*
* def before(x, y, a, b):
* z = make_tuple(TransData(a), TransData(b))
* depend_intput = control_depend(y, z)
* sum = add(x, depend_intput)
* return sum
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before");

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::OptimizeDependence>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(g);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}


TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) {
/*
* def before(x, y, a, b):
* z = make_tuple(TransData(a), TransData(b))
* depend_intput = control_depend(y, z)
* sum = add(x, depend_intput)
* return sum
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before");

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::OptimizeDependence>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(g);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

+ 40
- 0
tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py View File

@@ -16,6 +16,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P

depend = P.Depend()
controldepend = Primitive("ControlDepend")
TransData = Primitive('TransData')
add = P.TensorAdd()
make_tuple = Primitive('make_tuple')
@@ -69,3 +70,42 @@ def test_optimize_dependence_with_make_tuple(tag):
return sum_add

return fns[tag]


def test_optimize_control_dependence(tag):
fns = FnDict()

@fns
def before(x, y, z):
new_z = TransData(z)
depend_intput = controldepend(y, new_z)
sum_add = add(x, depend_intput)
return sum_add

@fns
def after(x, y, z):
depend_intput = controldepend(y, z)
sum_add = add(x, depend_intput)
return sum_add

return fns[tag]


def test_optimize_control_dependence_with_make_tuple(tag):
fns = FnDict()

@fns
def before(x, y, a, b):
z = make_tuple(TransData(a), TransData(b))
depend_intput = controldepend(y, z)
sum_add = add(x, depend_intput)
return sum_add

@fns
def after(x, y, a, b):
z = make_tuple(a, b)
depend_intput = controldepend(y, z)
sum_add = add(x, depend_intput)
return sum_add

return fns[tag]

Loading…
Cancel
Save