Browse Source

Fix a bug after old side-effect flag removed

The problem:
After we removed old 'side_effect' flag from Push primtive,
CSE will incorrectly merge some Push nodes.

To fix this:
We change "_random_effect" flag to a more proper name "side_effect_hidden",
and add this flag to Push primitive to prevent it be merged by CSE.
feature/build-system-rewrite
He Wei 4 years ago
parent
commit
227dbcb101
8 changed files with 21 additions and 21 deletions
  1. +1
    -1
      mindspore/ccsrc/backend/common/pass/common_subexpression_elimination.cc
  2. +3
    -4
      mindspore/ccsrc/frontend/optimizer/cse.cc
  3. +1
    -1
      mindspore/ccsrc/frontend/optimizer/cse.h
  4. +4
    -4
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  5. +1
    -1
      mindspore/core/utils/flags.h
  6. +1
    -1
      mindspore/python/mindspore/ops/operations/nn_ops.py
  7. +1
    -0
      mindspore/python/mindspore/ops/operations/other_ops.py
  8. +9
    -9
      mindspore/python/mindspore/ops/operations/random_ops.py

+ 1
- 1
mindspore/ccsrc/backend/common/pass/common_subexpression_elimination.cc View File

@@ -111,7 +111,7 @@ bool BackendCSE::CheckCNode(const CNodePtr &main, const CNodePtr &node) const {
if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && CheckIgnoreCase(main)) {
return false;
}
if (HasRandomEffect(main) || HasRandomEffect(node)) {
if (HasHiddenSideEffect(main) || HasHiddenSideEffect(node)) {
return false;
}
if (!CheckEqualKernelBuildInfo(main, node)) {


+ 3
- 4
mindspore/ccsrc/frontend/optimizer/cse.cc View File

@@ -133,13 +133,12 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed;
}

bool CSE::HasRandomEffect(const AnfNodePtr &node) {
bool CSE::HasHiddenSideEffect(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) {
return false;
}
auto attr = prim->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
return (attr != nullptr) && attr->isa<BoolImm>() && GetValue<bool>(attr);
return prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_HIDDEN);
}

bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
@@ -181,7 +180,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
return false;
}
// We don't merge primitive cnodes with random effect.
return !HasRandomEffect(c_main);
return !HasHiddenSideEffect(c_main);
}
// a parameter node.
return false;


+ 1
- 1
mindspore/ccsrc/frontend/optimizer/cse.h View File

@@ -38,7 +38,7 @@ class CSE {

virtual bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;

static bool HasRandomEffect(const AnfNodePtr &node);
static bool HasHiddenSideEffect(const AnfNodePtr &node);

protected:
bool BuildOrderGroupAndDoReplaceForOneGraph(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) const;


+ 4
- 4
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -357,7 +357,7 @@ void GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<t
bool has_const_input = false;
const auto &op_prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim);
bool is_random_effect_op = op_prim->HasAttr(GRAPH_FLAG_RANDOM_EFFECT);
bool has_hidden_side_effect = op_prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_HIDDEN);
for (size_t index = 0; index < input_tensors.size(); ++index) {
MS_EXCEPTION_IF_NULL(input_tensors[index]);
buf << input_tensors[index]->shape();
@@ -365,7 +365,7 @@ void GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<t
buf << input_tensors[index]->padding_type();
// In the case of the same shape, but dtype and format are inconsistent
auto tensor_addr = input_tensors[index]->device_address();
if (tensor_addr != nullptr && !is_random_effect_op) {
if (tensor_addr != nullptr && !has_hidden_side_effect) {
auto p_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr);
MS_EXCEPTION_IF_NULL(p_address);
buf << p_address->type_id();
@@ -407,8 +407,8 @@ void GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector<t
buf << build_type->type_id();
}

// Random effect operator
if (is_random_effect_op) {
// Operator with hidden side effect.
if (has_hidden_side_effect) {
buf << "_" << std::to_string(op_prim->id());
}



+ 1
- 1
mindspore/core/utils/flags.h View File

@@ -22,9 +22,9 @@ namespace mindspore {
inline const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
inline const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
inline const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable";
inline const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
inline const char GRAPH_FLAG_SIDE_EFFECT_IO[] = "side_effect_io";
inline const char GRAPH_FLAG_SIDE_EFFECT_MEM[] = "side_effect_mem";
inline const char GRAPH_FLAG_SIDE_EFFECT_HIDDEN[] = "side_effect_hidden";
inline const char GRAPH_FLAG_SIDE_EFFECT_EXCEPTION[] = "side_effect_exception";
inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate";
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";


+ 1
- 1
mindspore/python/mindspore/ops/operations/nn_ops.py View File

@@ -3178,7 +3178,7 @@ class DropoutGenMask(Primitive):
self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
validator.check_value_type("Seed0", Seed0, [int], self.name)
validator.check_value_type("Seed1", Seed1, [int], self.name)
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)


class DropoutDoMask(Primitive):


+ 1
- 0
mindspore/python/mindspore/ops/operations/other_ops.py View File

@@ -678,6 +678,7 @@ class Push(PrimitiveWithInfer):
"""Initialize Push"""
self.add_prim_attr("primitive_target", "CPU")
self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key'])
self.add_prim_attr("side_effect_hidden", True)

def infer_shape(self, inputs, shapes):
return [1]


+ 9
- 9
mindspore/python/mindspore/ops/operations/random_ops.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -62,7 +62,7 @@ class StandardNormal(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize StandardNormal"""
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

@@ -119,7 +119,7 @@ class StandardLaplace(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize StandardLaplace"""
self.init_prim_io_names(inputs=['shape'], outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)
Validator.check_value_type('seed', seed, [int], self.name)
Validator.check_value_type('seed2', seed2, [int], self.name)

@@ -196,7 +196,7 @@ class Gamma(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize Gamma"""
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

@@ -262,7 +262,7 @@ class Poisson(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize Poisson"""
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

@@ -334,7 +334,7 @@ class UniformInt(PrimitiveWithInfer):
def __init__(self, seed=0, seed2=0):
"""Initialize UniformInt"""
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

@@ -451,7 +451,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
Validator.check_positive_int(count, "count", self.name)
Validator.check_value_type('seed', seed, [int], self.name)
Validator.check_value_type('seed2', seed2, [int], self.name)
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)

def infer_shape(self, x_shape):
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
@@ -513,7 +513,7 @@ class RandomCategorical(PrimitiveWithInfer):
Validator.check_type_name("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)

def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype']
@@ -580,7 +580,7 @@ class Multinomial(PrimitiveWithInfer):
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
self.add_prim_attr("_random_effect", True)
self.add_prim_attr("side_effect_hidden", True)

def __infer__(self, inputs, num_samples):
input_shape = inputs["shape"]


Loading…
Cancel
Save