Browse Source

!1615 convert constant bool tensor to bool

Merge pull request !1615 from amongo/GetBoolValueFromConstantTensor
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b236beae28
18 changed files with 168 additions and 11 deletions
  1. +1
    -1
      mindspore/_extends/parse/resources.py
  2. +37
    -2
      mindspore/_extends/parse/standard_method.py
  3. +4
    -0
      mindspore/_extends/parse/trope.py
  4. +10
    -0
      mindspore/ccsrc/debug/trace_info.h
  5. +1
    -0
      mindspore/ccsrc/operator/ops.cc
  6. +1
    -0
      mindspore/ccsrc/operator/ops.h
  7. +13
    -1
      mindspore/ccsrc/operator/prim_statement.cc
  8. +7
    -0
      mindspore/ccsrc/pipeline/parse/function_block.cc
  9. +1
    -0
      mindspore/ccsrc/pipeline/parse/function_block.h
  10. +1
    -0
      mindspore/ccsrc/pipeline/parse/parse.cc
  11. +1
    -0
      mindspore/ccsrc/pipeline/static_analysis/prim.cc
  12. +2
    -0
      mindspore/ccsrc/pipeline/static_analysis/prim.h
  13. +1
    -0
      mindspore/ops/functional.py
  14. +1
    -1
      mindspore/ops/operations/math_ops.py
  15. +2
    -2
      tests/ut/cpp/optimizer/lib_test.cc
  16. +80
    -0
      tests/ut/python/ops/test_control_ops.py
  17. +1
    -1
      tests/ut/python/ops/test_math_ops.py
  18. +4
    -3
      tests/ut/python/pynative_mode/test_framstruct.py

+ 1
- 1
mindspore/_extends/parse/resources.py View File

@@ -126,7 +126,7 @@ convert_object_map = {
T.make_list: F.make_list,
T.make_slice: F.make_slice,
T.range: F.make_range,
T.while_cond: M.while_cond,
# lib function
math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT,


+ 37
- 2
mindspore/_extends/parse/standard_method.py View File

@@ -16,8 +16,10 @@
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from mindspore.common import dtype as mstype
from ...ops import functional as F
from ...ops import operations as P
from ...ops.primitive import constexpr
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like
from ...ops.composite.base import _append
@@ -102,11 +104,44 @@ def bool_(x):
return x.__bool__()


def tensor_bool(x):
"""return immedate x, x is a tensor of bool value"""
def while_cond(x):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond:
return F.cast(x, mstype.bool_)
return x


@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
if shp in ((), (1,)):
return True
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)

@constexpr
def const_tensor_to_bool(x):
"""convert bool tensor to bool condition"""
if x is None:
raise ValueError("Only constant tensor bool can be converted to bool")
x = x.asnumpy()
if x.shape not in ((), (1,)):
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
if x.shape == ():
value = bool(x)
else:
value = bool(x[0])
return value

def tensor_bool(x):
"""tensor as conditon, if is constant, return immediate bool value"""
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond and F.isconstant(x):
return const_tensor_to_bool(x)
return F.cast(x, mstype.bool_)


def and_(x, y):
"""Implementation of `and` (`&`)."""
return x.__and__(y)


+ 4
- 0
mindspore/_extends/parse/trope.py View File

@@ -91,3 +91,7 @@ def to_array(x): # pragma: no cover
def not_contains(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')

def while_cond(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')

+ 10
- 0
mindspore/ccsrc/debug/trace_info.h View File

@@ -281,6 +281,16 @@ class TraceForceBool : public TraceInfo {
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
};

class TraceForceWhileCond : public TraceInfo {
public:
explicit TraceForceWhileCond(const DebugInfoPtr &info) : TraceInfo(info, "force_while_cond", "") {}
MS_DECLARE_PARENT(TraceForceWhileCond, TraceInfo);
~TraceForceWhileCond() override = default;
TraceInfoPtr clone() override {
return std::make_shared<TraceForceWhileCond>(*shared_from_base<TraceForceWhileCond>());
}
};

class TraceExpandJ : public TraceInfo {
public:
explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}


+ 1
- 0
mindspore/ccsrc/operator/ops.cc View File

@@ -243,6 +243,7 @@ const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");

// Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");


+ 1
- 0
mindspore/ccsrc/operator/ops.h View File

@@ -252,6 +252,7 @@ extern const PrimitivePtr kPrimIsNot;
extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
extern const PrimitivePtr kPrimMixedPrecisionCast;
extern const PrimitivePtr kPrimIsConsant;

// Comm ops
extern const PrimitivePtr kPrimMirror;


+ 13
- 1
mindspore/ccsrc/operator/prim_statement.cc View File

@@ -110,7 +110,8 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,

ValuePtr v = cond->GetValueTrack();
MS_EXCEPTION_IF_NULL(v);
if (v->isa<AnyValue>()) {
// for tensor as condition, keeps both true and false branch.
if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
MS_EXCEPTION_IF_NULL(tb);
return tb->Join(fb);
}
@@ -228,5 +229,16 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
// Inputs: x, t
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
}
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// statement: isconstant(x)
// Inputs: x
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
}
ValuePtr v = args_spec_list[0]->BuildValue();
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
}

} // namespace abstract
} // namespace mindspore

