diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index db73bac434..9742aab110 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -107,14 +107,8 @@ int Concat::InferShape(std::vector inputs_, std::vectordata_type(); int output_axis_dim = input0_shape.at(axis); for (size_t i = 1; i < inputs_.size(); ++i) { - if (inputs_.at(i)->data_type() != input0_data_type) { - MS_LOG(ERROR) << "All inputs should have the same data type!"; - return RET_PARAM_INVALID; - } - auto shape_tmp = inputs_.at(i)->shape(); if (shape_tmp.size() != input0_shape.size()) { MS_LOG(ERROR) << "All inputs should have the same dim num!"; diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index bfdd2ac039..b421a4c885 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -60,6 +60,12 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: return RET_ERROR; } + std::vector axes; + if (attr->axes() != nullptr) { + for (int i = 0; i < static_cast(attr->axes()->size()); i++) { + axes.push_back(attr->axes()->data()[i]); + } + } std::vector begin; if (attr->begin() != nullptr) { for (int i = 0; i < static_cast(attr->begin()->size()); i++) { @@ -73,7 +79,7 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: } } - auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &begin, &size); + auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &axes, &begin, &size); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o); fbb->Finish(prim_offset); return RET_OK;