Browse Source

Remove redundant Min/Max ops for Bert

Update threshold for rounding when checking expected value in input tensor node
tags/v0.7.0-beta
Hoai Linh Tran 5 years ago
parent
commit
eae5f28256
4 changed files with 86 additions and 4 deletions
  1. +2
    -2
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +80
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc
  3. +2
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h
  4. +2
    -2
      mindspore/core/ir/pattern_matcher.h

+ 2
- 2
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -168,8 +168,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});

// Value_Based Eliminate
value_based_eliminate_ =
MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", {prim::kPrimSelect});
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
}

ResolveIRPassLib::ResolveIRPassLib() {


+ 80
- 0
mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc View File

@@ -19,6 +19,9 @@
namespace mindspore {
namespace opt {
namespace irpass {
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0)
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0)

bool IsCNodePositive(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
return IsCNodePositive(node->cast<CNodePtr>()->input(1));
@@ -29,17 +32,94 @@ bool IsCNodePositive(const AnfNodePtr &node) {
return false;
}

// check if a value is bigger than UPPER_FLT_LIMIT
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}

auto value = value_node->value();
if (value == nullptr) {
return false;
}

auto scalar = value->cast<ScalarPtr>();
if (scalar != nullptr) {
if (scalar->isa<FloatImm>()) {
return GetValue<float>(scalar) > UPPER_FLT_LIMIT;
}
}
// Check for Tensor [] or Tensor [1]
auto tensor_ptr = value->cast<tensor::TensorPtr>();
if (tensor_ptr == nullptr) {
return false;
}
if (tensor_ptr->DataSize() > 1) {
return false;
}

TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
return data[0] > UPPER_FLT_LIMIT;
}

return false;
}

// check if a value is smaller than LOWER_FLT_LIMIT
bool IsNodeScalarMinFLT(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}

auto value = value_node->value();
if (value == nullptr) {
return false;
}

auto scalar = value->cast<ScalarPtr>();
if (scalar != nullptr) {
if (scalar->isa<FloatImm>()) {
return GetValue<float>(scalar) < LOWER_FLT_LIMIT;
}
}
// Check for Tensor [] or Tensor [1]
auto tensor_ptr = value->cast<tensor::TensorPtr>();
if (tensor_ptr == nullptr) {
return false;
}
if (tensor_ptr->DataSize() > 1) {
return false;
}

TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
return data[0] < LOWER_FLT_LIMIT;
}

return false;
}

AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y, z;
PConstant zero_(node, false, 0);
PConstant zero_scalar_(node, false, 0, true);

// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_), y, z), y,
IsCNodePositive(x.GetNode(node)));

MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y,
IsCNodePositive(x.GetNode(node)));

MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node)));

MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node)));

return nullptr;
}



+ 2
- 0
mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.h View File

@@ -32,6 +32,8 @@ namespace opt {
namespace irpass {

// {prim::kPrimSelect, {prim::kPrimGreater, X, 0}, Y, Z}} -> Y when X is always greater than 0
// {prim::kPrimMaximum, X, Y} -> X when Y is smaller than LOWER_FLT_LIMIT
// {prim::kPrimMinimum, X, Y} -> X when Y is greater than UPPER_FLT_LIMIT
class ValueBasedEliminate : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;


+ 2
- 2
mindspore/core/ir/pattern_matcher.h View File

@@ -487,7 +487,7 @@ class PConstant : public PBase<PConstant<T> > {
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
auto threshold = FLT_EPSILON * FLT_EPSILON;
auto threshold = FLT_MIN;
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > threshold) {
return false;
@@ -496,7 +496,7 @@ class PConstant : public PBase<PConstant<T> > {
return true;
} else if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
auto threshold = DBL_EPSILON * DBL_EPSILON;
auto threshold = DBL_MIN;
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > threshold) {
return false;


Loading…
Cancel
Save