From 9757a97d5b5e255e2e4cb1470b2c0e00f1099320 Mon Sep 17 00:00:00 2001 From: yefeng Date: Thu, 28 Jan 2021 10:50:03 +0800 Subject: [PATCH] batch_space --- mindspore/lite/src/ops/batch_to_space.cc | 127 ++++++++++++++---- mindspore/lite/src/ops/batch_to_space.h | 5 + mindspore/lite/src/ops/conv2d.cc | 3 + .../ops/populate/batch_to_space_populate.cc | 6 + .../populate/space_to_batch_nd_populate.cc | 6 + mindspore/lite/src/ops/space_to_batch_nd.cc | 110 ++++++++++++--- mindspore/lite/src/ops/space_to_batch_nd.h | 2 + .../kernel/arm/fp32/batch_to_space_fp32.cc | 59 ++++++-- .../kernel/arm/fp32/batch_to_space_fp32.h | 6 + .../kernel/arm/fp32/space_to_batch_fp32.cc | 38 ++++++ .../kernel/arm/fp32/space_to_batch_fp32.h | 1 + .../parser/tf/tf_batch_to_space_nd_parser.cc | 65 +++++++++ .../parser/tf/tf_batch_to_space_nd_parser.h | 36 +++++ .../parser/tf/tf_space_to_batch_nd_parser.cc | 64 +++++++++ .../parser/tf/tf_space_to_batch_nd_parser.h | 36 +++++ 15 files changed, 508 insertions(+), 56 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.h diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index 8dc4b2d229..da7dcc3316 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -78,34 +78,19 @@ Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCr namespace { constexpr int kBatchToSpaceOutputNum = 1; -constexpr int kBatchToSpaceInputNum = 1; +constexpr int kBatchToSpaceOneInput = 1; +constexpr int kBatchToSpaceThreeInput = 3; constexpr int kBlockShapeSize = 2; constexpr int kCropsSize = 4; } // namespace -int BatchToSpace::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive_ != nullptr); - if (outputs.size() != kBatchToSpaceOutputNum || inputs.size() != kBatchToSpaceInputNum) { - MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); - return RET_PARAM_INVALID; - } - - auto input = inputs.at(0); - if (input->format() != schema::Format::Format_NHWC) { - MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; - return RET_FORMAT_ERR; - } - outputs[0]->set_format(input->format()); - outputs[0]->set_data_type(input->data_type()); - if (!infer_flag()) { - return RET_INFER_INVALID; - } - auto input_shape = input->shape(); +int BatchToSpace::SetOutputShapeFromParam(const std::vector inputs, + std::vector outputs) { + auto input_shape = inputs[0]->shape(); if (input_shape.size() != kQuadrupleNum) { MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum; return RET_PARAM_INVALID; } - auto block_shape = GetBlockShape(); if (block_shape.size() != kBlockShapeSize) { MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize; @@ -116,7 +101,7 @@ int BatchToSpace::InferShape(std::vector inputs, std::vector inputs, std::vector inputs, std::vector output_shape(input_shape.size()); - output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape; + output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape_; output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1]; output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3]; - output_shape[NHWC_C] = input_shape[NHWC_C]; + if (input_shape.size() > 3) { + output_shape[NHWC_C] = input_shape[NHWC_C]; + } + outputs[0]->set_shape(output_shape); + return RET_OK; +} +int BatchToSpace::SetOutputShapeFromInput(const std::vector inputs, + std::vector outputs) { + auto input_shape = inputs[0]->shape(); + if (input_shape.size() != kQuadrupleNum) { + MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum; + return RET_PARAM_INVALID; + } + auto block_shape_data = inputs[1]->data_c(); + auto crops_data = inputs[2]->data_c(); + auto block_shape = static_cast(block_shape_data); + auto crops = static_cast(crops_data); + if (inputs[1]->ElementsNum() != kBlockShapeSize) { + MS_LOG(ERROR) << "Block shape size should be " << kBlockShapeSize; + return RET_PARAM_INVALID; + } + if (inputs[2]->ElementsNum() != kCropsSize) { + MS_LOG(ERROR) << "Crops size should be " << kCropsSize; + return RET_PARAM_INVALID; + } + mul_block_shape_ = 1; + + for (size_t i = 0; i < kBlockShapeSize; ++i) { + if (block_shape[i] <= 0) { + MS_LOG(ERROR) << "Input block_shape should > 0!"; + return RET_PARAM_INVALID; + } + if (input_shape[NHWC_N] % block_shape[i]) { + MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " can not divide block_shape[" << i << "] " + << block_shape[i]; + return 1; + } + mul_block_shape_ *= block_shape[i]; + } + + if (input_shape[NHWC_N] < mul_block_shape_) { + MS_LOG(ERROR) << "Dimension n " << input_shape[NHWC_N] << " < product of block shape!"; + return RET_PARAM_INVALID; + } + for (size_t i = 0; i < kCropsSize; ++i) { + if (crops[i] < 0) { + MS_LOG(ERROR) << "Input crops should >= 0"; + return RET_PARAM_INVALID; + } + } + std::vector output_shape(input_shape.size()); + output_shape[NHWC_N] = input_shape[NHWC_N] / mul_block_shape_; + output_shape[NHWC_H] = input_shape[NHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[NHWC_W] = input_shape[NHWC_W] * block_shape[1] - crops[2] - crops[3]; + if (input_shape.size() > 3) { + output_shape[NHWC_C] = input_shape[NHWC_C]; + } outputs[0]->set_shape(output_shape); return RET_OK; } + +int BatchToSpace::InferShape(std::vector inputs, std::vector outputs) { + MS_ASSERT(this->primitive_ != nullptr); + if (outputs.size() != kBatchToSpaceOutputNum || + (inputs.size() != kBatchToSpaceOneInput && inputs.size() != kBatchToSpaceThreeInput)) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return RET_PARAM_INVALID; + } + + auto input = inputs.at(0); + if (input->format() != schema::Format::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } + outputs[0]->set_format(input->format()); + outputs[0]->set_data_type(input->data_type()); + if (!infer_flag()) { + return RET_INFER_INVALID; + } + + if (inputs.size() == kBatchToSpaceOneInput) { + auto ret = SetOutputShapeFromParam(inputs, outputs); + return ret; + } + if (inputs.size() == kBatchToSpaceThreeInput) { + if (inputs[0]->data_c() == nullptr) { + return RET_INFER_INVALID; + } + MS_ASSERT(inputs[1]->data_c() != nullptr); + MS_ASSERT(inputs[2]->data_c() != nullptr); + auto ret = SetOutputShapeFromInput(inputs, outputs); + return ret; + } + + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/batch_to_space.h b/mindspore/lite/src/ops/batch_to_space.h index ce3e5756e3..aa8a2433b8 100644 --- a/mindspore/lite/src/ops/batch_to_space.h +++ b/mindspore/lite/src/ops/batch_to_space.h @@ -40,6 +40,11 @@ class BatchToSpace : public PrimitiveC { int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetBlockShape() const; std::vector GetCrops() const; + + private: + int SetOutputShapeFromParam(const std::vector inputs, std::vector outputs); + int SetOutputShapeFromInput(const std::vector inputs, std::vector outputs); + int mul_block_shape_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 3bb9fa81e6..66656db10f 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -396,6 +396,9 @@ int Conv2D::InferShape(std::vector inputs_, std::vector outp return RET_INFER_INVALID; } auto in_shape = input_tensor->shape(); + if (in_shape.size() == 0) { + return RET_INFER_INVALID; + } int input_h = in_shape.at(1); int input_w = in_shape.at(2); int output_w = 0, output_h = 0; diff --git a/mindspore/lite/src/ops/populate/batch_to_space_populate.cc b/mindspore/lite/src/ops/populate/batch_to_space_populate.cc index 5332e0cfd6..a3ae90ac9b 100644 --- a/mindspore/lite/src/ops/populate/batch_to_space_populate.cc +++ b/mindspore/lite/src/ops/populate/batch_to_space_populate.cc @@ -34,6 +34,9 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr batch_space_param->op_parameter_.type_ = primitive->Type(); auto param = reinterpret_cast(const_cast(primitive)); auto block_shape = param->GetBlockShape(); + if (block_shape.empty()) { + return reinterpret_cast(batch_space_param); + } if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; free(batch_space_param); @@ -41,6 +44,9 @@ OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *pr } auto crops = param->GetCrops(); + if (crops.empty()) { + return reinterpret_cast(batch_space_param); + } if (crops.size() != COMM_SHAPE_SIZE) { MS_LOG(ERROR) << "batch_to_space crops size should be " << COMM_SHAPE_SIZE; free(batch_space_param); diff --git a/mindspore/lite/src/ops/populate/space_to_batch_nd_populate.cc b/mindspore/lite/src/ops/populate/space_to_batch_nd_populate.cc index 316602e728..0682c9fe31 100644 --- a/mindspore/lite/src/ops/populate/space_to_batch_nd_populate.cc +++ b/mindspore/lite/src/ops/populate/space_to_batch_nd_populate.cc @@ -31,6 +31,9 @@ OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC * space_batch_param_nd->op_parameter_.type_ = primitive->Type(); auto block_sizes = ((mindspore::lite::SpaceToBatchND *)primitive)->GetBlockShape(); + if (block_sizes.empty()) { + return reinterpret_cast(space_batch_param_nd); + } space_batch_param_nd->m_ = block_sizes.size(); if (block_sizes.size() > std::numeric_limits::max() / sizeof(int)) { MS_LOG(ERROR) << "The value of block_sizes.size() is too big"; @@ -39,6 +42,9 @@ OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC * } memcpy(space_batch_param_nd->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); auto paddings = ((mindspore::lite::SpaceToBatchND *)primitive)->GetPaddings(); + if (paddings.empty()) { + return reinterpret_cast(space_batch_param_nd); + } if (paddings.size() > std::numeric_limits::max() / sizeof(int)) { MS_LOG(ERROR) << "The value of paddings.size() is too big"; free(space_batch_param_nd); diff --git a/mindspore/lite/src/ops/space_to_batch_nd.cc b/mindspore/lite/src/ops/space_to_batch_nd.cc index fda4d4d2ae..ef2d33534e 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.cc +++ b/mindspore/lite/src/ops/space_to_batch_nd.cc @@ -26,7 +26,8 @@ namespace mindspore { namespace lite { namespace { constexpr int kSpaceToBatchNDOutputNum = 1; -constexpr int kSpaceToBatchNDInputNum = 1; +constexpr int kSpaceToBatchNDOneInput = 1; +constexpr int kSpaceToBatchNDThreeInput = 3; } // namespace #ifdef PRIMITIVE_WRITEABLE @@ -86,23 +87,9 @@ Registry SpaceToBatchNDRegistry(schema::PrimitiveType_SpaceToBatchND, SpaceToBat #endif // PRIMITIVE_WRITEABLE -int SpaceToBatchND::InferShape(std::vector inputs, std::vector outputs) { - if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) { - MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); - return 1; - } - - auto input = inputs.at(0); - if (input->format() != schema::Format::Format_NHWC) { - MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!"; - return RET_ERROR; - } - outputs.at(0)->set_data_type(input->data_type()); - outputs.at(0)->set_format(input->format()); - if (!infer_flag()) { - return RET_INFER_INVALID; - } - auto input_shape = input->shape(); +int SpaceToBatchND::SetOutputShapeFromParam(const std::vector inputs, + std::vector outputs) { + auto input_shape = inputs[0]->shape(); if (input_shape.size() != kQuadrupleNum) { MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!"; return RET_ERROR; @@ -133,9 +120,94 @@ int SpaceToBatchND::InferShape(std::vector inputs, std::vector 3) { + output_shape.at(NHWC_C) = input_shape.at(NHWC_C); + } outputs.at(0)->set_shape(output_shape); return RET_OK; } + +int SpaceToBatchND::SetOutputShapeFromInput(const std::vector inputs, + std::vector outputs) { + auto input_shape = inputs[0]->shape(); + if (input_shape.size() != kQuadrupleNum) { + MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!"; + return RET_ERROR; + } + MS_ASSERT(inputs[2]->ElementsNum() == 4); + auto block_shape_data = inputs[1]->data_c(); + auto block_shape = static_cast(block_shape_data); + auto padding_data = inputs[2]->data_c(); + auto padding = static_cast(padding_data); + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (inputs[1]->ElementsNum() == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + std::vector output_shape(input_shape.size()); + if (block_shape[0] * block_w > std::numeric_limits::max() / input_shape.at(NHWC_N)) { + MS_LOG(ERROR) << "The value of block_shape.at(0) * block_w is too big"; + return RET_ERROR; + } + output_shape.at(NHWC_N) = input_shape.at(NHWC_N) * block_shape[0] * block_w; + if (padding[0] + padding[1] > std::numeric_limits::max() - input_shape.at(NHWC_H)) { + MS_LOG(ERROR) << "The value of padding.at(0) + padding.at(1) is too big"; + return RET_ERROR; + } + output_shape.at(NHWC_H) = (input_shape.at(NHWC_H) + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > std::numeric_limits::max() - input_shape.at(NHWC_W)) { + MS_LOG(ERROR) << "The value of padding_left + padding_right is too big"; + return RET_ERROR; + } + output_shape.at(NHWC_W) = (input_shape.at(NHWC_W) + padding_left + padding_right) / block_w; + if (input_shape.size() > 3) { + output_shape.at(NHWC_C) = input_shape.at(NHWC_C); + } + outputs.at(0)->set_shape(output_shape); + return RET_OK; +} + +int SpaceToBatchND::InferShape(std::vector inputs, std::vector outputs) { + if (outputs.size() != kSpaceToBatchNDOutputNum || + (inputs.size() != kSpaceToBatchNDOneInput && inputs.size() != kSpaceToBatchNDThreeInput)) { + MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); + return 1; + } + + auto input = inputs.at(0); + if (input->format() != schema::Format::Format_NHWC) { + MS_LOG(ERROR) << "space_to_batch_nd only support NHWC now!"; + return RET_ERROR; + } + outputs.at(0)->set_data_type(input->data_type()); + outputs.at(0)->set_format(input->format()); + if (!infer_flag()) { + return RET_INFER_INVALID; + } + + if (inputs.size() == kSpaceToBatchNDOneInput) { + auto ret = SetOutputShapeFromParam(inputs, outputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetOutputShapeFromParam failed"; + return ret; + } + } + if (inputs.size() == kSpaceToBatchNDThreeInput) { + if (inputs[0]->data_c() == nullptr) { + return RET_INFER_INVALID; + } + MS_ASSERT(inputs[1]->data_c() != nullptr); + MS_ASSERT(inputs[2]->data_c() != nullptr); + auto ret = SetOutputShapeFromInput(inputs, outputs); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetOutputShapeFromInput failed"; + return ret; + } + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/space_to_batch_nd.h b/mindspore/lite/src/ops/space_to_batch_nd.h index 3b92211990..f77780b909 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.h +++ b/mindspore/lite/src/ops/space_to_batch_nd.h @@ -39,6 +39,8 @@ class SpaceToBatchND : public PrimitiveC { std::vector GetBlockShape() const; std::vector GetPaddings() const; int InferShape(std::vector inputs, std::vector outputs) override; + int SetOutputShapeFromParam(const std::vector inputs, std::vector outputs); + int SetOutputShapeFromInput(const std::vector inputs, std::vector outputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc index dea637ac76..b61e6c5acd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.cc @@ -16,23 +16,46 @@ #include "src/runtime/kernel/arm/fp32/batch_to_space_fp32.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" +#include "src/ops/batch_to_space.h" using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_BatchToSpace; using mindspore::schema::PrimitiveType_BatchToSpaceND; namespace mindspore::kernel { +int BatchToSpaceCPUKernel::Processinput() { + MS_ASSERT(in_tensors_[1]->data_c() != nullptr); + MS_ASSERT(in_tensors_[2]->data_c() != nullptr); + auto block_shape_data = in_tensors_[1]->data_c(); + auto crops_data = in_tensors_[2]->data_c(); + auto block_shape = static_cast(block_shape_data); + auto crops = static_cast(crops_data); + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + block_shape_[i] = block_shape[i]; + } + no_crop_ = true; + for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { + crops_[i] = crops[i]; + if (crops_[i] != 0) { + no_crop_ = false; + } + } + return RET_OK; +} + int BatchToSpaceCPUKernel::Init() { MS_ASSERT(in_tensors_.at(0)->format() == schema::Format::Format_NHWC); if (!InferShapeDone()) { - return lite::RET_OK; + return RET_OK; } return ReSize(); } int BatchToSpaceCPUKernel::ReSize() { MS_ASSERT(in_tensors_.at(0)->shape().size() == 4); - return lite::RET_OK; + return RET_OK; } int BatchToSpaceCPUKernel::Run() { @@ -42,17 +65,29 @@ int BatchToSpaceCPUKernel::Run() { float *output_data = reinterpret_cast(output->MutableData()); auto in_shape = input->shape(); auto out_shape = output->shape(); - BatchToSpaceParameter *param = reinterpret_cast(this->op_parameter_); - - if (param->no_crop_) { - BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, - sizeof(float)); - } else { - BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_, - sizeof(float)); + if (in_tensors_.size() == 1) { + BatchToSpaceParameter *param = reinterpret_cast(this->op_parameter_); + if (param->no_crop_) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, + sizeof(float)); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], param->block_shape_, param->crops_, + sizeof(float)); + } } - - return lite::RET_OK; + if (in_tensors_.size() == 3) { + auto ret = Processinput(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Processinput failed in BatchToSpace."; + return ret; + } + if (no_crop_) { + BatchToSpaceNoCropForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, sizeof(float)); + } else { + BatchToSpaceForNHWC(input_data, output_data, in_shape.data(), out_shape[0], block_shape_, crops_, sizeof(float)); + } + } + return RET_OK; } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.h index f3aa4d027b..46994b5278 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/batch_to_space_fp32.h @@ -34,6 +34,12 @@ class BatchToSpaceCPUKernel : public LiteKernel { int Init() override; int ReSize() override; int Run() override; + int Processinput(); + + private: + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; + int32_t crops_[COMM_SHAPE_SIZE]; + bool no_crop_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc index e01f02e94c..d590ab7d9f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc @@ -25,6 +25,32 @@ using mindspore::schema::PrimitiveType_SpaceToBatch; using mindspore::schema::PrimitiveType_SpaceToBatchND; namespace mindspore::kernel { +void SpaceToBatchCPUKernel::ProcessInput() { + MS_ASSERT(in_tensors_[1] != nullptr); + MS_ASSERT(in_tensors_[2] != nullptr); + auto input_tensor = in_tensors_.at(0); + MS_ASSERT(input_tensor); + auto output_tensor = out_tensors_.at(0); + MS_ASSERT(output_tensor); + MS_ASSERT(param_); + for (size_t i = 0; i < DIMENSION_4D; i++) { + param_->input_shape_[i] = input_tensor->shape().at(i); + param_->output_shape_[i] = output_tensor->shape().at(i); + } + ComputeStrides(param_->input_shape_, param_->in_stride_, DIMENSION_4D); + ComputeStrides(param_->output_shape_, param_->out_stride_, DIMENSION_4D); + auto block_shape_data = in_tensors_[1]->data_c(); + auto block_shape = static_cast(block_shape_data); + for (int i = 0; i < in_tensors_[1]->ElementsNum(); i++) { + param_->block_sizes_[i] = block_shape[i]; + } + auto padding_data = in_tensors_[2]->data_c(); + auto padding = static_cast(padding_data); + for (int i = 0; i < in_tensors_[2]->ElementsNum(); i++) { + param_->paddings_[i] = padding[i]; + } +} + int SpaceToBatchCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; @@ -39,6 +65,12 @@ int SpaceToBatchFp32Run(void *cdata, int task_id) { } int SpaceToBatchCPUKernel::ReSize() { + if (in_tensors_.size() == 3) { + if (in_tensors_[1] != nullptr && in_tensors_[1]->IsConst() && in_tensors_[2] != nullptr && + in_tensors_[2]->IsConst()) { + ProcessInput(); + } + } auto input_tensor = in_tensors_.at(0); MS_ASSERT(input_tensor); auto output_tensor = out_tensors_.at(0); @@ -61,8 +93,14 @@ void SpaceToBatchCPUKernel::DoRun(int task_id) { } int SpaceToBatchCPUKernel::Run() { + MS_ASSERT(in_tensors_[0] != nullptr); input_ptr_ = reinterpret_cast(in_tensors_.at(0)->data_c()); output_ptr_ = reinterpret_cast(out_tensors_.at(0)->data_c()); + if (in_tensors_.size() == 3) { + if (!in_tensors_[1]->IsConst() || !in_tensors_[2]->IsConst()) { + ProcessInput(); + } + } ParallelLaunch(this->context_->thread_pool_, SpaceToBatchFp32Run, this, op_parameter_->thread_num_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.h index a778ef63ae..fa0a3ce895 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.h @@ -35,6 +35,7 @@ class SpaceToBatchCPUKernel : public LiteKernel { int Init() override; int ReSize() override; int Run() override; + void ProcessInput(); public: void DoRun(int task_id); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.cc new file mode 100644 index 0000000000..75f8db75c4 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_batch_to_space_nd_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFBatchToSpaceNDParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(WARNING) << "TF BatchToSpaceNDParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_BatchToSpace; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfBatchToSpaceNDParser("BatchToSpaceND", new TFBatchToSpaceNDParser()); +TFNodeRegistrar g_tfBatchToSpaceParser("BatchToSpace", new TFBatchToSpaceNDParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.h new file mode 100644 index 0000000000..c28f194088 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_batch_to_space_nd_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFBatchToSpaceNDParser : public TFNodeParser { + public: + TFBatchToSpaceNDParser() = default; + ~TFBatchToSpaceNDParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_TO_SPACE_ND_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.cc new file mode 100644 index 0000000000..89eeb0a067 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_space_to_batch_nd_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSpaceToBatchNDParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(WARNING) << "TF SpaceToBatchNDParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_SpaceToBatchND; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); ++i) { + auto status = AddOpInput(tf_op, i, inputs); + if (status != RET_OK) { + return status; + } + } + return RET_OK; +} +TFNodeRegistrar g_tfSpaceToBatchNDParser("SpaceToBatchND", new TFSpaceToBatchNDParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.h new file mode 100644 index 0000000000..339c76034a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_space_to_batch_nd_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSpaceToBatchNDParser : public TFNodeParser { + public: + TFSpaceToBatchNDParser() = default; + ~TFSpaceToBatchNDParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPACE_TO_BATCH_ND_PARSER_H_