Browse Source

unstack remote num

tags/v1.1.0
chenjianping 5 years ago
parent
commit
ef72f405e0
6 changed files with 5 additions and 9 deletions
  1. +1
    -1
      mindspore/lite/schema/ops.fbs
  2. +0
    -1
      mindspore/lite/src/ops/populate/unstack_populate.cc
  3. +0
    -3
      mindspore/lite/src/ops/unstack.cc
  4. +0
    -2
      mindspore/lite/src/ops/unstack.h
  5. +4
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/unstack_fp32.cc
  6. +0
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc

+ 1
- 1
mindspore/lite/schema/ops.fbs View File

@@ -940,7 +940,7 @@ table Unique {
}

table Unstack {
num: int;
num: int; // deprecated
axis: int;
}



+ 0
- 1
mindspore/lite/src/ops/populate/unstack_populate.cc View File

@@ -31,7 +31,6 @@ OpParameter *PopulateUnstackParameter(const mindspore::lite::PrimitiveC *primiti
memset(unstack_param, 0, sizeof(UnstackParameter));
auto param = reinterpret_cast<mindspore::lite::Unstack *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
unstack_param->op_parameter_.type_ = primitive->Type();
unstack_param->num_ = param->GetNum();
unstack_param->axis_ = param->GetAxis();
return reinterpret_cast<OpParameter *>(unstack_param);
}


+ 0
- 3
mindspore/lite/src/ops/unstack.cc View File

@@ -22,15 +22,12 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Unstack::GetNum() const { return this->primitive_->value.AsUnstack()->num; }
int Unstack::GetAxis() const { return this->primitive_->value.AsUnstack()->axis; }

void Unstack::SetNum(int num) { this->primitive_->value.AsUnstack()->num = num; }
void Unstack::SetAxis(int axis) { this->primitive_->value.AsUnstack()->axis = axis; }

#else

int Unstack::GetNum() const { return this->primitive_->value_as_Unstack()->num(); }
int Unstack::GetAxis() const { return this->primitive_->value_as_Unstack()->axis(); }
int Unstack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);


+ 0
- 2
mindspore/lite/src/ops/unstack.h View File

@@ -32,13 +32,11 @@ class Unstack : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Unstack, PrimitiveC);
explicit Unstack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNum(int num);
void SetAxis(int axis);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetNum() const;
int GetAxis() const;
};
} // namespace lite


+ 4
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/unstack_fp32.cc View File

@@ -42,6 +42,7 @@ int UnstackCPUKernel::ReSize() {
if (para->axis_ < 0) {
para->axis_ += shape_size;
}

for (size_t i = 0; i < shape_size; i++) {
if (static_cast<int>(i) < para->axis_) {
para->pre_dims_ *= input->DimensionSize(i);
@@ -71,7 +72,9 @@ int UnstackCPUKernel::Run() {
output_addr_array_[i] = reinterpret_cast<float *>(out_tensors_.at(i)->MutableData());
}
MS_ASSERT(output_addr_array_);
Unistack(input, output_addr_array_, reinterpret_cast<UnstackParameter *>(op_parameter_));
auto para = reinterpret_cast<UnstackParameter *>(op_parameter_);
para->num_ = out_num;
Unistack(input, output_addr_array_, para);
return RET_OK;
}



+ 0
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc View File

@@ -40,7 +40,6 @@ PrimitiveC *TfliteUnstackParser::ParseLitePrimitive(const std::unique_ptr<tflite
MS_LOG(ERROR) << "get op unstack attr failed";
return nullptr;
}
attr->num = tflite_attr->num;
attr->axis = tflite_attr->axis;

primitive->value.type = schema::PrimitiveType_Unstack;


Loading…
Cancel
Save