Browse Source

rectify the overlength of function line

pull/14858/head
yujianfeng 4 years ago
parent
commit
f580031bdd
2 changed files with 55 additions and 43 deletions
  1. +12
    -6
      mindspore/ccsrc/frontend/optimizer/recompute.cc
  2. +43
    -37
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc

+ 12
- 6
mindspore/ccsrc/frontend/optimizer/recompute.cc View File

@@ -293,6 +293,17 @@ void SetRecomputedAttr(const FuncGraphPtr &graph, const std::vector<CNodePtr> &o
}
}

CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const std::vector<AnfNodePtr> &new_inputs) {
auto recomputed_node = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(recomputed_node);
recomputed_node->AddAttr("duplicated", MakeValue(true));
recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope());
return recomputed_node;
}

CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const std::vector<AnfNodePtr> &first_target_inputs,
const std::unordered_set<CNodePtr> &recomputed_origin_nodes,
@@ -336,12 +347,7 @@ CNodePtr NewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &origin_nod
depend_node->set_abstract(first_input->abstract());
new_inputs[1] = depend_node;
}
auto recomputed_node = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(recomputed_node);
recomputed_node->AddAttr("duplicated", MakeValue(true));
recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope());
auto recomputed_node = CreateNewRecomputedNode(graph, origin_node, new_inputs);
origin_to_recomputed_nodes->insert(std::make_pair(origin_node, recomputed_node));
return recomputed_node;
}


+ 43
- 37
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -278,7 +278,6 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
py::tuple max_value_tuple(len);
py::tuple min_shape_tuple(len);
py::tuple max_shape_tuple(len);
auto dic = py::dict();
bool dyn_shape = false;
bool dyn_value = false;

@@ -303,7 +302,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
dyn_shape = true;
}
}
auto dic = py::dict();
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
@@ -332,7 +331,6 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
py::list value_list(len);
py::list min_shape_list(len);
py::list max_shape_list(len);
auto dic = py::dict();
bool dyn_shape = false;

for (size_t i = 0; i < len; i++) {
@@ -348,7 +346,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
dyn_shape = true;
}
}
auto dic = py::dict();
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
if (arg_list->BuildValue()->isa<AnyValue>()) {
@@ -364,32 +362,52 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {

return dic;
}

void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape();
const auto &max_shape = arg_tensor->shape()->max_shape();
if (!min_shape.empty() && !max_shape.empty()) {
(*dic)[ATTR_MIN_SHAPE] = min_shape;
(*dic)[ATTR_MAX_SHAPE] = max_shape;
}
}

auto min_value = arg_tensor->get_min_value();
auto max_value = arg_tensor->get_max_value();
if (min_value != nullptr && max_value != nullptr) {
(*dic)[ATTR_MIN_VALUE] = BuildValue(min_value);
(*dic)[ATTR_MAX_VALUE] = BuildValue(max_value);
}

(*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
(*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
}

void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
(*dic)[ATTR_SHAPE] = py::none();
(*dic)[ATTR_DTYPE] = abs_base->BuildType();
(*dic)[ATTR_VALUE] = py::none();
if (abs_base->isa<PartialAbstractClosure>()) {
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
if (!args.empty()) {
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
if (value != nullptr) {
(*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
(*dic)[ATTR_VALUE] = value->obj();
}
}
}
}
} // end anonymous namespace

py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
MS_EXCEPTION_IF_NULL(abs_base);
auto dic = py::dict();
if (abs_base->isa<AbstractTensor>()) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
dic[ATTR_SHAPE] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape();
const auto &max_shape = arg_tensor->shape()->max_shape();
if (!min_shape.empty() && !max_shape.empty()) {
dic[ATTR_MIN_SHAPE] = min_shape;
dic[ATTR_MAX_SHAPE] = max_shape;
}
}

auto min_value = arg_tensor->get_min_value();
auto max_value = arg_tensor->get_max_value();
if (min_value != nullptr && max_value != nullptr) {
dic[ATTR_MIN_VALUE] = BuildValue(min_value);
dic[ATTR_MAX_VALUE] = BuildValue(max_value);
}

dic[ATTR_DTYPE] = arg_tensor->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
ConvertAbstractTensorToPython(abs_base, &dic);
} else if (abs_base->isa<AbstractRowTensor>()) {
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
dic[ATTR_SHAPE] = arg->shape()->shape();
@@ -424,19 +442,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_DTYPE] = py::none();
dic[ATTR_VALUE] = py::none();
} else if (abs_base->isa<AbstractFunction>()) {
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = abs_base->BuildType();
dic[ATTR_VALUE] = py::none();
if (abs_base->isa<PartialAbstractClosure>()) {
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
if (!args.empty()) {
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
if (value != nullptr) {
dic[ATTR_DTYPE] = std::make_shared<TypeType>();
dic[ATTR_VALUE] = value->obj();
}
}
}
ConvertAbstractFunctionToPython(abs_base, &dic);
} else if (abs_base->isa<AbstractUndetermined>()) {
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
dic[ATTR_SHAPE] = py::none();


Loading…
Cancel
Save