|
|
|
@@ -27,9 +27,14 @@ using mindspore::lite::RET_ERROR; |
|
|
|
using mindspore::lite::RET_FORMAT_ERR; |
|
|
|
using mindspore::lite::RET_OK; |
|
|
|
using mindspore::schema::PrimitiveType_BatchToSpace; |
|
|
|
using mindspore::schema::PrimitiveType_BatchToSpaceND; |
|
|
|
|
|
|
|
namespace mindspore::kernel { |
|
|
|
int BatchToSpaceBaseCPUKernel::Init() { |
|
|
|
if (in_tensors_[0]->GetFormat() != schema::Format::Format_NHWC) { |
|
|
|
MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; |
|
|
|
return RET_FORMAT_ERR; |
|
|
|
} |
|
|
|
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->op_parameter_); |
|
|
|
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { |
|
|
|
if (param->crops_[i] != 0) { |
|
|
|
@@ -40,9 +45,10 @@ int BatchToSpaceBaseCPUKernel::Init() { |
|
|
|
} |
|
|
|
|
|
|
|
int BatchToSpaceBaseCPUKernel::ReSize() { |
|
|
|
if (in_tensors_[0]->GetFormat() != schema::Format::Format_NHWC) { |
|
|
|
MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; |
|
|
|
return RET_FORMAT_ERR; |
|
|
|
auto shape = in_tensors_[0]->shape(); |
|
|
|
if (shape.size() != 4) { |
|
|
|
MS_LOG(ERROR) << "Unsupport shape size: " << shape.size(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -52,7 +58,6 @@ kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::Ten |
|
|
|
OpParameter *op_parameter, const lite::Context *ctx, |
|
|
|
const kernel::KernelKey &desc, |
|
|
|
const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace); |
|
|
|
if (op_parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Input op_parameter is nullptr!"; |
|
|
|
return nullptr; |
|
|
|
@@ -78,7 +83,6 @@ kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector<lite::Ten |
|
|
|
OpParameter *op_parameter, const lite::Context *ctx, |
|
|
|
const kernel::KernelKey &desc, |
|
|
|
const mindspore::lite::PrimitiveC *primitive) { |
|
|
|
MS_ASSERT(desc.type == schema::PrimitiveType_BatchToSpace); |
|
|
|
if (op_parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Input op_parameter is nullptr!"; |
|
|
|
return nullptr; |
|
|
|
@@ -100,5 +104,7 @@ kernel::LiteKernel *CpuBatchToSpaceFp32KernelCreator(const std::vector<lite::Ten |
|
|
|
} |
|
|
|
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BatchToSpace, CpuBatchToSpaceInt8KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_BatchToSpaceND, CpuBatchToSpaceInt8KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, CpuBatchToSpaceFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_BatchToSpaceND, CpuBatchToSpaceFp32KernelCreator) |
|
|
|
} // namespace mindspore::kernel |