Browse Source

!4176 [MS][LITE][Develop]optimize infershape when running graph

Merge pull request !4176 from chenjianping/lite_dev2
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
05f405c0bc
19 changed files with 96 additions and 37 deletions
  1. +13
    -4
      mindspore/lite/src/lite_kernel.h
  2. +5
    -4
      mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc
  3. +2
    -4
      mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h
  4. +8
    -4
      mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc
  5. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h
  6. +3
    -5
      mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc
  7. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h
  8. +9
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc
  9. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h
  10. +13
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc
  11. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h
  12. +9
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc
  13. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h
  14. +10
    -3
      mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc
  15. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h
  16. +8
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc
  17. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h
  18. +8
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc
  19. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h

+ 13
- 4
mindspore/lite/src/lite_kernel.h View File

@@ -62,6 +62,7 @@ class LiteKernel {
const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive),
context_(ctx) {
opParameter->thread_num_ = ctx->thread_num_;
this->in_kernel_.clear();
this->out_kernel_.clear();
}
@@ -69,12 +70,13 @@ class LiteKernel {
virtual ~LiteKernel() { delete opParameter; }

virtual int Prepare() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
if (!InferShapeDone()) {
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_);
if (need_reinit) {
Init();
}
}
if (need_reinit) {
Init();
}

auto &outputs = this->GetOutputs();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
@@ -126,6 +128,13 @@ class LiteKernel {
}

protected:
bool InferShapeDone() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
return false;
}
return true;
}

KernelKey desc;
std::string name;
OpParameter *opParameter = nullptr;


+ 5
- 4
mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc View File

@@ -32,10 +32,6 @@ using mindspore::schema::PrimitiveType_ArgMin;

namespace mindspore::kernel {
int ArgMinMaxBaseCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
switch (opParameter->type_) {
case PrimitiveType_ArgMax:
@@ -49,8 +45,13 @@ int ArgMinMaxBaseCPUKernel::Init() {
return RET_ERROR;
}

return RET_OK;
}

int ArgMinMaxBaseCPUKernel::ReSize() {
auto in_shape = inputs_.at(0)->shape();
auto dims_size = in_shape.size();
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_;
param->axis_ = axis;
param->dims_size_ = dims_size;


+ 2
- 4
mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.h View File

@@ -26,15 +26,13 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel {
ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {
opParameter->thread_num_ = ctx->thread_num_;
}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {}

virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); }

int Init() override;

int ReSize() override { return 0; }
int ReSize() override;

int Run() override;



+ 8
- 4
mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc View File

@@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_BatchToSpace;

namespace mindspore::kernel {
int BatchToSpaceBaseCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
if (param->crops_[i] != 0) {
@@ -43,6 +39,14 @@ int BatchToSpaceBaseCPUKernel::Init() {
return RET_OK;
}

int BatchToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
return RET_OK;
}

kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.h View File

@@ -35,7 +35,7 @@ class BatchToSpaceBaseCPUKernel : public LiteKernel {

int Init() override;

int ReSize() override { return 0; }
int ReSize() override;

int Run() override { return 0; }



+ 3
- 5
mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.cc View File

@@ -31,11 +31,9 @@ using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_DepthToSpace;

namespace mindspore::kernel {
int DepthToSpaceBaseCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
int DepthToSpaceBaseCPUKernel::Init() { return RET_OK; }

int DepthToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR;


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/depth_to_space_base.h View File

@@ -35,7 +35,7 @@ class DepthToSpaceBaseCPUKernel : public LiteKernel {

int Init() override;

int ReSize() override { return 0; }
int ReSize() override;

int Run() override { return 0; }
};


+ 9
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.cc View File

@@ -36,7 +36,15 @@ int ArgMinMaxCPUKernel::Init() {
}
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
param->data_type_ = kNumberTypeFloat32;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int ArgMinMaxCPUKernel::ReSize() {
ArgMinMaxBaseCPUKernel::FreeTmpMemory();
return ArgMinMaxBaseCPUKernel::ReSize();
}

int ArgMinMaxCPUKernel::Run() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/argminmax.h View File

@@ -30,7 +30,7 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel {
~ArgMinMaxCPUKernel() = default;

int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel


+ 13
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.cc View File

@@ -24,7 +24,19 @@ using mindspore::lite::RET_OK;

namespace mindspore::kernel {
int BatchToSpaceCPUKernel::Init() {
return BatchToSpaceBaseCPUKernel::Init();
auto ret = BatchToSpaceBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}

if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int BatchToSpaceCPUKernel::ReSize() {
return BatchToSpaceBaseCPUKernel::ReSize();
}

int BatchToSpaceCPUKernel::Run() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space.h View File

@@ -29,7 +29,7 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel {
~BatchToSpaceCPUKernel() = default;

int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel


+ 9
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.cc View File

@@ -37,7 +37,15 @@ int DepthToSpaceCPUKernel::Init() {
}
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
param->data_type_size_ = sizeof(float);
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}

return ReSize();
}

int DepthToSpaceCPUKernel::ReSize() {
return DepthToSpaceBaseCPUKernel::ReSize();
}

int DepthToSpaceCPUKernel::Run() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/depth_to_space.h View File

@@ -29,7 +29,7 @@ class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel {
~DepthToSpaceCPUKernel() = default;

int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel


+ 10
- 3
mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.cc View File

@@ -40,14 +40,21 @@ int ArgMinMaxInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int ArgMinMaxInt8CPUKernel::ReSize() {
return ArgMinMaxBaseCPUKernel::ReSize();
}

int ArgMinMaxInt8CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
auto input = inputs_.at(0);



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/argminmax_int8.h View File

@@ -31,7 +31,7 @@ class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel {
~ArgMinMaxInt8CPUKernel() = default;

int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
private:
QuantArg in_quant_arg_;


+ 8
- 1
mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.cc View File

@@ -38,7 +38,14 @@ int BatchToSpaceInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int BatchToSpaceInt8CPUKernel::ReSize() {
return BatchToSpaceBaseCPUKernel::ReSize();
}

int BatchToSpaceInt8CPUKernel::Run() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/batch_to_space_int8.h View File

@@ -30,7 +30,7 @@ class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel {
~BatchToSpaceInt8CPUKernel() = default;

int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
private:
QuantArg in_quant_arg_;


+ 8
- 1
mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.cc View File

@@ -42,7 +42,14 @@ int DepthToSpaceInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}

int DepthToSpaceInt8CPUKernel::ReSize() {
return DepthToSpaceBaseCPUKernel::ReSize();
}

int DepthToSpaceInt8CPUKernel::Run() {


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/depth_to_space_int8.h View File

@@ -30,7 +30,7 @@ class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel {
~DepthToSpaceInt8CPUKernel() = default;

int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
private:
QuantArg in_quant_arg_;


Loading…
Cancel
Save