Merge pull request !5472 from ling/conv1x1tags/v1.0.0
| @@ -65,7 +65,7 @@ class MS_API Context { | |||||
| virtual ~Context(); | virtual ~Context(); | ||||
| public: | public: | ||||
| bool float16_priority = false; /**< allow priority select float16 kernel */ | |||||
| bool float16_priority = false; /**< prior enable float16 inference */ | |||||
| DeviceContext device_ctx_{DT_CPU}; | DeviceContext device_ctx_{DT_CPU}; | ||||
| int thread_num_ = 2; /**< thread number config for thread pool */ | int thread_num_ = 2; /**< thread number config for thread pool */ | ||||
| std::shared_ptr<Allocator> allocator = nullptr; | std::shared_ptr<Allocator> allocator = nullptr; | ||||
| @@ -50,10 +50,10 @@ class OptimizeModule { | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| if (hwcap & HWCAP_ASIMDDP) { | if (hwcap & HWCAP_ASIMDDP) { | ||||
| printf("Hw cap support SMID Dot Product, hwcap: 0x%x \n", hwcap); | |||||
| MS_LOG(INFO) << "Hw cap support SMID Dot Product, hwcap: 0x" << hwcap; | |||||
| support_optimize_ops = true; | support_optimize_ops = true; | ||||
| } else { | } else { | ||||
| printf("Hw cap NOT support SIMD Dot Product, hwcap: 0x%x\n", hwcap); | |||||
| MS_LOG(INFO) << "Hw cap NOT support SIMD Dot Product, hwcap: 0x" << hwcap; | |||||
| } | } | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| @@ -63,7 +63,7 @@ class OptimizeModule { | |||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||
| optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); | optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); | ||||
| if (optimized_op_handler_ == nullptr) { | if (optimized_op_handler_ == nullptr) { | ||||
| printf("Open optimize shared library failed: %s\n", dlerror()); | |||||
| MS_LOG(INFO) << "Open optimize shared library failed: " << dlerror(); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -178,7 +178,8 @@ void Convolution1x1FP16CPUKernel::Pre1x1Trans(float16_t *src_input, float16_t *s | |||||
| } | } | ||||
| int Convolution1x1FP16CPUKernel::RunImpl(int task_id) { | int Convolution1x1FP16CPUKernel::RunImpl(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||||
| int cur_stride = matmul_param_->col_ - task_id * thread_stride_; | |||||
| int cur_oc = MSMIN(thread_stride_, cur_stride); | |||||
| if (cur_oc <= 0) { | if (cur_oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -148,8 +148,10 @@ static int DeConvFp16Run(void *cdata, int task_id) { | |||||
| } | } | ||||
| int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) { | int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) { | ||||
| int oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); | |||||
| int oc_res = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); | |||||
| int cur_stride = UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_; | |||||
| int oc = MSMIN(thread_stride_, cur_stride); | |||||
| cur_stride = conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM; | |||||
| int oc_res = MSMIN(thread_stride_ * C8NUM, cur_stride); | |||||
| if (oc <= 0) { | if (oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -137,7 +137,8 @@ int Convolution1x1CPUKernel::Init() { | |||||
| } | } | ||||
| int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_); | |||||
| int res_stride = matmul_param_->col_ - task_id * thread_stride_; | |||||
| int cur_oc = MSMIN(thread_stride_, res_stride); | |||||
| if (cur_oc <= 0) { | if (cur_oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -103,8 +103,11 @@ int DeConvFp32Run(void *cdata, int task_id) { | |||||
| } | } | ||||
| int DeConvolutionCPUKernel::DoDeconv(int task_id) { | int DeConvolutionCPUKernel::DoDeconv(int task_id) { | ||||
| int oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); | |||||
| int oc_res = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); | |||||
| int res_stride = UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_; | |||||
| int oc = MSMIN(thread_stride_, res_stride); | |||||
| int cur_stride = thread_stride_ * C8NUM; | |||||
| res_stride = conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM; | |||||
| int oc_res = MSMIN(cur_stride, res_stride); | |||||
| if (oc <= 0 || oc_res <= 0) { | if (oc <= 0 || oc_res <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -272,7 +272,9 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||||
| } | } | ||||
| int Convolution1x1Int8CPUKernel::RunPre(int task_id) { | int Convolution1x1Int8CPUKernel::RunPre(int task_id) { | ||||
| int cur_hw = MSMIN(thread_stride_hw_ * C8NUM, matmul_param_->row_ - task_id * thread_stride_hw_ * C8NUM); | |||||
| int cur_stride = thread_stride_hw_ * C8NUM; | |||||
| int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C8NUM; | |||||
| int cur_hw = MSMIN(cur_stride, res_stride); | |||||
| if (cur_hw <= 0) { | if (cur_hw <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -227,8 +227,13 @@ int DeConvInt8Run(void *cdata, int task_id) { | |||||
| } | } | ||||
| int DeConvInt8CPUKernel::DoDeconv(int task_id) { | int DeConvInt8CPUKernel::DoDeconv(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); | |||||
| int cur_oc_res = MSMIN(thread_stride_ * C4NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C4NUM); | |||||
| int cur_stride = thread_stride_; | |||||
| int res_stride = UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_; | |||||
| int cur_oc = MSMIN(cur_stride, res_stride); | |||||
| cur_stride = thread_stride_ * C4NUM; | |||||
| res_stride = conv_param_->output_channel_ - task_id * thread_stride_ * C4NUM; | |||||
| int cur_oc_res = MSMIN(cur_stride, res_stride); | |||||
| if (cur_oc <= 0) { | if (cur_oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||