|
|
|
@@ -21,19 +21,10 @@ namespace opt { |
|
|
|
namespace irpass { |
|
|
|
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0) |
|
|
|
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) |
|
|
|
// Define the checking mode |
|
|
|
enum ScalarCheckingMode : int { GREATER_EQUAL = 0, LESS }; |
|
|
|
|
|
|
|
bool IsCNodePositive(const AnfNodePtr &node) { |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { |
|
|
|
return IsCNodePositive(node->cast<CNodePtr>()->input(1)); |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// check if a value is bigger than UPPER_FLT_LIMIT |
|
|
|
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { |
|
|
|
bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) { |
|
|
|
auto value_node = node->cast<ValueNodePtr>(); |
|
|
|
if (value_node == nullptr) { |
|
|
|
return false; |
|
|
|
@@ -47,7 +38,10 @@ bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { |
|
|
|
auto scalar = value->cast<ScalarPtr>(); |
|
|
|
if (scalar != nullptr) { |
|
|
|
if (scalar->isa<FloatImm>()) { |
|
|
|
return GetValue<float>(scalar) > UPPER_FLT_LIMIT; |
|
|
|
if (checking_mode == GREATER_EQUAL) { |
|
|
|
return GetValue<float>(scalar) >= check_value; |
|
|
|
} |
|
|
|
return GetValue<float>(scalar) < check_value; |
|
|
|
} |
|
|
|
} |
|
|
|
// Check for Tensor [] or Tensor [1] |
|
|
|
@@ -62,48 +56,42 @@ bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { |
|
|
|
TypeId tensor_type = tensor_ptr->Dtype()->type_id(); |
|
|
|
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { |
|
|
|
float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); |
|
|
|
return data[0] > UPPER_FLT_LIMIT; |
|
|
|
if (checking_mode == GREATER_EQUAL) { |
|
|
|
return data[0] >= check_value; |
|
|
|
} |
|
|
|
return data[0] < check_value; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// check if a value is smaller than LOWER_FLT_LIMIT |
|
|
|
bool IsNodeScalarMinFLT(const AnfNodePtr &node) { |
|
|
|
auto value_node = node->cast<ValueNodePtr>(); |
|
|
|
if (value_node == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// check if a value is greater or equal 0.0 |
|
|
|
bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); } |
|
|
|
|
|
|
|
auto value = value_node->value(); |
|
|
|
if (value == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto scalar = value->cast<ScalarPtr>(); |
|
|
|
if (scalar != nullptr) { |
|
|
|
if (scalar->isa<FloatImm>()) { |
|
|
|
return GetValue<float>(scalar) < LOWER_FLT_LIMIT; |
|
|
|
} |
|
|
|
} |
|
|
|
// Check for Tensor [] or Tensor [1] |
|
|
|
auto tensor_ptr = value->cast<tensor::TensorPtr>(); |
|
|
|
if (tensor_ptr == nullptr) { |
|
|
|
return false; |
|
|
|
bool IsCNodePositive(const AnfNodePtr &node) { |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { |
|
|
|
return IsCNodePositive(node->cast<CNodePtr>()->input(1)); |
|
|
|
} |
|
|
|
if (tensor_ptr->DataSize() > 1) { |
|
|
|
return false; |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
TypeId tensor_type = tensor_ptr->Dtype()->type_id(); |
|
|
|
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { |
|
|
|
float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); |
|
|
|
return data[0] < LOWER_FLT_LIMIT; |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimMinimum) || IsPrimitiveCNode(node, prim::kPrimRealDiv)) { |
|
|
|
auto first_node_positive = |
|
|
|
IsCNodePositive(node->cast<CNodePtr>()->input(1)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(1)); |
|
|
|
auto second_node_positive = |
|
|
|
IsCNodePositive(node->cast<CNodePtr>()->input(2)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(2)); |
|
|
|
return first_node_positive && second_node_positive; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// check if a value is greater or equal UPPER_FLT_LIMIT |
|
|
|
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); } |
|
|
|
|
|
|
|
// check if a value is smaller than LOWER_FLT_LIMIT |
|
|
|
bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); } |
|
|
|
|
|
|
|
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
PatternNode x, y, z; |
|
|
|
PConstant zero_(node, false, 0); |
|
|
|
@@ -116,10 +104,15 @@ AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePt |
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, |
|
|
|
IsCNodePositive(x.GetNode(node))); |
|
|
|
|
|
|
|
// {prim::kPrimMaximum, X, LOWER_FLT_LIMIT}} -> X |
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); |
|
|
|
|
|
|
|
// {prim::kPrimMinimum, X, UPPER_FLT_LIMIT}} -> X |
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); |
|
|
|
|
|
|
|
// {prim::kPrimMaximum, X, 0}} -> X when X is always greater or equal 0 |
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, zero_), x, IsCNodePositive(x.GetNode(node))); |
|
|
|
|
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} // namespace irpass |
|
|
|
|