|
|
@@ -722,7 +722,9 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto tensor_input = op->MutableInputDesc(0); |
|
|
auto tensor_input = op->MutableInputDesc(0); |
|
|
|
|
|
auto tensor_output = op->MutableOutputDesc(0); |
|
|
GE_CHECK_NOTNULL(tensor_input); |
|
|
GE_CHECK_NOTNULL(tensor_input); |
|
|
|
|
|
GE_CHECK_NOTNULL(tensor_output); |
|
|
string data_op_name = op->GetName(); |
|
|
string data_op_name = op->GetName(); |
|
|
auto origin_shape = tensor_input->GetShape(); |
|
|
auto origin_shape = tensor_input->GetShape(); |
|
|
auto iter = shape_range_map.find(data_op_name); |
|
|
auto iter = shape_range_map.find(data_op_name); |
|
|
@@ -741,6 +743,8 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, |
|
|
} |
|
|
} |
|
|
tensor_input->SetShape(origin_shape); |
|
|
tensor_input->SetShape(origin_shape); |
|
|
tensor_input->SetShapeRange(cur_shape_range); |
|
|
tensor_input->SetShapeRange(cur_shape_range); |
|
|
|
|
|
tensor_output->SetShape(origin_shape); |
|
|
|
|
|
tensor_output->SetShapeRange(cur_shape_range); |
|
|
GELOGI("Update input [%s] shape range info", data_op_name.c_str()); |
|
|
GELOGI("Update input [%s] shape range info", data_op_name.c_str()); |
|
|
} else { |
|
|
} else { |
|
|
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); |
|
|
GELOGI("No need to update input [%s] attr because not found from input_shape_range.", data_op_name.c_str()); |
|
|
@@ -749,6 +753,29 @@ Status UpdateDataOpShapeRange(const OpDescPtr &op, |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static Status CheckInputShapeRangeNode(const ComputeGraphPtr &compute_graph, |
|
|
|
|
|
map<string, vector<pair<int64_t, int64_t>>> &shape_range_map) { |
|
|
|
|
|
for (const auto &it : shape_range_map) { |
|
|
|
|
|
std::string node_name = it.first; |
|
|
|
|
|
ge::NodePtr node = compute_graph->FindNode(node_name); |
|
|
|
|
|
if (node == nullptr) { |
|
|
|
|
|
REPORT_INPUT_ERROR("E10016", std::vector<std::string>({"parameter", "opname"}), |
|
|
|
|
|
std::vector<std::string>({"input_shape_range", node_name})); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not exist in model", |
|
|
|
|
|
node_name.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
if (node->GetType() != DATA) { |
|
|
|
|
|
REPORT_INPUT_ERROR("E10017", std::vector<std::string>({"parameter", "opname"}), |
|
|
|
|
|
std::vector<std::string>({"input_shape_range", node_name})); |
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputNode]Input parameter[--input_shape_range]'s opname[%s] is not a input opname", |
|
|
|
|
|
node_name.c_str()); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { |
|
|
Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, const string &input_shape_range) { |
|
|
if (input_shape_range.empty()) { |
|
|
if (input_shape_range.empty()) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
@@ -757,7 +784,12 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co |
|
|
|
|
|
|
|
|
map<string, vector<pair<int64_t, int64_t>>> shape_range_map; |
|
|
map<string, vector<pair<int64_t, int64_t>>> shape_range_map; |
|
|
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { |
|
|
if (!ParseInputShapeRange(input_shape_range, shape_range_map)) { |
|
|
GELOGE(PARAM_INVALID, "Parse input shape range failed."); |
|
|
|
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Update][InputShapeRange]Parse input shape range failed."); |
|
|
|
|
|
return PARAM_INVALID; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (CheckInputShapeRangeNode(compute_graph, shape_range_map) != SUCCESS) { |
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Update][InputShapeRange]Parse input shape range failed."); |
|
|
return PARAM_INVALID; |
|
|
return PARAM_INVALID; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -767,7 +799,7 @@ Status UpdateDynamicInputShapeRange(const ge::ComputeGraphPtr &compute_graph, co |
|
|
GE_CHECK_NOTNULL(op); |
|
|
GE_CHECK_NOTNULL(op); |
|
|
if (op->GetType() == DATA) { |
|
|
if (op->GetType() == DATA) { |
|
|
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { |
|
|
if (UpdateDataOpShapeRange(op, shape_range_map) != SUCCESS) { |
|
|
GELOGE(FAILED, "Update data op [%s] input shape range failed.", op->GetName().c_str()); |
|
|
|
|
|
|
|
|
GELOGE(FAILED, "[Update][InputShapeRange]Update data op [%s] input shape range failed.", op->GetName().c_str()); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|