diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc index 3e6377c7..d50b6df9 100644 --- a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc @@ -113,6 +113,17 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) { GE_CHECK_NOTNULL(op_desc); std::vector dynamic_shape_dims = {kDynamicShapeDim}; GeShape dynamic_shape(dynamic_shape_dims); + bool reset_shape_flag = false; + if (ResetInputTensorShape(op_desc, dynamic_shape, reset_shape_flag) == SUCCESS && reset_shape_flag) { + (void)ResetOutputTensorShape(op_desc, dynamic_shape); + } + return SUCCESS; +} + +Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, + bool &reset_shape_flag) { + reset_shape_flag = false; + GE_CHECK_NOTNULL(op_desc); for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { auto input_desc = op_desc->MutableInputDesc(static_cast(i)); GE_CHECK_NOTNULL(input_desc); @@ -125,8 +136,14 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) { if (CheckIfConstInput(input_desc)) { continue; } + reset_shape_flag = true; input_desc->SetShape(dynamic_shape); } + return SUCCESS; +} + +Status DynamicSingleOpResetShapePass::ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape) { + GE_CHECK_NOTNULL(op_desc); for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) { auto output_desc = op_desc->MutableOutputDesc(static_cast(i)); GE_CHECK_NOTNULL(output_desc); diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.h b/ge/graph/passes/dynamic_single_op_reset_shape_pass.h index 659bed9c..897fcac6 100644 --- a/ge/graph/passes/dynamic_single_op_reset_shape_pass.h +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.h @@ -27,6 +27,8 @@ class DynamicSingleOpResetShapePass : public GraphPass { private: Status ResetOpShape(OpDescPtr &op_desc); + Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, bool &reset_shape_flag); + Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape); Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu); bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc); };