| @@ -100,6 +100,13 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| const int k_plane = 16; | const int k_plane = 16; | ||||
| MS_ASSERT(ctx_->allocator != nullptr); | MS_ASSERT(ctx_->allocator != nullptr); | ||||
| size_t tile_buffer_size = thread_count_ * C12NUM * C16NUM * ic4 * C4NUM * sizeof(float); | |||||
| tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size)); | |||||
| if (tile_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc tile buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float); | size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float); | ||||
| block_unit_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(block_unit_buffer_size)); | block_unit_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(block_unit_buffer_size)); | ||||
| if (block_unit_buffer_ == nullptr) { | if (block_unit_buffer_ == nullptr) { | ||||
| @@ -171,10 +178,6 @@ int Convolution3x3CPUKernel::ReSize() { | |||||
| free(nhwc4_input_); | free(nhwc4_input_); | ||||
| nhwc4_input_ = nullptr; | nhwc4_input_ = nullptr; | ||||
| } | } | ||||
| if (tile_buffer_ != nullptr) { | |||||
| free(tile_buffer_); | |||||
| tile_buffer_ = nullptr; | |||||
| } | |||||
| ret = ConvolutionBaseCPUKernel::Init(); | ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -192,13 +195,6 @@ int Convolution3x3CPUKernel::ReSize() { | |||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| size_t tile_buffer_size = thread_count_ * C12NUM * C16NUM * iC4 * C4NUM * sizeof(float); | |||||
| tile_buffer_ = reinterpret_cast<float *>(malloc(tile_buffer_size)); | |||||
| if (tile_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc tile buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(tile_buffer_, 0, tile_buffer_size); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -33,10 +33,6 @@ class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| if (transformed_filter_addr_ != nullptr) { | if (transformed_filter_addr_ != nullptr) { | ||||
| free(transformed_filter_addr_); | free(transformed_filter_addr_); | ||||
| } | } | ||||
| if (tile_buffer_ != nullptr) { | |||||
| free(tile_buffer_); | |||||
| tile_buffer_ = nullptr; | |||||
| } | |||||
| } | } | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| @@ -49,6 +45,10 @@ class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| private: | private: | ||||
| void FreeTmpBuffer() { | void FreeTmpBuffer() { | ||||
| if (tile_buffer_ != nullptr) { | |||||
| ctx_->allocator->Free(tile_buffer_); | |||||
| tile_buffer_ = nullptr; | |||||
| } | |||||
| if (block_unit_buffer_ != nullptr) { | if (block_unit_buffer_ != nullptr) { | ||||
| ctx_->allocator->Free(block_unit_buffer_); | ctx_->allocator->Free(block_unit_buffer_); | ||||
| block_unit_buffer_ = nullptr; | block_unit_buffer_ = nullptr; | ||||
| @@ -200,6 +200,13 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); | int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); | ||||
| MS_ASSERT(ctx_->allocator != nullptr); | MS_ASSERT(ctx_->allocator != nullptr); | ||||
| size_t tile_buffer_size = thread_count_ * C12NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); | |||||
| trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size)); | |||||
| if (trans_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| gemm_out_ = reinterpret_cast<float *>( | gemm_out_ = reinterpret_cast<float *>( | ||||
| ctx_->allocator->Malloc(thread_count_ * C12NUM * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float))); | ctx_->allocator->Malloc(thread_count_ * C12NUM * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float))); | ||||
| if (gemm_out_ == nullptr) { | if (gemm_out_ == nullptr) { | ||||
| @@ -290,10 +297,6 @@ int ConvolutionWinogradCPUKernel::ReSize() { | |||||
| free(nhwc4_input_); | free(nhwc4_input_); | ||||
| nhwc4_input_ = nullptr; | nhwc4_input_ = nullptr; | ||||
| } | } | ||||
| if (trans_input_ != nullptr) { | |||||
| free(trans_input_); | |||||
| trans_input_ = nullptr; | |||||
| } | |||||
| ret = ConvolutionBaseCPUKernel::Init(); | ret = ConvolutionBaseCPUKernel::Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -316,14 +319,6 @@ int ConvolutionWinogradCPUKernel::ReSize() { | |||||
| } | } | ||||
| memset(nhwc4_input_, 0, nhwc4_input_size); | memset(nhwc4_input_, 0, nhwc4_input_size); | ||||
| size_t tile_buffer_size = thread_count_ * C12NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); | |||||
| trans_input_ = reinterpret_cast<float *>(malloc(tile_buffer_size)); | |||||
| if (trans_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(trans_input_, 0, tile_buffer_size); | |||||
| ret = ConfigInputOutput(); | ret = ConfigInputOutput(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConfigInputOutput failed."; | MS_LOG(ERROR) << "ConfigInputOutput failed."; | ||||
| @@ -38,10 +38,6 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| delete trans_weight_; | delete trans_weight_; | ||||
| trans_weight_ = nullptr; | trans_weight_ = nullptr; | ||||
| } | } | ||||
| if (trans_input_ != nullptr) { | |||||
| free(trans_input_); | |||||
| trans_input_ = nullptr; | |||||
| } | |||||
| }; | }; | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| @@ -55,6 +51,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| private: | private: | ||||
| void FreeTmpBuffer() { | void FreeTmpBuffer() { | ||||
| if (trans_input_ != nullptr) { | |||||
| ctx_->allocator->Free(trans_input_); | |||||
| trans_input_ = nullptr; | |||||
| } | |||||
| if (tmp_data_ != nullptr) { | if (tmp_data_ != nullptr) { | ||||
| ctx_->allocator->Free(tmp_data_); | ctx_->allocator->Free(tmp_data_); | ||||
| tmp_data_ = nullptr; | tmp_data_ = nullptr; | ||||