Browse Source

Add supoort resolving outer lambda function for ops.Partial.

tags/v1.6.0
hezhenhao1 4 years ago
parent
commit
4af312d17e
3 changed files with 261 additions and 51 deletions
  1. +45
    -22
      mindspore/ccsrc/pipeline/jit/parse/parse.cc
  2. +3
    -1
      mindspore/ccsrc/pipeline/jit/parse/parse.h
  3. +213
    -28
      tests/ut/python/pipeline/parse/test_partial.py

+ 45
- 22
mindspore/ccsrc/pipeline/jit/parse/parse.cc View File

@@ -175,7 +175,19 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunction
FuncGraphPtr Parser::ParseFuncGraph() {
// Get ast FunctionDef node
py::object node = ast_->GetAstNode();
FunctionBlockPtr fn_block = ParseFunction(node);
constexpr char function_def_name[] = "FunctionDef";
constexpr char lambda_name[] = "Lambda";
FunctionBlockPtr fn_block = nullptr;
if (ast_->GetNodeType(node)->node_name() == function_def_name) {
fn_block = ParseDefFunction(node);
} else {
auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
MS_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
<< ast_->GetNodeType(lambda_node)->node_name() << ".";
}
fn_block = ParseLambdaFunction(lambda_node);
}
if (errcode() != PARSE_SUCCESS) {
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
return nullptr;
@@ -259,7 +271,7 @@ ScopePtr Parser::GetScopeForParseFunction() {
return scope;
}

FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
ScopePtr scope = GetScopeForParseFunction();
// The node created in the parsefunction context, will inherit the scope created using scope_guard
ScopeGuard scope_guard(scope);
@@ -323,6 +335,33 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo
return func_block;
}

FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
MS_EXCEPTION_IF_NULL(ast_);
ScopePtr scope = GetScopeForParseFunction();
ScopeGuard scope_guard(scope);
TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));

FunctionBlockPtr func_block = MakeFunctionBlock(*this);
if (block != nullptr) {
func_block->AddPrevBlock(block);
} else {
func_graph_ = func_block->func_graph();
}
func_block->Mature();
auto current_fg = func_block->func_graph();

auto function_name = ast_->function_name();
MS_LOG(DEBUG) << "The function name is " << function_name;
current_fg->debug_info()->set_name(function_name);
GenerateArgsNodeForFunction(func_block, node);

py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
current_fg->set_output(lambda_body_node);
GenerateArgsDefaultValueForFunction(func_block, node);
return func_block;
}

FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
auto node_list = py::cast<py::list>(nodes);
size_t count = py::len(node_list);
@@ -919,7 +958,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &
// Process a function def
FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast FunctionDef";
FunctionBlockPtr function_block = ParseFunction(node, block);
FunctionBlockPtr function_block = ParseDefFunction(node, block);
MS_EXCEPTION_IF_NULL(function_block);

// Get function name
@@ -933,26 +972,10 @@ FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const p
// Process a lambda expression . like lambda x,y: x + y
AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Lambda";
FunctionBlockPtr func_block = MakeFunctionBlock(*this);
func_block->AddPrevBlock(block);
func_block->Mature();

// Get lambda args
py::list args = ast_->GetArgs(node);
auto block_fg = func_block->func_graph();
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(block_fg);
para_node->debug_info()->set_name(arg_name);
block_fg->add_parameter(para_node);
func_block->WriteVariable(arg_name, para_node);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
}
FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
MS_EXCEPTION_IF_NULL(function_block);

py::object body_node = python_adapter::GetPyObjAttr(node, "body");
AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
block_fg->set_output(lambda_body_node);
auto block_fg = function_block->func_graph();
ValueNodePtr const_graph = NewValueNode(block_fg);
return const_graph;
}


+ 3
- 1
mindspore/ccsrc/pipeline/jit/parse/parse.h View File

