| @@ -940,7 +940,7 @@ table Unique { | |||||
| } | } | ||||
| table Unstack { | table Unstack { | ||||
| num: int; | |||||
| num: int; // deprecated | |||||
| axis: int; | axis: int; | ||||
| } | } | ||||
| @@ -31,7 +31,6 @@ OpParameter *PopulateUnstackParameter(const mindspore::lite::PrimitiveC *primiti | |||||
| memset(unstack_param, 0, sizeof(UnstackParameter)); | memset(unstack_param, 0, sizeof(UnstackParameter)); | ||||
| auto param = reinterpret_cast<mindspore::lite::Unstack *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | auto param = reinterpret_cast<mindspore::lite::Unstack *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | ||||
| unstack_param->op_parameter_.type_ = primitive->Type(); | unstack_param->op_parameter_.type_ = primitive->Type(); | ||||
| unstack_param->num_ = param->GetNum(); | |||||
| unstack_param->axis_ = param->GetAxis(); | unstack_param->axis_ = param->GetAxis(); | ||||
| return reinterpret_cast<OpParameter *>(unstack_param); | return reinterpret_cast<OpParameter *>(unstack_param); | ||||
| } | } | ||||
| @@ -22,15 +22,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int Unstack::GetNum() const { return this->primitive_->value.AsUnstack()->num; } | |||||
| int Unstack::GetAxis() const { return this->primitive_->value.AsUnstack()->axis; } | int Unstack::GetAxis() const { return this->primitive_->value.AsUnstack()->axis; } | ||||
| void Unstack::SetNum(int num) { this->primitive_->value.AsUnstack()->num = num; } | |||||
| void Unstack::SetAxis(int axis) { this->primitive_->value.AsUnstack()->axis = axis; } | void Unstack::SetAxis(int axis) { this->primitive_->value.AsUnstack()->axis = axis; } | ||||
| #else | #else | ||||
| int Unstack::GetNum() const { return this->primitive_->value_as_Unstack()->num(); } | |||||
| int Unstack::GetAxis() const { return this->primitive_->value_as_Unstack()->axis(); } | int Unstack::GetAxis() const { return this->primitive_->value_as_Unstack()->axis(); } | ||||
| int Unstack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Unstack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| @@ -32,13 +32,11 @@ class Unstack : public PrimitiveC { | |||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(Unstack, PrimitiveC); | MS_DECLARE_PARENT(Unstack, PrimitiveC); | ||||
| explicit Unstack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit Unstack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| void SetNum(int num); | |||||
| void SetAxis(int axis); | void SetAxis(int axis); | ||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| int GetNum() const; | |||||
| int GetAxis() const; | int GetAxis() const; | ||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -42,6 +42,7 @@ int UnstackCPUKernel::ReSize() { | |||||
| if (para->axis_ < 0) { | if (para->axis_ < 0) { | ||||
| para->axis_ += shape_size; | para->axis_ += shape_size; | ||||
| } | } | ||||
| for (size_t i = 0; i < shape_size; i++) { | for (size_t i = 0; i < shape_size; i++) { | ||||
| if (static_cast<int>(i) < para->axis_) { | if (static_cast<int>(i) < para->axis_) { | ||||
| para->pre_dims_ *= input->DimensionSize(i); | para->pre_dims_ *= input->DimensionSize(i); | ||||
| @@ -71,7 +72,9 @@ int UnstackCPUKernel::Run() { | |||||
| output_addr_array_[i] = reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()); | output_addr_array_[i] = reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()); | ||||
| } | } | ||||
| MS_ASSERT(output_addr_array_); | MS_ASSERT(output_addr_array_); | ||||
| Unistack(input, output_addr_array_, reinterpret_cast<UnstackParameter *>(op_parameter_)); | |||||
| auto para = reinterpret_cast<UnstackParameter *>(op_parameter_); | |||||
| para->num_ = out_num; | |||||
| Unistack(input, output_addr_array_, para); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -40,7 +40,6 @@ PrimitiveC *TfliteUnstackParser::ParseLitePrimitive(const std::unique_ptr<tflite | |||||
| MS_LOG(ERROR) << "get op unstack attr failed"; | MS_LOG(ERROR) << "get op unstack attr failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| attr->num = tflite_attr->num; | |||||
| attr->axis = tflite_attr->axis; | attr->axis = tflite_attr->axis; | ||||
| primitive->value.type = schema::PrimitiveType_Unstack; | primitive->value.type = schema::PrimitiveType_Unstack; | ||||