diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 0c1fa797bf..ca11616e53 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -940,7 +940,7 @@ table Unique { } table Unstack { - num: int; + num: int; // deprecated axis: int; } diff --git a/mindspore/lite/src/ops/populate/unstack_populate.cc b/mindspore/lite/src/ops/populate/unstack_populate.cc index 09f9768d30..c2b7647003 100644 --- a/mindspore/lite/src/ops/populate/unstack_populate.cc +++ b/mindspore/lite/src/ops/populate/unstack_populate.cc @@ -31,7 +31,6 @@ OpParameter *PopulateUnstackParameter(const mindspore::lite::PrimitiveC *primiti memset(unstack_param, 0, sizeof(UnstackParameter)); auto param = reinterpret_cast(const_cast(primitive)); unstack_param->op_parameter_.type_ = primitive->Type(); - unstack_param->num_ = param->GetNum(); unstack_param->axis_ = param->GetAxis(); return reinterpret_cast(unstack_param); } diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc index 52deeafabb..7913476f25 100644 --- a/mindspore/lite/src/ops/unstack.cc +++ b/mindspore/lite/src/ops/unstack.cc @@ -22,15 +22,12 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE -int Unstack::GetNum() const { return this->primitive_->value.AsUnstack()->num; } int Unstack::GetAxis() const { return this->primitive_->value.AsUnstack()->axis; } -void Unstack::SetNum(int num) { this->primitive_->value.AsUnstack()->num = num; } void Unstack::SetAxis(int axis) { this->primitive_->value.AsUnstack()->axis = axis; } #else -int Unstack::GetNum() const { return this->primitive_->value_as_Unstack()->num(); } int Unstack::GetAxis() const { return this->primitive_->value_as_Unstack()->axis(); } int Unstack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/unstack.h b/mindspore/lite/src/ops/unstack.h index b1e25f6a92..9dd73df784 100644 --- a/mindspore/lite/src/ops/unstack.h +++ b/mindspore/lite/src/ops/unstack.h @@ -32,13 +32,11 @@ class Unstack : public PrimitiveC { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Unstack, PrimitiveC); explicit Unstack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - void SetNum(int num); void SetAxis(int axis); #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; - int GetNum() const; int GetAxis() const; }; } // namespace lite diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack_fp32.cc index 87535715c3..97d8f03f80 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/unstack_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/unstack_fp32.cc @@ -42,6 +42,7 @@ int UnstackCPUKernel::ReSize() { if (para->axis_ < 0) { para->axis_ += shape_size; } + for (size_t i = 0; i < shape_size; i++) { if (static_cast(i) < para->axis_) { para->pre_dims_ *= input->DimensionSize(i); @@ -71,7 +72,9 @@ int UnstackCPUKernel::Run() { output_addr_array_[i] = reinterpret_cast(out_tensors_.at(i)->MutableData()); } MS_ASSERT(output_addr_array_); - Unistack(input, output_addr_array_, reinterpret_cast(op_parameter_)); + auto para = reinterpret_cast(op_parameter_); + para->num_ = out_num; + Unistack(input, output_addr_array_, para); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index 869e76bd73..7110da2adc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -40,7 +40,6 @@ PrimitiveC *TfliteUnstackParser::ParseLitePrimitive(const std::unique_ptrnum = tflite_attr->num; attr->axis = tflite_attr->axis; primitive->value.type = schema::PrimitiveType_Unstack;