Browse Source

stride_slice-5

tags/v1.2.0-rc1
yefeng 4 years ago
parent
commit
152992d3a9
3 changed files with 15 additions and 76 deletions
  1. +3
    -1
      mindspore/lite/src/ops/strided_slice.cc
  2. +8
    -71
      mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc
  3. +4
    -4
      mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc

+ 3
- 1
mindspore/lite/src/ops/strided_slice.cc View File

@@ -359,7 +359,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
MS_ASSERT(input != nullptr); MS_ASSERT(input != nullptr);
auto input_shape = input->shape(); auto input_shape = input->shape();
auto inferflag = infer_flag(); auto inferflag = infer_flag();

if (!infer_flag()) {
return RET_INFER_INVALID;
}
in_shape_.clear(); in_shape_.clear();
if (inferflag) { if (inferflag) {
in_shape_.assign(input_shape.begin(), input_shape.end()); in_shape_.assign(input_shape.begin(), input_shape.end());


+ 8
- 71
mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc View File

@@ -72,76 +72,6 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_ERROR; return RET_ERROR;
} }
attr->shrinkAxisMask = attr_value.i(); attr->shrinkAxisMask = attr_value.i();

// begin
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (begin_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input begin failed";
return RET_ERROR;
}
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->begin.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->begin.push_back(data[i]);
}
}

// end
auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2));
if (end_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input end failed";
return RET_ERROR;
}
if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->end.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->end.push_back(data[i]);
}
}

// strides
auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3));
if (stride_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input strides failed";
return RET_ERROR;
}
if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->stride.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->stride.push_back(data[i]);
}
}

primitive->value.type = schema::PrimitiveType_StridedSlice; primitive->value.type = schema::PrimitiveType_StridedSlice;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release()); *primitiveC = PrimitiveC::Create(primitive.release());
@@ -151,7 +81,14 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
} }


*output_size = 1; *output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
STATUS status = RET_OK;
for (int i = 0; i < tf_op.input_size(); i++) {
status = AddOpInput(tf_op, i, inputs);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add Op input failed.";
return status;
}
}
return status; return status;
} }
TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser()); TFNodeRegistrar g_tfStrideSliceParser("StridedSlice", new TFStrideSliceParser());


+ 4
- 4
mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc View File

@@ -71,8 +71,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {


auto fw_shape = auto fw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_}); VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_});
auto fw_stride =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), fw_shape});
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)),
fw_shape, std::make_shared<SeqVar>()});
auto fw_min = auto fw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2}); VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2});


@@ -106,8 +106,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
bw_reverse_seq, std::make_shared<Var>()}); bw_reverse_seq, std::make_shared<Var>()});
auto bw_shape = auto bw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans}); VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans});
auto bw_stride =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)), bw_shape});
auto bw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_StridedSlice)),
bw_shape, std::make_shared<SeqVar>()});
auto bw_min = auto bw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2}); VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2});
auto bw_reserve = auto bw_reserve =


Loading…
Cancel
Save