diff --git a/mindspore/ccsrc/frontend/optimizer/recompute.cc b/mindspore/ccsrc/frontend/optimizer/recompute.cc index 65f93ad7a7..b374297cee 100644 --- a/mindspore/ccsrc/frontend/optimizer/recompute.cc +++ b/mindspore/ccsrc/frontend/optimizer/recompute.cc @@ -293,6 +293,17 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector &o } } +CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, + const std::vector &new_inputs) { + auto recomputed_node = graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(recomputed_node); + recomputed_node->AddAttr("duplicated", MakeValue(true)); + recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); + recomputed_node->set_abstract(origin_node->abstract()); + recomputed_node->set_scope(origin_node->scope()); + return recomputed_node; +} + CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node, const std::vector &first_target_inputs, const std::unordered_set &recomputed_origin_nodes, @@ -336,12 +347,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod depend_node->set_abstract(first_input->abstract()); new_inputs[1] = depend_node; } - auto recomputed_node = graph->NewCNode(new_inputs); - MS_EXCEPTION_IF_NULL(recomputed_node); - recomputed_node->AddAttr("duplicated", MakeValue(true)); - recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true)); - recomputed_node->set_abstract(origin_node->abstract()); - recomputed_node->set_scope(origin_node->scope()); + auto recomputed_node = CreateNewRecomputedNode(graph, origin_node, new_inputs); origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node)); return recomputed_node; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 0efd4341e5..3debf5309c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -278,7 +278,6 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { py::tuple max_value_tuple(len); py::tuple min_shape_tuple(len); py::tuple max_shape_tuple(len); - auto dic = py::dict(); bool dyn_shape = false; bool dyn_value = false; @@ -303,7 +302,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { dyn_shape = true; } } - + auto dic = py::dict(); dic[ATTR_SHAPE] = shape_tuple; dic[ATTR_DTYPE] = dtype_tuple; if (arg_tuple->BuildValue()->isa()) { @@ -332,7 +331,6 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { py::list value_list(len); py::list min_shape_list(len); py::list max_shape_list(len); - auto dic = py::dict(); bool dyn_shape = false; for (size_t i = 0; i < len; i++) { @@ -348,7 +346,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { dyn_shape = true; } } - + auto dic = py::dict(); dic[ATTR_SHAPE] = shape_list; dic[ATTR_DTYPE] = dtype_list; if (arg_list->BuildValue()->isa()) { @@ -364,32 +362,52 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { return dic; } + +void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) { + auto arg_tensor = dyn_cast(abs_base); + (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape(); + if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { + const auto &min_shape = arg_tensor->shape()->min_shape(); + const auto &max_shape = arg_tensor->shape()->max_shape(); + if (!min_shape.empty() && !max_shape.empty()) { + (*dic)[ATTR_MIN_SHAPE] = min_shape; + (*dic)[ATTR_MAX_SHAPE] = max_shape; + } + } + + auto min_value = arg_tensor->get_min_value(); + auto max_value = arg_tensor->get_max_value(); + if (min_value != nullptr && max_value != nullptr) { + (*dic)[ATTR_MIN_VALUE] = BuildValue(min_value); + (*dic)[ATTR_MAX_VALUE] = BuildValue(max_value); + } + + (*dic)[ATTR_DTYPE] = arg_tensor->BuildType(); + (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); +} + +void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) { + (*dic)[ATTR_SHAPE] = py::none(); + (*dic)[ATTR_DTYPE] = abs_base->BuildType(); + (*dic)[ATTR_VALUE] = py::none(); + if (abs_base->isa()) { + AbstractBasePtrList args = abs_base->cast()->args(); + if (!args.empty()) { + auto value = args[0]->BuildValue()->cast(); + if (value != nullptr) { + (*dic)[ATTR_DTYPE] = std::make_shared(); + (*dic)[ATTR_VALUE] = value->obj(); + } + } + } +} } // end anonymous namespace py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { MS_EXCEPTION_IF_NULL(abs_base); auto dic = py::dict(); if (abs_base->isa()) { - auto arg_tensor = dyn_cast(abs_base); - dic[ATTR_SHAPE] = arg_tensor->shape()->shape(); - if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { - const auto &min_shape = arg_tensor->shape()->min_shape(); - const auto &max_shape = arg_tensor->shape()->max_shape(); - if (!min_shape.empty() && !max_shape.empty()) { - dic[ATTR_MIN_SHAPE] = min_shape; - dic[ATTR_MAX_SHAPE] = max_shape; - } - } - - auto min_value = arg_tensor->get_min_value(); - auto max_value = arg_tensor->get_max_value(); - if (min_value != nullptr && max_value != nullptr) { - dic[ATTR_MIN_VALUE] = BuildValue(min_value); - dic[ATTR_MAX_VALUE] = BuildValue(max_value); - } - - dic[ATTR_DTYPE] = arg_tensor->BuildType(); - dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); + ConvertAbstractTensorToPython(abs_base, &dic); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = arg->shape()->shape(); @@ -424,19 +442,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic[ATTR_DTYPE] = py::none(); dic[ATTR_VALUE] = py::none(); } else if (abs_base->isa()) { - dic[ATTR_SHAPE] = py::none(); - dic[ATTR_DTYPE] = abs_base->BuildType(); - dic[ATTR_VALUE] = py::none(); - if (abs_base->isa()) { - AbstractBasePtrList args = abs_base->cast()->args(); - if (!args.empty()) { - auto value = args[0]->BuildValue()->cast(); - if (value != nullptr) { - dic[ATTR_DTYPE] = std::make_shared(); - dic[ATTR_VALUE] = value->obj(); - } - } - } + ConvertAbstractFunctionToPython(abs_base, &dic); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = py::none();