Browse Source

Fix CSE bug for some operations like `DropoutGenMask` which should not

be optimized as it will generate different values each time.
tags/v0.3.0-alpha
seatea 5 years ago
parent
commit
981b013f81
6 changed files with 23 additions and 2 deletions
  1. +17
    -1
      mindspore/ccsrc/optimizer/cse.cc
  2. +2
    -0
      mindspore/ccsrc/optimizer/cse.h
  3. +1
    -0
      mindspore/ccsrc/pybind_api/export_flags.cc
  4. +1
    -0
      mindspore/ccsrc/pybind_api/export_flags.h
  5. +1
    -0
      mindspore/ops/operations/nn_ops.py
  6. +1
    -1
      tests/st/networks/models/bert/bert_tdt_lossscale.py

+ 17
- 1
mindspore/ccsrc/optimizer/cse.cc View File

@@ -90,6 +90,22 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed;
}

bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) {
return false;
}
if (prim_main != nullptr) {
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val);
}
}
return has_random_effect;
}

bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
@@ -122,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
break;
}
}
if (IsPrimitiveCNode(c_main, prim::kPrimDropoutGenMask)) {
if (CheckRandomEffect(c_main, c_node)) {
appsame = false;
}
replace = appsame;


+ 2
- 0
mindspore/ccsrc/optimizer/cse.h View File

@@ -43,6 +43,8 @@ class CSE {

virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;

virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;

bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;

private:


+ 1
- 0
mindspore/ccsrc/pybind_api/export_flags.cc View File

@@ -32,5 +32,6 @@ const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";

} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/pybind_api/export_flags.h View File

@@ -33,6 +33,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];

} // namespace mindspore



+ 1
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -1877,6 +1877,7 @@ class DropoutGenMask(Primitive):
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
validator.check_value_type("Seed0", Seed0, [int], self.name)
validator.check_value_type("Seed1", Seed1, [int], self.name)
self.add_prim_attr("_random_effect", True)


class DropoutDoMask(PrimitiveWithInfer):


+ 1
- 1
tests/st/networks/models/bert/bert_tdt_lossscale.py View File

@@ -162,7 +162,7 @@ def test_bert_tdt():

# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661]
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982189, 11.973948, 12.610932, 12.17564, 12.840248, 12.40294, 12.621653]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)


Loading…
Cancel
Save