|
|
@@ -1352,6 +1352,16 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & |
|
|
return cell_id; |
|
|
return cell_id; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string PynativeExecutor::GetCellInfo(const py::object &cell) { |
|
|
|
|
|
if (py::isinstance<Cell>(cell)) { |
|
|
|
|
|
auto c_cell = py::cast<CellPtr>(cell); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_cell); |
|
|
|
|
|
auto cell_info = c_cell->ToString(); |
|
|
|
|
|
return cell_info; |
|
|
|
|
|
} |
|
|
|
|
|
return ""; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node, |
|
|
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node, |
|
|
parse::AstMainType type) { |
|
|
parse::AstMainType type) { |
|
|
MS_EXCEPTION_IF_NULL(ast); |
|
|
MS_EXCEPTION_IF_NULL(ast); |
|
|
@@ -1372,6 +1382,16 @@ std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAs |
|
|
return node_name; |
|
|
return node_name; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ast); |
|
|
|
|
|
py::list args = ast->GetArgs(fn_node); |
|
|
|
|
|
for (size_t i = 1; i < args.size(); i++) { |
|
|
|
|
|
std::string arg_name = py::cast<std::string>(args[i].attr("arg")); |
|
|
|
|
|
MS_LOG(DEBUG) << "Input arg name: " << arg_name; |
|
|
|
|
|
cell_input_args_.emplace(arg_name); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) { |
|
|
bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) { |
|
|
MS_LOG(DEBUG) << "Parse if/while expr"; |
|
|
MS_LOG(DEBUG) << "Parse if/while expr"; |
|
|
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); |
|
|
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST); |
|
|
@@ -1383,14 +1403,25 @@ bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAs |
|
|
MS_LOG(DEBUG) << "Get comparators node falied!"; |
|
|
MS_LOG(DEBUG) << "Get comparators node falied!"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
const auto &left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); |
|
|
|
|
|
const auto &right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); |
|
|
|
|
|
MS_LOG(DEBUG) << "left is " << left << " right is " << right; |
|
|
|
|
|
|
|
|
auto left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR); |
|
|
|
|
|
auto right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR); |
|
|
|
|
|
if (left == parse::NAMED_PRIMITIVE_SUBSCRIPT) { |
|
|
|
|
|
py::object value_in_subscript = parse::python_adapter::GetPyObjAttr(left_node, parse::NAMED_PRIMITIVE_VALUE); |
|
|
|
|
|
left = ParseNodeName(ast, value_in_subscript, parse::AST_MAIN_TYPE_EXPR); |
|
|
|
|
|
} |
|
|
|
|
|
MS_LOG(DEBUG) << "Left is " << left << " Right is " << right; |
|
|
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || |
|
|
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() || |
|
|
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { |
|
|
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) { |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
// if flag: |
|
|
|
|
|
if (node_name == parse::NAMED_PRIMITIVE_NAME) { |
|
|
|
|
|
std::string id = py::cast<std::string>(test_node.attr("id")); |
|
|
|
|
|
if (cell_input_args_.find(id) != cell_input_args_.end()) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -1405,7 +1436,9 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst |
|
|
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); |
|
|
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE); |
|
|
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); |
|
|
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE); |
|
|
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); |
|
|
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR); |
|
|
return unchanged_named_primitive.find(node_name_in_slice_node) == unchanged_named_primitive.end(); |
|
|
|
|
|
|
|
|
if (cell_input_args_.find(node_name_in_slice_node) != cell_input_args_.end()) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return false; |
|
|
return false; |
|
|
@@ -1414,9 +1447,13 @@ bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst |
|
|
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) { |
|
|
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) { |
|
|
MS_LOG(DEBUG) << "Parse for expr"; |
|
|
MS_LOG(DEBUG) << "Parse for expr"; |
|
|
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); |
|
|
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY); |
|
|
|
|
|
if (py::isinstance<py::none>(body_node)) { |
|
|
|
|
|
MS_LOG(DEBUG) << "Parse body of for expression is none!"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); |
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN); |
|
|
size_t count = LongToSize(pcount); |
|
|
size_t count = LongToSize(pcount); |
|
|
MS_LOG(DEBUG) << "The for nodes count is " << count; |
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "The for nodes count in body is " << count; |
|
|
for (size_t i = 0; i < count; ++i) { |
|
|
for (size_t i = 0; i < count; ++i) { |
|
|
auto it = py::cast<py::list>(body_node)[i]; |
|
|
auto it = py::cast<py::list>(body_node)[i]; |
|
|
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); |
|
|
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT); |
|
|
@@ -1427,28 +1464,16 @@ bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> & |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsDynamicCell(const py::object &cell) { |
|
|
|
|
|
std::string cell_info; |
|
|
|
|
|
if (py::isinstance<Cell>(cell)) { |
|
|
|
|
|
auto c_cell = py::cast<CellPtr>(cell); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_cell); |
|
|
|
|
|
cell_info = c_cell->ToString(); |
|
|
|
|
|
} |
|
|
|
|
|
if (cell_info.find("nn.layer.basic.Dense") != string::npos) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
// using ast parse to check whether the construct of cell will be changed |
|
|
|
|
|
auto ast = std::make_shared<parse::ParseAst>(cell); |
|
|
|
|
|
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); |
|
|
|
|
|
if (!success) { |
|
|
|
|
|
MS_LOG(ERROR) << "Parse code to ast tree failed"; |
|
|
|
|
|
|
|
|
bool PynativeExecutor::ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ast); |
|
|
|
|
|
py::object func_obj = parse::python_adapter::GetPyObjAttr(fn_node, parse::NAMED_PRIMITIVE_BODY); |
|
|
|
|
|
if (py::isinstance<py::none>(func_obj)) { |
|
|
|
|
|
MS_LOG(DEBUG) << "Parse body of cell is none!"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
py::object nodes = ast->GetAstNode(); |
|
|
|
|
|
py::object func_obj = parse::python_adapter::GetPyObjAttr(nodes, parse::NAMED_PRIMITIVE_BODY); |
|
|
|
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); |
|
|
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN); |
|
|
size_t count = IntToSize(pcount); |
|
|
size_t count = IntToSize(pcount); |
|
|
MS_LOG(DEBUG) << "The nodes count is " << count; |
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "The nodes count in body is " << count; |
|
|
bool ret = false; |
|
|
bool ret = false; |
|
|
for (size_t i = 0; i < count; ++i) { |
|
|
for (size_t i = 0; i < count; ++i) { |
|
|
auto node = py::cast<py::list>(func_obj)[i]; |
|
|
auto node = py::cast<py::list>(func_obj)[i]; |
|
|
@@ -1461,13 +1486,35 @@ bool PynativeExecutor::IsDynamicCell(const py::object &cell) { |
|
|
ret = ParseIfWhileExprNode(ast, node); |
|
|
ret = ParseIfWhileExprNode(ast, node); |
|
|
} |
|
|
} |
|
|
if (ret) { |
|
|
if (ret) { |
|
|
MS_LOG(INFO) << "Cur cell is dynamic"; |
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Current cell is dynamic!"; |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool PynativeExecutor::IsDynamicCell(const py::object &cell) { |
|
|
|
|
|
std::string cell_info = GetCellInfo(cell); |
|
|
|
|
|
if (ignore_judge_dynamic_cell.find(cell_info) != ignore_judge_dynamic_cell.end()) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
// using ast parse to check whether the construct of cell will be changed |
|
|
|
|
|
auto ast = std::make_shared<parse::ParseAst>(cell); |
|
|
|
|
|
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD); |
|
|
|
|
|
if (!success) { |
|
|
|
|
|
MS_LOG(ERROR) << "Parse code to ast tree failed"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
py::object fn_node = ast->GetAstNode(); |
|
|
|
|
|
// get the name of input args as the initialize of dynamic_variables |
|
|
|
|
|
ParseInputArgs(ast, fn_node); |
|
|
|
|
|
// parse body context |
|
|
|
|
|
bool ret = false; |
|
|
|
|
|
ret = ParseBodyContext(ast, fn_node); |
|
|
|
|
|
cell_input_args_.clear(); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { |
|
|
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { |
|
|
auto cell_id = GetCellId(cell, args); |
|
|
auto cell_id = GetCellId(cell, args); |
|
|
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; |
|
|
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; |
|
|
|