@@ -196,7 +196,9 @@ class Parser {
// Generate argument default value for ast function node
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// Parse ast function node
FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// Parse lambda function node
FunctionBlockPtr ParseLambdaFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
// Parse ast statements
FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
// Parse one ast statement node


+ 213
- 28
tests/ut/python/pipeline/parse/test_partial.py View File

@@ -22,7 +22,14 @@ from mindspore import nn, Tensor, context

context.set_context(mode=context.GRAPH_MODE)


def test_partial_pos_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_pos_arg
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -35,13 +42,31 @@ def test_partial_pos_arg():
ret = f(y, z)
return ret

class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self, x, y, z):
f = partial(self.show, x)
ret = f(y, z)
return ret

x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)

for net in [Net(), Net2()]:
net(x, y, z)


def test_partial_key_ward_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -54,13 +79,31 @@ def test_partial_key_ward_arg():
ret = f(y=y, z=z)
return ret

class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self, x, y, z):
f = partial(self.show, x=x)
ret = f(y=y, z=z)
return ret

x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)

for net in [Net(), Net2()]:
net(x, y, z)


def test_partial_key_ward_arg_update():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_update
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -73,14 +116,31 @@ def test_partial_key_ward_arg_update():
ret = f(y=y, z=z)
return ret

class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self, x, y, z):
f = partial(self.show, x=x, y=y)
ret = f(y=y, z=z)
return ret

x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)

for net in [Net(), Net2()]:
net(x, y, z)


def test_partial_key_ward_arg_and_pos_arg():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -93,14 +153,31 @@ def test_partial_key_ward_arg_and_pos_arg():
ret = f(2, z=z)
return ret

class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self, x, y, z):
f = partial(self.show, y=y)
ret = f(2, z=z)
return ret

x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
net = Net()
net(x, y, z)

for net in [Net(), Net2()]:
net(x, y, z)


def test_partial_pos_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_pos_arg_const
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -113,10 +190,27 @@ def test_partial_pos_arg_const():
ret = f(2, 3)
return ret

net = Net()
assert net() == (1, 2, 3)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, 1)
ret = f(2, 3)
return ret

for net in [Net(), Net2()]:
assert net() == (1, 2, 3)


def test_partial_key_ward_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_const
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -129,10 +223,27 @@ def test_partial_key_ward_arg_const():
ret = f(y=2, z=3)
return ret

net = Net()
assert net() == (1, 2, 3)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, x=1)
ret = f(y=2, z=3)
return ret

for net in [Net(), Net2()]:
assert net() == (1, 2, 3)


def test_partial_key_ward_arg_update_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_update_const
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -145,11 +256,27 @@ def test_partial_key_ward_arg_update_const():
ret = f(y=3, z=4)
return ret

net = Net()
assert net() == (1, 3, 4)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, x=1, y=2)
ret = f(y=3, z=4)
return ret

for net in [Net(), Net2()]:
assert net() == (1, 3, 4)


def test_partial_key_ward_arg_and_pos_arg_const():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -162,11 +289,27 @@ def test_partial_key_ward_arg_and_pos_arg_const():
ret = f(1, z=3)
return ret

net = Net()
assert net() == (1, 2, 3)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, y=2)
ret = f(1, z=3)
return ret

for net in [Net(), Net2()]:
assert net() == (1, 2, 3)


def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -179,13 +322,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
ret = f(1, 2, 3)
return ret

net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: x" in str(ex.value)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, x=1)
ret = f(1, 2, 3)
return ret

for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: x" in str(ex.value)


def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -198,13 +357,29 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
ret = f(1, 2, z=3)
return ret

net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: y" in str(ex.value)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, y=2)
ret = f(1, 2, z=3)
return ret

for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: y" in str(ex.value)


def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
"""
Feature: ALL TO ALL
Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z
Expectation: the result match given one
"""

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@@ -217,7 +392,17 @@ def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
ret = f(1, 2, 3)
return ret

net = Net()
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: z" in str(ex.value)
class Net2(nn.Cell):
def __init__(self):
super(Net2, self).__init__()
self.show = lambda x, y, z: (x, y, z)

def construct(self):
f = partial(self.show, z=1)
ret = f(1, 2, 3)
return ret

for net in [Net(), Net2()]:
with pytest.raises(TypeError) as ex:
net()
assert "Multiply values for specific argument: z" in str(ex.value)

Loading…
Cancel
Save