Browse Source

!5583 [MS][LITE][Develop]fix concat and slice

Merge pull request !5583 from sunsuodong/fix_concat_slice
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d27178bf2b
2 changed files with 7 additions and 7 deletions
  1. +0
    -6
      mindspore/lite/src/ops/concat.cc
  2. +7
    -1
      mindspore/lite/src/ops/slice.cc

+ 0
- 6
mindspore/lite/src/ops/concat.cc View File

@@ -107,14 +107,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
} }
auto input0_shape_without_axis = input0_shape; auto input0_shape_without_axis = input0_shape;
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis); input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
auto input0_data_type = inputs_.at(0)->data_type();
int output_axis_dim = input0_shape.at(axis); int output_axis_dim = input0_shape.at(axis);
for (size_t i = 1; i < inputs_.size(); ++i) { for (size_t i = 1; i < inputs_.size(); ++i) {
if (inputs_.at(i)->data_type() != input0_data_type) {
MS_LOG(ERROR) << "All inputs should have the same data type!";
return RET_PARAM_INVALID;
}

auto shape_tmp = inputs_.at(i)->shape(); auto shape_tmp = inputs_.at(i)->shape();
if (shape_tmp.size() != input0_shape.size()) { if (shape_tmp.size() != input0_shape.size()) {
MS_LOG(ERROR) << "All inputs should have the same dim num!"; MS_LOG(ERROR) << "All inputs should have the same dim num!";


+ 7
- 1
mindspore/lite/src/ops/slice.cc View File

@@ -60,6 +60,12 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
return RET_ERROR; return RET_ERROR;
} }


std::vector<int32_t> axes;
if (attr->axes() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->axes()->size()); i++) {
axes.push_back(attr->axes()->data()[i]);
}
}
std::vector<int32_t> begin; std::vector<int32_t> begin;
if (attr->begin() != nullptr) { if (attr->begin() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) { for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) {
@@ -73,7 +79,7 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
} }
} }


auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &begin, &size);
auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &axes, &begin, &size);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o);
fbb->Finish(prim_offset); fbb->Finish(prim_offset);
return RET_OK; return RET_OK;


Loading…
Cancel
Save