Browse Source

fix reduce_eliminate bug

tags/v1.0.0
lichenever 5 years ago
parent
commit
954cf0aff6
1 changed files with 9 additions and 0 deletions
  1. +9
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h

+ 9
- 0
mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h View File

@@ -72,6 +72,15 @@ class ReduceOneEliminater : public AnfVisitor {
}
auto new_shape = std::make_shared<ValueTuple>(elements);
auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast<PrimitivePtr>();
auto node_abstract = node->abstract();
// handle auto_parallel get nullptr abstract
if (node_abstract != nullptr) {
auto new_base_shape = std::make_shared<abstract::Shape>(GetValue<std::vector<int>>(new_shape));
node_abstract->set_shape(new_base_shape);
auto new_node = node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)});
new_node->set_abstract(node_abstract);
return new_node;
}
return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)});
}



Loading…
Cancel
Save