| @@ -21,19 +21,10 @@ namespace opt { | |||
| namespace irpass { | |||
| #define UPPER_FLT_LIMIT (FLT_MAX / 2.0) | |||
| #define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) | |||
| // Define the checking mode | |||
| enum ScalarCheckingMode : int { GREATER_EQUAL = 0, LESS }; | |||
| bool IsCNodePositive(const AnfNodePtr &node) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { | |||
| return IsCNodePositive(node->cast<CNodePtr>()->input(1)); | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| // check if a value is bigger than UPPER_FLT_LIMIT | |||
| bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||
| bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return false; | |||
| @@ -47,7 +38,10 @@ bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||
| auto scalar = value->cast<ScalarPtr>(); | |||
| if (scalar != nullptr) { | |||
| if (scalar->isa<FloatImm>()) { | |||
| return GetValue<float>(scalar) > UPPER_FLT_LIMIT; | |||
| if (checking_mode == GREATER_EQUAL) { | |||
| return GetValue<float>(scalar) >= check_value; | |||
| } | |||
| return GetValue<float>(scalar) < check_value; | |||
| } | |||
| } | |||
| // Check for Tensor [] or Tensor [1] | |||
| @@ -62,48 +56,42 @@ bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||
| float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||
| return data[0] > UPPER_FLT_LIMIT; | |||
| if (checking_mode == GREATER_EQUAL) { | |||
| return data[0] >= check_value; | |||
| } | |||
| return data[0] < check_value; | |||
| } | |||
| return false; | |||
| } | |||
| // check if a value is smaller than LOWER_FLT_LIMIT | |||
| bool IsNodeScalarMinFLT(const AnfNodePtr &node) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| return false; | |||
| } | |||
| // check if a value is greater or equal 0.0 | |||
| bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); } | |||
| auto value = value_node->value(); | |||
| if (value == nullptr) { | |||
| return false; | |||
| } | |||
| auto scalar = value->cast<ScalarPtr>(); | |||
| if (scalar != nullptr) { | |||
| if (scalar->isa<FloatImm>()) { | |||
| return GetValue<float>(scalar) < LOWER_FLT_LIMIT; | |||
| } | |||
| } | |||
| // Check for Tensor [] or Tensor [1] | |||
| auto tensor_ptr = value->cast<tensor::TensorPtr>(); | |||
| if (tensor_ptr == nullptr) { | |||
| return false; | |||
| bool IsCNodePositive(const AnfNodePtr &node) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { | |||
| return IsCNodePositive(node->cast<CNodePtr>()->input(1)); | |||
| } | |||
| if (tensor_ptr->DataSize() > 1) { | |||
| return false; | |||
| if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { | |||
| return true; | |||
| } | |||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||
| float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||
| return data[0] < LOWER_FLT_LIMIT; | |||
| if (IsPrimitiveCNode(node, prim::kPrimMinimum) || IsPrimitiveCNode(node, prim::kPrimRealDiv)) { | |||
| auto first_node_positive = | |||
| IsCNodePositive(node->cast<CNodePtr>()->input(1)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(1)); | |||
| auto second_node_positive = | |||
| IsCNodePositive(node->cast<CNodePtr>()->input(2)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(2)); | |||
| return first_node_positive && second_node_positive; | |||
| } | |||
| return false; | |||
| } | |||
| // check if a value is greater or equal UPPER_FLT_LIMIT | |||
| bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); } | |||
| // check if a value is smaller than LOWER_FLT_LIMIT | |||
| bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); } | |||
| AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||
| PatternNode x, y, z; | |||
| PConstant zero_(node, false, 0); | |||
| @@ -116,10 +104,15 @@ AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePt | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, | |||
| IsCNodePositive(x.GetNode(node))); | |||
| // {prim::kPrimMaximum, X, LOWER_FLT_LIMIT}} -> X | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); | |||
| // {prim::kPrimMinimum, X, UPPER_FLT_LIMIT}} -> X | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); | |||
| // {prim::kPrimMaximum, X, 0}} -> X when X is always greater or equal 0 | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, zero_), x, IsCNodePositive(x.GetNode(node))); | |||
| return nullptr; | |||
| } | |||