|
|
|
@@ -2739,7 +2739,7 @@ struct DelTransposeInfo { |
|
|
|
int inputIdx; |
|
|
|
}; |
|
|
|
|
|
|
|
Status GetTransposeInfo(GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, |
|
|
|
Status GetTransposeInfo(domi::tensorflow::GraphDef *graph_def, std::map<std::string, std::string> &softmaxInfo, |
|
|
|
std::map<std::string, DelTransposeInfo> &transposeInfo) { |
|
|
|
GE_CHECK_NOTNULL(graph_def); |
|
|
|
for (int i = 0; i < graph_def->node_size(); ++i) { |
|
|
|
@@ -2813,7 +2813,7 @@ void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTranspose |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *const graph_def) { |
|
|
|
void TensorFlowModelParser::SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def) { |
|
|
|
// The caller guarantees that the pointer is not null |
|
|
|
for (int i = 0; i < graph_def->node_size(); ++i) { |
|
|
|
auto node_def = graph_def->mutable_node(i); |
|
|
|
@@ -3021,7 +3021,7 @@ Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, |
|
|
|
GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS, |
|
|
|
return FAILED); |
|
|
|
const TensorProto &tensor = attr_value.tensor(); |
|
|
|
const TensorShapeProto &tensor_shape = tensor.tensor_shape(); |
|
|
|
const domi::tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape(); |
|
|
|
GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != parser::DIM_DEFAULT_SIZE, |
|
|
|
return SUCCESS); |
|
|
|
GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS); |
|
|
|
@@ -3108,7 +3108,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef |
|
|
|
placeholder_node.clear_input(); |
|
|
|
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); |
|
|
|
domi::tensorflow::AttrValue attr_value; |
|
|
|
TensorShapeProto *data_shape = attr_value.mutable_shape(); |
|
|
|
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape(); |
|
|
|
GE_CHECK_NOTNULL(data_shape); |
|
|
|
const ge::ParserContext &ctx = ge::GetParserContext(); |
|
|
|
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; |
|
|
|
@@ -3181,7 +3181,7 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef |
|
|
|
placeholder_node.clear_input(); |
|
|
|
GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); |
|
|
|
domi::tensorflow::AttrValue attr_value; |
|
|
|
TensorShapeProto *data_shape = attr_value.mutable_shape(); |
|
|
|
domi::tensorflow::TensorShapeProto *data_shape = attr_value.mutable_shape(); |
|
|
|
GE_CHECK_NOTNULL(data_shape); |
|
|
|
const ge::ParserContext &ctx = ge::GetParserContext(); |
|
|
|
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; |
|
|
|
|