+ 7
- 0
mindspore/ccsrc/pipeline/parse/function_block.cc View File

@@ -265,6 +265,13 @@ CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
return op_apply_node;
}

CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
TraceManager::DebugTrace(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond});
TraceManager::EndTrace();
return op_apply_node;
}

// Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) {
if (func_graph()->get_return() != nullptr) {


+ 1
- 0
mindspore/ccsrc/pipeline/parse/function_block.h View File

@@ -55,6 +55,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// A block is matured if all its predecessors is generated
void Mature();
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
void Jump(const FunctionBlockPtr &block, AnfNodePtr node);
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock);


+ 1
- 0
mindspore/ccsrc/pipeline/parse/parse.cc View File

@@ -967,6 +967,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj

py::object test_node = python_adapter::GetPyObjAttr(node, "test");
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
condition_node = header_block->ForceToWhileCond(condition_node);
body_block->Mature();
header_block->ConditionalJump(condition_node, body_block, after_block);



+ 1
- 0
mindspore/ccsrc/pipeline/static_analysis/prim.cc View File

@@ -55,6 +55,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},


+ 2
- 0
mindspore/ccsrc/pipeline/static_analysis/prim.h View File

@@ -200,6 +200,8 @@ AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 0
mindspore/ops/functional.py View File

@@ -26,6 +26,7 @@ typeof = Primitive('typeof')
hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')


issubclass_ = P.IsSubClass()


+ 1
- 1
mindspore/ops/operations/math_ops.py View File

@@ -2294,7 +2294,7 @@ class Abs(PrimitiveWithInfer):
def infer_value(self, x):
if x is not None:
x = x.asnumpy()
out = np.abs(x, dtype=x.dtype)
out = np.array(np.abs(x, dtype=x.dtype))
return Tensor(out)
return None



+ 2
- 2
tests/ut/cpp/optimizer/lib_test.cc View File

@@ -147,8 +147,8 @@ TEST_F(TestOptLib, test_inline_new_closure) {
TEST_F(TestOptLib, test_inline_while) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_while", "before");
auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
FuncGraphPtr after_ = RunSubs(before, patterns);
ASSERT_TRUE(CheckOpt(before, before, patterns));
FuncGraphPtr after = RunSubs(before, patterns);
ASSERT_TRUE(CheckOpt(before, after, patterns, true));
}

TEST_F(TestOptLib, test_arithmetic) {


+ 80
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -520,3 +520,83 @@ def test_while_in_while():
out = out + 3
return out
while_in_while(c1, c2, c3, c4)


def test_tensor_cond():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.array(0, np.bool))
self.t1 = Tensor(np.array([True], np.bool))
def construct(self, x, y):
t = 0
if self.t:
t = t - x * y
else:
t = t - x / y
if self.t1:
t = t + x / y
else:
t = t + x * y
return t
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
net = Net()
out = net(x, y)

def test_tensor_cond_exception():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.array([True, False], np.bool))
def construct(self, x, y):
t = 0
if self.t:
t = t - x * y
else:
t = t - x / y
return t
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
net = Net()
with pytest.raises(ValueError):
out = net(x, y)

def test_while_scalar():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.x = 10
def construct(self, x, y):
i = 0
t = 0
while (i < 10):
t = t + x + y
i = i + 1
return t
net = Net()
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)

def test_while_tensor():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.ones([6, 8, 10], np.int32))
self.count = Tensor(np.array([10], np.int32))
def construct(self, x, y):
i = 0
t = self.t
while (i < self.count):
t = t + x + y
i = i + 1
return t
net = Net()
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)

+ 1
- 1
tests/ut/python/ops/test_math_ops.py View File

@@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE)

# pylint: disable=W0613
# pylint: disable=W0231


+ 4
- 3
tests/ut/python/pynative_mode/test_framstruct.py View File

@@ -30,6 +30,7 @@ from ....mindspore_test_framework.utils.check_gradient import (
ms_function, check_jacobian, Tensor, NNGradChecker,
OperationGradChecker, check_gradient, ScalarGradChecker)

context.set_context(mode=context.PYNATIVE_MODE)

def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
@@ -257,8 +258,8 @@ def if_tensor(a, b):


def test_if_tensor():
res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)
res = if_tensor(Tensor(np.ones([1]).astype(np.int32)), Tensor(np.ones([1]).astype(np.int32)))
assert res == Tensor(np.ones([1]).astype(np.int32) * 4)


@ms_function
@@ -399,7 +400,7 @@ def if_while(a, b, x, z):
def test_if_while():
x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)




Loading…
Cancel
Save