Browse Source

Feature: reset shape of dynamic single op

pull/618/head
l00444296 5 years ago
parent
commit
784e2449fc
2 changed files with 19 additions and 0 deletions
  1. +17
    -0
      ge/graph/passes/dynamic_single_op_reset_shape_pass.cc
  2. +2
    -0
      ge/graph/passes/dynamic_single_op_reset_shape_pass.h

+ 17
- 0
ge/graph/passes/dynamic_single_op_reset_shape_pass.cc View File

@@ -113,6 +113,17 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) {
GE_CHECK_NOTNULL(op_desc);
std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim};
GeShape dynamic_shape(dynamic_shape_dims);
bool reset_shape_flag = false;
if (ResetInputTensorShape(op_desc, dynamic_shape, reset_shape_flag) == SUCCESS && reset_shape_flag) {
(void)ResetOutputTensorShape(op_desc, dynamic_shape);
}
return SUCCESS;
}

Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape,
bool &reset_shape_flag) {
reset_shape_flag = false;
GE_CHECK_NOTNULL(op_desc);
for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
GE_CHECK_NOTNULL(input_desc);
@@ -125,8 +136,14 @@ Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) {
if (CheckIfConstInput(input_desc)) {
continue;
}
reset_shape_flag = true;
input_desc->SetShape(dynamic_shape);
}
return SUCCESS;
}

Status DynamicSingleOpResetShapePass::ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape) {
GE_CHECK_NOTNULL(op_desc);
for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) {
auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(i));
GE_CHECK_NOTNULL(output_desc);


+ 2
- 0
ge/graph/passes/dynamic_single_op_reset_shape_pass.h View File

@@ -27,6 +27,8 @@ class DynamicSingleOpResetShapePass : public GraphPass {

private:
Status ResetOpShape(OpDescPtr &op_desc);
Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape, bool &reset_shape_flag);
Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape);
Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu);
bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc);
};


Loading…
Cancel
Save