Browse Source

!12949 Add TopoSort Rhs First attribute for special CNode, such as Depend CNode with isolated nodes.

From: @zh_qh
Reviewed-by: @hwhewei,@zhunaipan
Signed-off-by: @zhunaipan
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
48d4cca512
7 changed files with 30 additions and 14 deletions
  1. +3
    -3
      mindspore/ccsrc/frontend/optimizer/opt.cc
  2. +8
    -2
      mindspore/ccsrc/pipeline/jit/parse/function_block.cc
  3. +1
    -0
      mindspore/ccsrc/utils/utils.h
  4. +11
    -2
      mindspore/core/ir/graph_utils.cc
  5. +2
    -2
      tests/st/ops/cpu/test_dot_op.py
  6. +1
    -1
      tests/ut/python/nn/test_nn_embedding.py
  7. +4
    -4
      tests/ut/python/nn/test_ssim.py

+ 3
- 3
mindspore/ccsrc/frontend/optimizer/opt.cc View File

@@ -228,13 +228,13 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
bool change = false; bool change = false;
auto res = DoTransform(optimizer, node, substitution); auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) { if (res != nullptr) {
if (is_once_) {
return true;
}
change = true; change = true;
changes = true; changes = true;
node = res; node = res;
} }
if (change && is_once_) {
return true;
}
UpdateTransformingList(optimizer, node, &todo, change, seen); UpdateTransformingList(optimizer, node, &todo, change, seen);
} }




+ 8
- 2
mindspore/ccsrc/pipeline/jit/parse/function_block.cc View File

@@ -17,14 +17,17 @@
*/ */


#include "pipeline/jit/parse/function_block.h" #include "pipeline/jit/parse/function_block.h"

#include <string> #include <string>
#include <memory> #include <memory>

#include "pybind11/pybind11.h"
#include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "utils/info.h" #include "utils/info.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "pybind11/pybind11.h"
#include "utils/utils.h"


namespace mindspore { namespace mindspore {
namespace py = pybind11; namespace py = pybind11;
@@ -435,7 +438,10 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
old_output = NewValueNode(kNone); old_output = NewValueNode(kNone);
} }
AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state}); AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
AnfNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
CNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
// We add this attribute for @constexpr use scene, since we must infer them before other nodes.
// That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString() MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(2); << ", state: " << state->DebugString(2);
func_graph()->set_output(depend_node, true); func_graph()->set_output(depend_node, true);


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -397,6 +397,7 @@ constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
constexpr auto kAttrStitch = "stitch"; constexpr auto kAttrStitch = "stitch";
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";


// attr value // attr value
constexpr auto kValueTargetSwitch = "target_switch"; constexpr auto kValueTargetSwitch = "target_switch";


+ 11
- 2
mindspore/core/ir/graph_utils.cc View File

@@ -32,6 +32,7 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "mindspore/ccsrc/utils/utils.h"


namespace mindspore { namespace mindspore {
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
@@ -170,8 +171,16 @@ std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) { static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
vecs->reserve(vecs->size() + inputs.size()); vecs->reserve(vecs->size() + inputs.size());
// To keep evaluate order from left to right, we push inputs in reversed order.
vecs->insert(vecs->end(), inputs.rbegin(), inputs.rend());

// To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
auto sort_rhs_first =
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
if (sort_rhs_first) {
vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
} else {
vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
}
} }


std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {


+ 2
- 2
tests/st/ops/cpu/test_dot_op.py View File

@@ -174,8 +174,8 @@ def test_dot_008():
network = NetDot() network = NetDot()
try: try:
network(x2_tensor, x1_tensor) network(x2_tensor, x1_tensor)
except IndexError as e:
assert IndexError == type(e)
except ValueError as e:
assert ValueError == type(e)




@pytest.mark.level0 @pytest.mark.level0


+ 1
- 1
tests/ut/python/nn/test_nn_embedding.py View File

@@ -82,7 +82,7 @@ def test_check_multifield_embedding_false_type_field_id():


@non_graph_engine @non_graph_engine
def test_check_multifield_embedding_false_input_shape(): def test_check_multifield_embedding_false_input_shape():
with pytest.raises(IndexError):
with pytest.raises(ValueError):
compile_multi_field_embedding((8,), (8, 200), (8, 200), compile_multi_field_embedding((8,), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16) dtype.int16, dtype.float32, dtype.int16)




+ 4
- 4
tests/ut/python/nn/test_ssim.py View File

@@ -84,7 +84,7 @@ def test_ssim_different_shape():
img1 = Tensor(np.random.random(shape_1)) img1 = Tensor(np.random.random(shape_1))
img2 = Tensor(np.random.random(shape_2)) img2 = Tensor(np.random.random(shape_2))
net = SSIMNet() net = SSIMNet()
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, img1, img2) _executor.compile(net, img1, img2)




@@ -108,9 +108,9 @@ def test_ssim_invalid_5d_input():
invalid_img2 = Tensor(np.random.random(invalid_shape)) invalid_img2 = Tensor(np.random.random(invalid_shape))


net = SSIMNet() net = SSIMNet()
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, invalid_img1, img2) _executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, img1, invalid_img2) _executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, invalid_img1, invalid_img2) _executor.compile(net, invalid_img1, invalid_img2)

Loading…
Cancel
Save