Browse Source

!8272 dynamic shape judging bug fix

Merge pull request !8272 from liubuyu/bug_fix
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
da96aaff41
2 changed files with 4 additions and 31 deletions
  1. +2
    -6
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +2
    -25
      mindspore/ccsrc/backend/session/session_basic.cc

+ 2
- 6
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1361,10 +1361,6 @@ std::vector<int> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_no
}
}

bool CheckDynamic(const NotNull<abstract::ShapePtr> &shape) {
return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
}

bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto base_shape = node->Shape();
@@ -1373,7 +1369,7 @@ bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
return false;
}
if (base_shape->isa<abstract::Shape>()) {
if (CheckDynamic(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
if (IsShapeDynamic(base_shape->cast<abstract::ShapePtr>())) {
return true;
}
} else if (base_shape->isa<abstract::TupleShape>()) {
@@ -1384,7 +1380,7 @@ bool AnfRuntimeAlgorithm::IsNodeDynamicShape(const AnfNodePtr &node) {
if (!b_shape->isa<abstract::Shape>()) {
continue;
}
if (CheckDynamic(NOT_NULL(b_shape->cast<abstract::ShapePtr>()))) {
if (IsShapeDynamic(b_shape->cast<abstract::ShapePtr>())) {
return true;
}
}


+ 2
- 25
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1427,35 +1427,12 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
}

bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
}

bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {
MS_EXCEPTION_IF_NULL(anf_node_ptr);
auto base_shape = anf_node_ptr->Shape();
if (base_shape == nullptr) {
MS_LOG(INFO) << "Invalid bash shape ptr, node:" << anf_node_ptr->fullname_with_scope();
return false;
}
if (base_shape->isa<abstract::Shape>()) {
if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
return true;
}
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);

for (size_t i = 0; i < tuple_shape->size(); ++i) {
auto b_shp = (*tuple_shape)[i];
if (!b_shp->isa<abstract::Shape>()) {
continue;
}
if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
return true;
}
}
}
return false;
return AnfAlgo::IsNodeDynamicShape(anf_node_ptr);
}

bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {


Loading…
Cancel
Save