| @@ -50,3 +50,35 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean, | |||||
| const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum, | |||||
| float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| for (int c = 0; c < ch; c++) { | |||||
| int ix = i * ch + c; | |||||
| dbias[c] += yt[ix]; | |||||
| // dscale | |||||
| float x_hat = (in[ix] - mean[c]) * invar[c]; | |||||
| dscale[c] += (yt[ix] * x_hat); | |||||
| // dx_1 | |||||
| float dx_hat = yt[ix] * scale[c]; | |||||
| dxhat_sum[c] += dx_hat; | |||||
| dxhathat_sum[c] += dx_hat * x_hat; | |||||
| } | |||||
| } | |||||
| } | |||||
| void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, | |||||
| const float *restrict invar, const float *restrict scale, int size, int total_size, int ch, | |||||
| const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) { | |||||
| float N = (float)total_size; | |||||
| for (int i = 0; i < size; i++) { | |||||
| for (int c = 0; c < ch; c++) { | |||||
| // dx_2 | |||||
| int ix = i * ch + c; | |||||
| float x_hat = (in[ix] - mean[c]) * invar[c]; | |||||
| float dx_hat = yt[ix] * scale[c]; | |||||
| dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -32,6 +32,10 @@ extern "C" { | |||||
| void var2Invar(float *save_var, int size, float eps); | void var2Invar(float *save_var, int size, float eps); | ||||
| void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, | void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, | ||||
| int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx); | int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx); | ||||
| void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, | |||||
| int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale); | |||||
| void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, | |||||
| int total_size, int ch, const float *dxhat_sum, const float *dxhathat_sum, float *dx); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,379 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32_grad/convolution_grad_filter.h" | |||||
| #ifdef ENABLE_ARM | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #ifdef ENABLE_ARM | |||||
| static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, | |||||
| const ConvParameter *conv_param) { | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; | |||||
| int k_w = conv_param->kernel_w_; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int x_size = in_h * in_w * in_ch; | |||||
| int y_size = out_ch * out_h * out_w; | |||||
| int k_spatial = k_w * k_h; | |||||
| int i_kh = k_idx / k_w; | |||||
| int i_kw = k_idx % k_w; | |||||
| for (; i_c < (out_ch & ~15); i_c += 16) { | |||||
| float32x4_t sum_03_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t sum_47_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t sum_12x_4 = vdupq_n_f32(0.0f); | |||||
| for (int b = 0; b < batch; ++b) { | |||||
| const float *x_addr = &x[b * x_size]; | |||||
| const float *dy_addr = &dy[b * y_size]; | |||||
| for (int i = 0; i < m; i++) { | |||||
| int idx = i; | |||||
| int input_h = idx / out_w * conv_param->stride_h_; | |||||
| int input_w = idx % out_w * conv_param->stride_w_; | |||||
| int input_row = -conv_param->pad_u_ + i_kh + input_h; | |||||
| int input_col = -conv_param->pad_l_ + i_kw + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { | |||||
| int offset_x = (input_row * in_w + input_col) * out_ch + i_c; | |||||
| int offset_dy = idx * out_ch + i_c; | |||||
| float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); | |||||
| float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); | |||||
| sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); | |||||
| float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); | |||||
| float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); | |||||
| sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); | |||||
| float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); | |||||
| float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); | |||||
| sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); | |||||
| float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12); | |||||
| float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12); | |||||
| sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4); | |||||
| } | |||||
| } | |||||
| } | |||||
| dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; | |||||
| dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; | |||||
| dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; | |||||
| dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; | |||||
| dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; | |||||
| dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; | |||||
| dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; | |||||
| dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; | |||||
| dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; | |||||
| dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; | |||||
| dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; | |||||
| dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; | |||||
| dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0]; | |||||
| dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1]; | |||||
| dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2]; | |||||
| dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3]; | |||||
| } | |||||
| return i_c; | |||||
| } | |||||
| static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, | |||||
| const ConvParameter *conv_param) { | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; | |||||
| int k_w = conv_param->kernel_w_; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int x_size = in_h * in_w * in_ch; | |||||
| int y_size = out_ch * out_h * out_w; | |||||
| int k_spatial = k_w * k_h; | |||||
| int i_kh = k_idx / k_w; | |||||
| int i_kw = k_idx % k_w; | |||||
| if ((out_ch - i_c) >= 12) { | |||||
| float32x4_t sum_03_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t sum_47_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); | |||||
| for (int b = 0; b < batch; ++b) { | |||||
| const float *x_addr = &x[b * x_size]; | |||||
| const float *dy_addr = &dy[b * y_size]; | |||||
| for (int i = 0; i < m; i++) { | |||||
| int idx = i; | |||||
| int input_h = idx / out_w * conv_param->stride_h_; | |||||
| int input_w = idx % out_w * conv_param->stride_w_; | |||||
| int input_row = -conv_param->pad_u_ + i_kh + input_h; | |||||
| int input_col = -conv_param->pad_l_ + i_kw + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { | |||||
| int offset_x = (input_row * in_w + input_col) * out_ch + i_c; | |||||
| int offset_dy = idx * out_ch + i_c; | |||||
| float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); | |||||
| float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); | |||||
| sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); | |||||
| float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); | |||||
| float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); | |||||
| sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); | |||||
| float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); | |||||
| float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); | |||||
| sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); | |||||
| } | |||||
| } | |||||
| } | |||||
| dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; | |||||
| dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; | |||||
| dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; | |||||
| dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; | |||||
| dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; | |||||
| dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; | |||||
| dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; | |||||
| dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; | |||||
| dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; | |||||
| dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; | |||||
| dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; | |||||
| dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; | |||||
| i_c += 12; | |||||
| } | |||||
| return i_c; | |||||
| } | |||||
| static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, | |||||
| const ConvParameter *conv_param) { | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; | |||||
| int k_w = conv_param->kernel_w_; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int x_size = in_h * in_w * in_ch; | |||||
| int y_size = out_ch * out_h * out_w; | |||||
| int k_spatial = k_w * k_h; | |||||
| int i_kh = k_idx / k_w; | |||||
| int i_kw = k_idx % k_w; | |||||
| if ((out_ch - i_c) >= 8) { | |||||
| float32x4_t sum_03_4 = vdupq_n_f32(0.0f); | |||||
| float32x4_t sum_47_4 = vdupq_n_f32(0.0f); | |||||
| for (int b = 0; b < batch; ++b) { | |||||
| const float *x_addr = &x[b * x_size]; | |||||
| const float *dy_addr = &dy[b * y_size]; | |||||
| for (int i = 0; i < m; i++) { | |||||
| int idx = i; | |||||
| int input_h = idx / out_w * conv_param->stride_h_; | |||||
| int input_w = idx % out_w * conv_param->stride_w_; | |||||
| int input_row = -conv_param->pad_u_ + i_kh + input_h; | |||||
| int input_col = -conv_param->pad_l_ + i_kw + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { | |||||
| int offset_x = (input_row * in_w + input_col) * out_ch + i_c; | |||||
| int offset_dy = idx * out_ch + i_c; | |||||
| float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); | |||||
| float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); | |||||
| sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); | |||||
| float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); | |||||
| float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); | |||||
| sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); | |||||
| } | |||||
| } | |||||
| } | |||||
| dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; | |||||
| dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; | |||||
| dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; | |||||
| dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; | |||||
| dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; | |||||
| dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; | |||||
| dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; | |||||
| dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; | |||||
| i_c += 8; | |||||
| } | |||||
| return i_c; | |||||
| } | |||||
| static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, | |||||
| const ConvParameter *conv_param) { | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; | |||||
| int k_w = conv_param->kernel_w_; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int x_size = in_h * in_w * in_ch; | |||||
| int y_size = out_ch * out_h * out_w; | |||||
| int k_spatial = k_w * k_h; | |||||
| int i_kh = k_idx / k_w; | |||||
| int i_kw = k_idx % k_w; | |||||
| if ((out_ch - i_c) >= 4) { | |||||
| float32x4_t sum_4 = vdupq_n_f32(0.0f); | |||||
| for (int b = 0; b < batch; ++b) { | |||||
| const float *x_addr = &x[b * x_size]; | |||||
| const float *dy_addr = &dy[b * y_size]; | |||||
| for (int i = 0; i < m; i++) { | |||||
| int idx = i; | |||||
| int input_h = idx / out_w * conv_param->stride_h_; | |||||
| int input_w = idx % out_w * conv_param->stride_w_; | |||||
| int input_row = -conv_param->pad_u_ + i_kh + input_h; | |||||
| int input_col = -conv_param->pad_l_ + i_kw + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { | |||||
| int offset_x = (input_row * in_w + input_col) * out_ch + i_c; | |||||
| int offset_dy = idx * out_ch + i_c; | |||||
| float32x4_t x_4 = vld1q_f32(x_addr + offset_x); | |||||
| float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy); | |||||
| sum_4 = vmlaq_f32(sum_4, x_4, dy_4); | |||||
| } | |||||
| } | |||||
| } | |||||
| dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0]; | |||||
| dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1]; | |||||
| dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2]; | |||||
| dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3]; | |||||
| i_c += 4; | |||||
| } | |||||
| return i_c; | |||||
| } | |||||
| static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, | |||||
| const ConvParameter *conv_param) { | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; | |||||
| int k_w = conv_param->kernel_w_; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int x_size = in_h * in_w * in_ch; | |||||
| int y_size = out_ch * out_h * out_w; | |||||
| int k_spatial = k_w * k_h; | |||||
| int i_kh = k_idx / k_w; | |||||
| int i_kw = k_idx % k_w; | |||||
| if ((out_ch - i_c) >= 2) { | |||||
| float32x2_t sum_2 = vdup_n_f32(0.0f); | |||||
| for (int b = 0; b < batch; ++b) { | |||||
| const float *x_addr = &x[b * x_size]; | |||||
| const float *dy_addr = &dy[b * y_size]; | |||||
| for (int i = 0; i < m; i++) { | |||||
| int idx = i; | |||||
| int input_h = idx / out_w * conv_param->stride_h_; | |||||
| int input_w = idx % out_w * conv_param->stride_w_; | |||||
| int input_row = -conv_param->pad_u_ + i_kh + input_h; | |||||
| int input_col = -conv_param->pad_l_ + i_kw + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { | |||||
| int offset_x = (input_row * in_w + input_col) * out_ch + i_c; | |||||
| int offset_dy = idx * out_ch + i_c; | |||||
| float32x2_t x_4 = vld1_f32(x_addr + offset_x); | |||||
| float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy); | |||||
| sum_2 = vmla_f32(sum_2, x_4, dy_4); | |||||
| } | |||||
| } | |||||
| } | |||||
| dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0]; | |||||
| dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1]; | |||||
| i_c += 2; | |||||
| } | |||||
| return i_c += 2; | |||||
| } | |||||
| #endif | |||||
| int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, | |||||
| const ConvParameter *conv_param) { | |||||
| int in_h = conv_param->input_h_; | |||||
| int in_w = conv_param->input_w_; | |||||
| int k_h = conv_param->kernel_h_; | |||||
| int k_w = conv_param->kernel_w_; | |||||
| int batch = conv_param->output_batch_; | |||||
| int out_ch = conv_param->output_channel_; | |||||
| int in_ch = conv_param->input_channel_; | |||||
| int out_h = conv_param->output_h_; | |||||
| int out_w = conv_param->output_w_; | |||||
| int m = out_h * out_w; | |||||
| int x_size = in_h * in_w * in_ch; | |||||
| int y_size = out_ch * out_h * out_w; | |||||
| int k_spatial = k_w * k_h; | |||||
| for (int i_k = 0; i_k < count; i_k++) { | |||||
| int k_idx = start + i_k; | |||||
| int i_kh = k_idx / k_w; | |||||
| int i_kw = k_idx % k_w; | |||||
| int i_c = 0; | |||||
| #ifdef ENABLE_ARM | |||||
| i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); | |||||
| i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param); | |||||
| i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); | |||||
| i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); | |||||
| i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param); | |||||
| #endif | |||||
| for (; i_c < out_ch; i_c++) { | |||||
| float sum = 0; | |||||
| for (int b = 0; b < batch; ++b) { | |||||
| const float *x_addr = &x[b * x_size]; | |||||
| const float *dy_addr = &dy[b * y_size]; | |||||
| for (int i = 0; i < m; i++) { | |||||
| int idx = i; | |||||
| int input_h = idx / out_w * conv_param->stride_h_; | |||||
| int input_w = idx % out_w * conv_param->stride_w_; | |||||
| int input_row = -conv_param->pad_u_ + i_kh + input_h; | |||||
| int input_col = -conv_param->pad_l_ + i_kw + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { | |||||
| int offset_x = (input_row * in_w + input_col) * out_ch + i_c; | |||||
| int offset_dy = idx * out_ch + i_c; | |||||
| sum += x_addr[offset_x] * dy_addr[offset_dy]; | |||||
| } | |||||
| } | |||||
| } | |||||
| dw[i_c * k_spatial + k_idx] = sum; | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ | |||||
| #include <stddef.h> | |||||
| #include "nnacl/conv_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ | |||||
| @@ -18,6 +18,56 @@ | |||||
| #include "nnacl/fp32_grad/pack_ext.h" | #include "nnacl/fp32_grad/pack_ext.h" | ||||
| #include "nnacl/pack.h" | #include "nnacl/pack.h" | ||||
| void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig, | |||||
| int real_cal_num, int start) { | |||||
| const int pad_left = conv_param->pad_l_; | |||||
| const int pad_up = conv_param->pad_u_; | |||||
| const int stride_h = conv_param->stride_h_; | |||||
| const int stride_w = conv_param->stride_w_; | |||||
| const int dilation_h = conv_param->dilation_h_; | |||||
| const int dilation_w = conv_param->dilation_w_; | |||||
| const int kernel_h = conv_param->kernel_h_; | |||||
| const int kernel_w = conv_param->kernel_w_; | |||||
| const int in_height = conv_param->input_h_; | |||||
| const int in_width = conv_param->input_w_; | |||||
| const int output_w = conv_param->output_w_; | |||||
| const int channels = conv_param->input_channel_; | |||||
| const int stride = kernel_h * kernel_w; | |||||
| int kernel_row, kernel_col; | |||||
| for (int i = 0; i < real_cal_num; i++) { | |||||
| int block_start = start + i; | |||||
| int input_h = block_start / output_w * stride_h; | |||||
| int input_w = block_start % output_w * stride_w; | |||||
| float *data_col = data_col_orig + i * channels * stride; | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h + input_h; | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_col = -pad_left + kernel_col * dilation_w + input_w; | |||||
| if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * channels; | |||||
| for (int c = 0; c < channels; c++) { | |||||
| data_col[c * stride] = in_data[offset + c]; | |||||
| } | |||||
| data_col++; | |||||
| } else { | |||||
| for (int c = 0; c < channels; c++) { | |||||
| data_col[c * stride] = 0; | |||||
| } | |||||
| data_col++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, | void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, | ||||
| int start) { | int start) { | ||||
| const int pad_left = conv_param->pad_l_; | const int pad_left = conv_param->pad_l_; | ||||
| @@ -90,85 +140,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con | |||||
| rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); | rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); | ||||
| } | } | ||||
| void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) { | |||||
| const int pad_left = conv_param->pad_l_; | |||||
| const int pad_up = conv_param->pad_u_; | |||||
| const int stride_h = conv_param->stride_h_; | |||||
| const int stride_w = conv_param->stride_w_; | |||||
| const int dilation_h = conv_param->dilation_h_; | |||||
| const int dilation_w = conv_param->dilation_w_; | |||||
| const int kernel_h = conv_param->kernel_h_; | |||||
| const int kernel_w = conv_param->kernel_w_; | |||||
| const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_; | |||||
| const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_; | |||||
| const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_; | |||||
| const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_; | |||||
| const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_; | |||||
| const int channels = tot_channels / conv_param->group_; | |||||
| int channel, kernel_row, kernel_col, output_rows, output_col; | |||||
| if (transpose) { | |||||
| for (channel = 0; channel < channels; channel++) { | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h; | |||||
| for (output_rows = output_h; output_rows; output_rows--) { | |||||
| if (!((unsigned)(input_row) < (unsigned)(in_height))) { | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| *(data_row++) = 0; | |||||
| } | |||||
| } else { | |||||
| int input_col = -pad_left + kernel_col * dilation_w; | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| if (((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | |||||
| *(data_row++) = in_data[offset]; | |||||
| } else { | |||||
| *(data_row++) = 0; | |||||
| } | |||||
| input_col += stride_w; | |||||
| } | |||||
| } | |||||
| input_row += stride_h; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { | |||||
| for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { | |||||
| for (channel = 0; channel < channels; channel++) { | |||||
| int input_row = -pad_up + kernel_row * dilation_h; | |||||
| for (output_rows = output_h; output_rows; output_rows--) { | |||||
| if (!((unsigned)(input_row) < (unsigned)(in_height))) { | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| *(data_row++) = 0; | |||||
| } | |||||
| } else { | |||||
| int input_col = -pad_left + kernel_col * dilation_w; | |||||
| for (output_col = output_w; output_col; output_col--) { | |||||
| if (((unsigned)(input_col) < (unsigned)(in_width))) { | |||||
| const int offset = (input_row * in_width + input_col) * tot_channels + channel; | |||||
| *(data_row++) = in_data[offset]; | |||||
| } else { | |||||
| *(data_row++) = 0; | |||||
| } | |||||
| input_col += stride_w; | |||||
| } | |||||
| } | |||||
| input_row += stride_h; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { | void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { | ||||
| const int pad_left = conv_param->pad_l_; | const int pad_left = conv_param->pad_l_; | ||||
| const int pad_up = conv_param->pad_u_; | const int pad_up = conv_param->pad_u_; | ||||
| @@ -26,6 +26,9 @@ extern "C" { | |||||
| void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, | void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, | ||||
| int real_cal_num, int block_index); | int real_cal_num, int block_index); | ||||
| void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, | |||||
| int real_cal_num, int block_index); | |||||
| void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); | void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); | ||||
| void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); | void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); | ||||
| void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); | void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include <float.h> | #include <float.h> | ||||
| #include "nnacl/fp32_grad/pooling_grad.h" | #include "nnacl/fp32_grad/pooling_grad.h" | ||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { | |||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param) { | |||||
| int stride_w = pooling_param->stride_w_; | int stride_w = pooling_param->stride_w_; | ||||
| int stride_h = pooling_param->stride_h_; | int stride_h = pooling_param->stride_h_; | ||||
| int pad_w = pooling_param->pad_l_; | int pad_w = pooling_param->pad_l_; | ||||
| @@ -30,29 +30,58 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter | |||||
| int in_h = pooling_param->input_h_; | int in_h = pooling_param->input_h_; | ||||
| int output_w = pooling_param->output_w_; | int output_w = pooling_param->output_w_; | ||||
| int output_h = pooling_param->output_h_; | int output_h = pooling_param->output_h_; | ||||
| int output_batch = pooling_param->output_batch_; | |||||
| memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); | |||||
| float kk = (float)(win_h * win_w); | |||||
| for (int ib = 0; ib < output_batch; ib++) { | |||||
| const float kk = 1.0f / (float)(win_h * win_w); | |||||
| #if ENABLE_ARM | |||||
| const float32x4_t factor = vdupq_n_f32(kk); | |||||
| #endif | |||||
| for (int ib = 0; ib < count; ib++) { | |||||
| float *out = &output_ptr[(ib * in_h * in_w * channel)]; | float *out = &output_ptr[(ib * in_h * in_w * channel)]; | ||||
| const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; | const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; | ||||
| // iterate over yt | // iterate over yt | ||||
| for (int yh = 0; yh < output_h; yh++) { | for (int yh = 0; yh < output_h; yh++) { | ||||
| int over_h = pad_h - yh * stride_h; | |||||
| int kh_s = MSMAX(0, over_h); | |||||
| int kh_e = MSMIN(win_h, in_h + over_h); | |||||
| for (int yw = 0; yw < output_w; yw++) { | for (int yw = 0; yw < output_w; yw++) { | ||||
| for (int ic = 0; ic < channel; ic++) { | |||||
| int over_w = pad_w - yw * stride_w; | |||||
| int kw_s = MSMAX(0, over_w); | |||||
| int kw_e = MSMIN(win_w, in_w + over_w); | |||||
| int ic = 0; | |||||
| for (; ic < channel - 4; ic += 4) { | |||||
| int idx = (yw + yh * output_w) * channel + ic; | int idx = (yw + yh * output_w) * channel + ic; | ||||
| float delta = inPtr[idx] / kk; | |||||
| for (int kh = 0; kh < win_h; kh++) { | |||||
| #ifdef ENABLE_ARM | |||||
| float32x4_t in = vld1q_f32(inPtr + idx); | |||||
| float32x4_t delta = vmulq_f32(in, factor); | |||||
| #else | |||||
| float delta[4] = {inPtr[idx], inPtr[idx + 1], inPtr[idx + 2], inPtr[idx + 3]}; | |||||
| for (int i = 0; i < 4; i++) delta[i] *= kk; | |||||
| #endif | |||||
| for (int kh = kh_s; kh < kh_e; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | int xh = yh * stride_h + kh - pad_h; | ||||
| if ((xh < 0) || (xh >= in_h)) { | |||||
| continue; | |||||
| } | |||||
| for (int kw = 0; kw < win_w; kw++) { | |||||
| for (int kw = kw_s; kw < kw_e; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | int xw = yw * stride_w + kw - pad_w; | ||||
| if ((xw < 0) || (xw >= in_w)) { | |||||
| continue; | |||||
| #ifdef ENABLE_ARM | |||||
| float *out_vec = out + (xw + in_w * xh) * channel + ic; | |||||
| float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); | |||||
| float32x4_t outs = vaddq_s32(outr, delta); | |||||
| vst1q_f32(out_vec, outs); | |||||
| #else | |||||
| for (int i = 0; i < 4; i++) { | |||||
| out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i]; | |||||
| } | } | ||||
| #endif | |||||
| } | |||||
| } | |||||
| } | |||||
| for (; ic < channel; ic++) { | |||||
| int idx = (yw + yh * output_w) * channel + ic; | |||||
| float delta = inPtr[idx] * kk; | |||||
| for (int kh = kh_s; kh < kh_e; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | |||||
| for (int kw = kw_s; kw < kw_e; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | |||||
| out[(xw + in_w * xh) * channel + ic] += delta; | out[(xw + in_w * xh) * channel + ic] += delta; | ||||
| } | } | ||||
| } | } | ||||
| @@ -62,8 +91,17 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter | |||||
| } | } | ||||
| } | } | ||||
| void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, | |||||
| PoolingParameter *pooling_param, int task_id) { | |||||
| #ifdef ENABLE_ARM | |||||
| static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { | |||||
| uint32x4_t res = vcgtq_f32(in, *max); | |||||
| uint32x4_t m_index = vbslq_f32(res, index, prev_index); | |||||
| *max = vbslq_f32(res, in, *max); | |||||
| return m_index; | |||||
| } | |||||
| #endif | |||||
| void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, | |||||
| PoolingParameter *pooling_param) { | |||||
| int stride_w = pooling_param->stride_w_; | int stride_w = pooling_param->stride_w_; | ||||
| int stride_h = pooling_param->stride_h_; | int stride_h = pooling_param->stride_h_; | ||||
| int pad_w = pooling_param->pad_l_; | int pad_w = pooling_param->pad_l_; | ||||
| @@ -75,36 +113,71 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy | |||||
| int in_h = pooling_param->input_h_; | int in_h = pooling_param->input_h_; | ||||
| int output_w = pooling_param->output_w_; | int output_w = pooling_param->output_w_; | ||||
| int output_h = pooling_param->output_h_; | int output_h = pooling_param->output_h_; | ||||
| int output_batch = pooling_param->output_batch_; | |||||
| memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); | |||||
| for (int ib = 0; ib < output_batch; ib++) { | for (int ib = 0; ib < output_batch; ib++) { | ||||
| float *out = &output_ptr[(ib * in_h * in_w * channel)]; | float *out = &output_ptr[(ib * in_h * in_w * channel)]; | ||||
| const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); | |||||
| const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); | |||||
| const float *inPtr = &input_ptr[(ib * in_h * in_w * channel)]; | |||||
| const float *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)]; | |||||
| for (int yh = 0; yh < output_h; yh++) { | for (int yh = 0; yh < output_h; yh++) { | ||||
| int over_h = pad_h - yh * stride_h; | |||||
| int kh_s = MSMAX(0, over_h); | |||||
| int kh_e = MSMIN(win_h, in_h + over_h); | |||||
| for (int yw = 0; yw < output_w; yw++) { | for (int yw = 0; yw < output_w; yw++) { | ||||
| for (int ic = 0; ic < channel; ic++) { | |||||
| int over_w = pad_w - yw * stride_w; | |||||
| int kw_s = MSMAX(0, over_w); | |||||
| int kw_e = MSMIN(win_w, in_w + over_w); | |||||
| int ic = 0; | |||||
| for (; ic < channel - 4; ic += 4) { | |||||
| int idx = (yw + yh * output_w) * channel + ic; | int idx = (yw + yh * output_w) * channel + ic; | ||||
| float delta = dyPtr[idx]; | |||||
| float max_val = -FLT_MAX; | |||||
| int max_idx = 0; | |||||
| for (int kh = 0; kh < win_h; kh++) { | |||||
| #ifdef ENABLE_ARM | |||||
| uint32x4_t max_idx = vdupq_n_u32(0); | |||||
| float32x4_t max_val = vdupq_n_f32(-FLT_MAX); | |||||
| float32x4_t delta = vld1q_f32(dyPtr + idx); | |||||
| #else | |||||
| float delta[4] = {dyPtr[idx], dyPtr[idx + 1], dyPtr[idx + 2], dyPtr[idx + 3]}; | |||||
| float max_val[4] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX}; | |||||
| int max_idx[4] = {0}; | |||||
| #endif | |||||
| for (int kh = kh_s; kh < kh_e; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | int xh = yh * stride_h + kh - pad_h; | ||||
| if ((xh < 0) || (xh >= in_h)) { | |||||
| continue; | |||||
| } | |||||
| for (int kw = 0; kw < win_w; kw++) { | |||||
| for (int kw = kw_s; kw < kw_e; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | int xw = yw * stride_w + kw - pad_w; | ||||
| if ((xw < 0) || (xw >= in_w)) { | |||||
| continue; | |||||
| int val_idx = (xw + in_w * xh) * channel + ic; | |||||
| #ifdef ENABLE_ARM | |||||
| unsigned int val_idx_vec[] = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; | |||||
| uint32x4_t index = vld1q_u32(val_idx_vec); | |||||
| float32x4_t in = vld1q_f32(inPtr + val_idx); | |||||
| max_idx = MaxIndex(in, &max_val, index, max_idx); | |||||
| #else | |||||
| float val[4] = {inPtr[val_idx], inPtr[val_idx + 1], inPtr[val_idx + 2], inPtr[val_idx + 3]}; | |||||
| for (int i = 0; i < 4; i++) { | |||||
| if (val[i] > max_val[i]) { | |||||
| max_val[i] = val[i]; | |||||
| max_idx[i] = val_idx + i; | |||||
| } | |||||
| } | } | ||||
| if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) { | |||||
| max_val = inPtr[(xw + in_w * xh) * channel + ic]; | |||||
| max_idx = (xw + in_w * xh) * channel + ic; | |||||
| #endif | |||||
| } | |||||
| } | |||||
| for (int i = 0; i < 4; i++) { | |||||
| out[((int *)&max_idx)[i]] += ((float *)&delta)[i]; | |||||
| } | |||||
| } | |||||
| for (; ic < channel; ic++) { | |||||
| float max_val = -FLT_MAX; | |||||
| int max_idx = 0; | |||||
| int idx = (yw + yh * output_w) * channel + ic; | |||||
| float delta = dyPtr[idx]; | |||||
| for (int kh = kh_s; kh < kh_e; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | |||||
| int loop = kw_e - kw_s; | |||||
| for (int kw = 0; kw < loop; kw++) { | |||||
| int xw = yw * stride_w + kw + kw_s - pad_w; | |||||
| int val_idx = (xw + in_w * xh) * channel + ic; | |||||
| float val = inPtr[val_idx]; | |||||
| if (val > max_val) { | |||||
| max_val = val; | |||||
| max_idx = val_idx; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -22,9 +22,9 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); | |||||
| void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr, | |||||
| PoolingParameter *pooling_param, int task_id); | |||||
| void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, PoolingParameter *pooling_param); | |||||
| void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, | |||||
| PoolingParameter *pooling_param); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32_grad/strided_slice_grad.h" | |||||
| #include "nnacl/errorcode.h" | |||||
| static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { | |||||
| size_t res = 1; | |||||
| for (size_t j = 0; j < size; j++) { | |||||
| res *= shape[(i + 1) + j]; | |||||
| } | |||||
| return (pos / res % shape[i]); | |||||
| } | |||||
| int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param) { | |||||
| if (inputs == NULL || output == NULL || param == NULL) { | |||||
| return NNACL_NULL_PTR; | |||||
| } | |||||
| if (param->num_axes_ > DIMENSION_7D) { | |||||
| return NNACL_PARAM_INVALID; | |||||
| } | |||||
| size_t size = 1; | |||||
| int *s = param->strides_; | |||||
| int *b = param->begins_; | |||||
| for (int i = 0; i < DIMENSION_7D; i++) { | |||||
| size *= param->in_shape_[i]; | |||||
| } | |||||
| for (size_t pos = 0; pos < size; pos++) { | |||||
| size_t i = CalcIndex(param->in_shape_, 6, 0, pos); | |||||
| size_t j = CalcIndex(param->in_shape_, 5, 1, pos); | |||||
| size_t k = CalcIndex(param->in_shape_, 4, 2, pos); | |||||
| size_t l = CalcIndex(param->in_shape_, 3, 3, pos); | |||||
| size_t m = CalcIndex(param->in_shape_, 2, 4, pos); | |||||
| size_t n = CalcIndex(param->in_shape_, 1, 5, pos); | |||||
| size_t o = CalcIndex(param->in_shape_, 0, 6, pos); | |||||
| size_t input_idx = | |||||
| (i * s[0] + b[0]) * dx_shape[1] * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] + | |||||
| (j * s[1] + b[1]) * dx_shape[2] * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] + | |||||
| (k * s[2] + b[2]) * dx_shape[3] * dx_shape[4] * dx_shape[5] * dx_shape[6] + | |||||
| (l * s[3] + b[3]) * dx_shape[4] * dx_shape[5] * dx_shape[6] + (m * s[4] + b[4]) * dx_shape[5] * dx_shape[6] + | |||||
| (n * s[5] + b[5]) * dx_shape[6] + (o * s[6] + b[6]); | |||||
| output[input_idx] = inputs[pos]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/strided_slice_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, StridedSliceParameter *param); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ | |||||
| @@ -53,6 +53,7 @@ | |||||
| #define DIMENSION_4D 4 | #define DIMENSION_4D 4 | ||||
| #define DIMENSION_6D 6 | #define DIMENSION_6D 6 | ||||
| #define DIMENSION_7D 7 | |||||
| #define kInputIndex 0 | #define kInputIndex 0 | ||||
| #define kWeightIndex 1 | #define kWeightIndex 1 | ||||
| #define kBiasIndex 2 | #define kBiasIndex 2 | ||||
| @@ -273,6 +273,7 @@ union PrimitiveType { | |||||
| RandomStandardNormal, | RandomStandardNormal, | ||||
| CropAndResize, | CropAndResize, | ||||
| Erf, | Erf, | ||||
| StridedSliceGrad | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -1259,6 +1259,18 @@ table RandomStandardNormal { | |||||
| table CropAndResize { | table CropAndResize { | ||||
| method : ResizeMethod; | method : ResizeMethod; | ||||
| extrapolation_value : float; | extrapolation_value : float; | ||||
| } | |||||
| table StridedSliceGrad { | |||||
| beginMask: int; | |||||
| endMask: int; | |||||
| ellipsisMask: int; | |||||
| newAxisMask: int; | |||||
| shrinkAxisMask: int; | |||||
| begin: [int]; | |||||
| end: [int]; | |||||
| stride: [int]; | |||||
| isScale: [int]; | |||||
| } | } | ||||
| table Erf { | table Erf { | ||||
| @@ -31,7 +31,7 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> | |||||
| MS_LOG(ERROR) << "FlattenGrad input or output is null!"; | MS_LOG(ERROR) << "FlattenGrad input or output is null!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { | |||||
| if (inputs_.size() != kDoubleNum || outputs_.size() != kSingleNum) { | |||||
| MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); | MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); | ||||
| return RET_INPUT_TENSOR_ERROR; | return RET_INPUT_TENSOR_ERROR; | ||||
| } | } | ||||
| @@ -42,16 +42,15 @@ int FlattenGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> | |||||
| return RET_INFER_INVALID; | return RET_INFER_INVALID; | ||||
| } | } | ||||
| auto input_shape = input->shape(); | |||||
| std::vector<int> output_shape(2); | |||||
| output_shape.at(0) = input_shape.at(0); | |||||
| output_shape.at(1) = 1; | |||||
| for (size_t i = 1; i < input_shape.size(); i++) { | |||||
| output_shape.at(1) *= input_shape.at(i); | |||||
| auto output_size = inputs_.at(1)->shape().at(0); | |||||
| std::vector<int> output_shape(output_size); | |||||
| for (int i = 0; i < output_size; i++) { | |||||
| output_shape.at(i) = static_cast<int *>(inputs_.at(1)->data_c())[i]; | |||||
| } | } | ||||
| output->set_shape(output_shape); | output->set_shape(output_shape); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | int FlattenGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | ||||
| if (this->primitive_ == nullptr) { | if (this->primitive_ == nullptr) { | ||||
| @@ -91,6 +91,8 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | attr->poolingMode = schema::PoolMode_MEAN_POOLING; | ||||
| } else if (prim.instance_name() == "AvgPoolGradGpu") { | } else if (prim.instance_name() == "AvgPoolGradGpu") { | ||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | attr->poolingMode = schema::PoolMode_MEAN_POOLING; | ||||
| } else if (prim.instance_name() == "AvgPoolGradCpu") { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||||
| } else { | } else { | ||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | attr->poolingMode = schema::PoolMode_MAX_POOLING; | ||||
| } | } | ||||
| @@ -202,6 +202,7 @@ | |||||
| #include "src/ops/smooth_l1_loss_grad.h" | #include "src/ops/smooth_l1_loss_grad.h" | ||||
| #include "src/ops/sigmoid_cross_entropy_with_logits.h" | #include "src/ops/sigmoid_cross_entropy_with_logits.h" | ||||
| #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" | #include "src/ops/sigmoid_cross_entropy_with_logits_grad.h" | ||||
| #include "src/ops/strided_slice_grad.h" | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -724,6 +725,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType); | return NewPrimitiveC<SigmoidCrossEntropyWithLogitsGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "Pad") { | } else if (op_type == "Pad") { | ||||
| return NewPrimitiveC<Pad>(prim, inputs, quantType); | return NewPrimitiveC<Pad>(prim, inputs, quantType); | ||||
| } else if (op_type == "StridedSliceGrad") { | |||||
| return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType); | |||||
| #else | #else | ||||
| } else if (op_type == "Conv2DBackpropInput") { | } else if (op_type == "Conv2DBackpropInput") { | ||||
| return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | return NewPrimitiveC<DeConv2D>(prim, inputs, quantType); | ||||
| @@ -1102,6 +1105,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive); | return new (std::nothrow) SigmoidCrossEntropyWithLogits(primitive); | ||||
| case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad: | case schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad: | ||||
| return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); | return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive); | ||||
| case schema::PrimitiveType_StridedSliceGrad: | |||||
| return new (std::nothrow) StridedSliceGrad(primitive); | |||||
| #endif | #endif | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); | ||||
| @@ -0,0 +1,266 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/strided_slice_grad.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value.AsStridedSliceGrad()->beginMask; } | |||||
| int StridedSliceGrad::GetEndMask() const { return this->primitive_->value.AsStridedSliceGrad()->endMask; } | |||||
| int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value.AsStridedSliceGrad()->ellipsisMask; } | |||||
| int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->newAxisMask; } | |||||
| int StridedSliceGrad::GetShrinkAxisMask() const { return this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask; } | |||||
| std::vector<int> StridedSliceGrad::GetBegin() const { return this->primitive_->value.AsStridedSliceGrad()->begin; } | |||||
| std::vector<int> StridedSliceGrad::GetEnd() const { return this->primitive_->value.AsStridedSliceGrad()->end; } | |||||
| std::vector<int> StridedSliceGrad::GetStride() const { return this->primitive_->value.AsStridedSliceGrad()->stride; } | |||||
| std::vector<int> StridedSliceGrad::GetIsScale() const { return this->primitive_->value.AsStridedSliceGrad()->isScale; } | |||||
| void StridedSliceGrad::SetBeginMask(int begin_mask) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->beginMask = begin_mask; | |||||
| } | |||||
| void StridedSliceGrad::SetEndMask(int end_mask) { this->primitive_->value.AsStridedSliceGrad()->endMask = end_mask; } | |||||
| void StridedSliceGrad::SetEllipsisMask(int ellipsis_mask) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->ellipsisMask = ellipsis_mask; | |||||
| } | |||||
| void StridedSliceGrad::SetNewAxisMask(int new_axis_mask) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->newAxisMask = new_axis_mask; | |||||
| } | |||||
| void StridedSliceGrad::SetShrinkAxisMask(int shrink_axis_mask) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->shrinkAxisMask = shrink_axis_mask; | |||||
| } | |||||
| void StridedSliceGrad::SetBegin(const std::vector<int> &begin) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->begin = begin; | |||||
| } | |||||
| void StridedSliceGrad::SetEnd(const std::vector<int> &end) { this->primitive_->value.AsStridedSliceGrad()->end = end; } | |||||
| void StridedSliceGrad::SetStride(const std::vector<int> &stride) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->stride = stride; | |||||
| } | |||||
| void StridedSliceGrad::SetIsScale(const std::vector<int> &is_scale) { | |||||
| this->primitive_->value.AsStridedSliceGrad()->isScale = is_scale; | |||||
| } | |||||
| int StridedSliceGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_StridedSliceGrad; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_StridedSliceGrad) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| auto attr = new (std::nothrow) schema::StridedSliceGradT(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new StridedSliceGrad failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->beginMask = CastToInt(prim.GetAttr("begin_mask")).front(); | |||||
| attr->endMask = CastToInt(prim.GetAttr("end_mask")).front(); | |||||
| attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask")).front(); | |||||
| attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask")).front(); | |||||
| attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask")).front(); | |||||
| auto inputNodeFirst = inputs[kAnfPopulaterInputNumOne]; | |||||
| std::vector<int> beginVec; | |||||
| GetAttrDataFromInput(inputNodeFirst, &beginVec); | |||||
| attr->begin = beginVec; | |||||
| auto inputNodeSecond = inputs[kAnfPopulaterInputNumTwo]; | |||||
| std::vector<int> endVec; | |||||
| GetAttrDataFromInput(inputNodeSecond, &endVec); | |||||
| attr->end = endVec; | |||||
| auto inputNodeThird = inputs[kAnfPopulaterInputNumThree]; | |||||
| std::vector<int> strideVec; | |||||
| GetAttrDataFromInput(inputNodeThird, &strideVec); | |||||
| attr->stride = strideVec; | |||||
| this->primitive_->value.value = attr; | |||||
| if (this->primitive_->value.value == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #else | |||||
| int StridedSliceGrad::GetBeginMask() const { return this->primitive_->value_as_StridedSliceGrad()->beginMask(); } | |||||
| int StridedSliceGrad::GetEndMask() const { return this->primitive_->value_as_StridedSliceGrad()->endMask(); } | |||||
| int StridedSliceGrad::GetEllipsisMask() const { return this->primitive_->value_as_StridedSliceGrad()->ellipsisMask(); } | |||||
| int StridedSliceGrad::GetNewAxisMask() const { return this->primitive_->value_as_StridedSliceGrad()->newAxisMask(); } | |||||
| int StridedSliceGrad::GetShrinkAxisMask() const { | |||||
| return this->primitive_->value_as_StridedSliceGrad()->shrinkAxisMask(); | |||||
| } | |||||
| std::vector<int> StridedSliceGrad::GetBegin() const { | |||||
| auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->begin(); | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| std::vector<int> StridedSliceGrad::GetEnd() const { | |||||
| auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->end(); | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| std::vector<int> StridedSliceGrad::GetStride() const { | |||||
| auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->stride(); | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| std::vector<int> StridedSliceGrad::GetIsScale() const { | |||||
| auto fb_vector = this->primitive_->value_as_StridedSliceGrad()->isScale(); | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| int StridedSliceGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_StridedSliceGrad(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_StridedSliceGrad return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int32_t> begin; | |||||
| if (attr->begin() != nullptr) { | |||||
| for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) { | |||||
| begin.push_back(attr->begin()->data()[i]); | |||||
| } | |||||
| } | |||||
| std::vector<int32_t> end; | |||||
| if (attr->end() != nullptr) { | |||||
| for (int i = 0; i < static_cast<int>(attr->end()->size()); i++) { | |||||
| end.push_back(attr->end()->data()[i]); | |||||
| } | |||||
| } | |||||
| std::vector<int32_t> stride; | |||||
| if (attr->stride() != nullptr) { | |||||
| for (int i = 0; i < static_cast<int>(attr->stride()->size()); i++) { | |||||
| stride.push_back(attr->stride()->data()[i]); | |||||
| } | |||||
| } | |||||
| std::vector<int32_t> isScale; | |||||
| if (attr->isScale() != nullptr) { | |||||
| for (int i = 0; i < static_cast<int>(attr->isScale()->size()); i++) { | |||||
| isScale.push_back(attr->isScale()->data()[i]); | |||||
| } | |||||
| } | |||||
| auto val_offset = | |||||
| schema::CreateStridedSliceGradDirect(*fbb, attr->beginMask(), attr->endMask(), attr->ellipsisMask(), | |||||
| attr->newAxisMask(), attr->shrinkAxisMask(), &begin, &end, &stride, &isScale); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_StridedSliceGrad, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *StridedSliceGradCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<StridedSliceGrad>(primitive); | |||||
| } | |||||
| Registry StridedSliceGradRegistry(schema::PrimitiveType_StridedSliceGrad, StridedSliceGradCreator); | |||||
| #endif | |||||
| namespace { | |||||
| constexpr size_t kStridedSliceGradOutputNum = 1; | |||||
| constexpr size_t kStridedSliceGradMultiInputNumMax = 5; | |||||
| } // namespace | |||||
| int StridedSliceGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) { | |||||
| MS_ASSERT(this->primitive_ != nullptr); | |||||
| if (outputs.size() != kStridedSliceGradOutputNum) { | |||||
| MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (inputs.size() != kStridedSliceGradMultiInputNumMax) { | |||||
| MS_LOG(ERROR) << "Invalid input size " << inputs.size(); | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto input = inputs.at(0); | |||||
| outputs.front()->set_data_type(input->data_type()); | |||||
| outputs.at(0)->set_format(input->format()); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto input_shape = input->shape(); | |||||
| auto inferflag = infer_flag(); | |||||
| in_shape_.clear(); | |||||
| if (inferflag) { | |||||
| in_shape_.assign(input_shape.begin(), input_shape.end()); | |||||
| } | |||||
| begins_.clear(); | |||||
| ends_.clear(); | |||||
| strides_.clear(); | |||||
| if (!CheckInputs(inputs)) { | |||||
| MS_LOG(DEBUG) << "Do infer shape in runtime."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| // input order: dy, shapex, begins, ends, strides. | |||||
| auto begin_tensor = inputs.at(2); | |||||
| int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData()); | |||||
| auto end_tensor = inputs.at(3); | |||||
| int *end_data = reinterpret_cast<int *>(end_tensor->MutableData()); | |||||
| auto stride_tensor = inputs.at(4); | |||||
| int *stride_data = reinterpret_cast<int *>(stride_tensor->MutableData()); | |||||
| if (begin_data == nullptr || end_data == nullptr || stride_data == nullptr) { | |||||
| return RET_INFER_ERR; | |||||
| } | |||||
| ndim_ = begin_tensor->ElementsNum(); | |||||
| for (size_t i = 0; i < ndim_; ++i) { | |||||
| begins_.emplace_back(begin_data[i]); | |||||
| ends_.emplace_back(end_data[i]); | |||||
| strides_.emplace_back(stride_data[i]); | |||||
| } | |||||
| // set all mask to original input shape | |||||
| begins_mask_.resize(ndim_); | |||||
| ends_mask_.resize(ndim_); | |||||
| ellipsis_mask_.resize(ndim_); | |||||
| new_axis_mask_.resize(ndim_); | |||||
| shrink_axis_mask_.resize(ndim_); | |||||
| for (size_t i = 0; i < ndim_; i++) { | |||||
| begins_mask_.at(i) = static_cast<uint32_t>(GetBeginMask()) & (1 << i); | |||||
| ends_mask_.at(i) = static_cast<uint32_t>(GetEndMask()) & (1 << i); | |||||
| ellipsis_mask_.at(i) = static_cast<uint32_t>(GetEllipsisMask()) & (1 << i); | |||||
| new_axis_mask_.at(i) = static_cast<uint32_t>(GetNewAxisMask()) & (1 << i); | |||||
| shrink_axis_mask_.at(i) = static_cast<uint32_t>(GetShrinkAxisMask()) & (1 << i); | |||||
| } | |||||
| ApplyNewAxisMask(); | |||||
| ApplyBeginMask(); | |||||
| ApplyEndMask(); | |||||
| ApplyEllipsisMask(); | |||||
| if (!inferflag) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto output_size = inputs.at(1)->shape().at(0); | |||||
| std::vector<int> output_shape; | |||||
| MS_ASSERT(inputs.at(1)->MutableData() != nullptr); | |||||
| for (int i = 0; i < output_size; i++) { | |||||
| output_shape.push_back(static_cast<int *>(inputs.at(1)->MutableData())[i]); | |||||
| } | |||||
| outputs.front()->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/strided_slice.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class StridedSliceGrad : public StridedSlice { | |||||
| public: | |||||
| StridedSliceGrad() = default; | |||||
| ~StridedSliceGrad() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(StridedSliceGrad, StridedSlice); | |||||
| explicit StridedSliceGrad(schema::PrimitiveT *primitive) : StridedSlice(primitive) {} | |||||
| void SetBeginMask(int begin_mask); | |||||
| void SetEndMask(int end_mask); | |||||
| void SetEllipsisMask(int ellipsis_mask); | |||||
| void SetNewAxisMask(int new_axis_mask); | |||||
| void SetShrinkAxisMask(int shrink_axis_mask); | |||||
| void SetBegin(const std::vector<int> &begin); | |||||
| void SetEnd(const std::vector<int> &end); | |||||
| void SetStride(const std::vector<int> &stride); | |||||
| void SetIsScale(const std::vector<int> &is_scale); | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs); | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| // bool CheckInputs(std::vector<lite::Tensor *> inputs_); | |||||
| int GetBeginMask() const; | |||||
| int GetEndMask() const; | |||||
| int GetEllipsisMask() const; | |||||
| int GetNewAxisMask() const; | |||||
| int GetShrinkAxisMask() const; | |||||
| std::vector<int> GetBegin() const; | |||||
| std::vector<int> GetEnd() const; | |||||
| std::vector<int> GetStride() const; | |||||
| std::vector<int> GetIsScale() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_OPS_STRIDED_SLICE_GRAD_H_ | |||||
| @@ -91,10 +91,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { | |||||
| // init bias | // init bias | ||||
| size_t new_bias_size = oc4 * C4NUM * sizeof(float); | size_t new_bias_size = oc4 * C4NUM * sizeof(float); | ||||
| bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size)); | |||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||||
| return RET_MEMORY_FAILED; | |||||
| bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size)); | |||||
| if (bias_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| } | } | ||||
| memset(bias_data_, 0, new_bias_size); | memset(bias_data_, 0, new_bias_size); | ||||
| if (in_tensors_.size() == kInputSize2) { | if (in_tensors_.size() == kInputSize2) { | ||||
| @@ -91,10 +91,6 @@ int FusedBatchnormCPUKernel::Run() { | |||||
| memcpy(scale_, scale, in_tensors_[1]->Size()); | memcpy(scale_, scale, in_tensors_[1]->Size()); | ||||
| memcpy(offset_, offset, in_tensors_[2]->Size()); | memcpy(offset_, offset, in_tensors_[2]->Size()); | ||||
| // save for next iteration | |||||
| memcpy(in_tensors_[3]->MutableData(), save_mean, in_tensors_[3]->Size()); | |||||
| memcpy(in_tensors_[4]->MutableData(), save_variance, in_tensors_[4]->Size()); | |||||
| trained_ = true; // trained at least once | trained_ = true; // trained at least once | ||||
| } | } | ||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_); | auto ret = ParallelLaunch(this->context_->thread_pool_, BatchNormRun, this, op_parameter_->thread_num_); | ||||
| @@ -40,17 +40,16 @@ int ApplyMomentumCPUKernel::Execute(int task_id) { | |||||
| size_t stride = UP_DIV(length, thread_count_); | size_t stride = UP_DIV(length, thread_count_); | ||||
| size_t count = MSMIN(stride, length - stride * task_id); | size_t count = MSMIN(stride, length - stride * task_id); | ||||
| size_t start = stride * task_id; | size_t start = stride * task_id; | ||||
| size_t end = start + count; | size_t end = start + count; | ||||
| if (apply_momentum_param_->use_nesterov_) { | if (apply_momentum_param_->use_nesterov_) { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | accumulate[i] = accumulate[i] * moment + gradient[i]; | ||||
| weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (size_t i = start; i < end; ++i) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| accumulate[i] = accumulate[i] * moment + gradient[i]; | accumulate[i] = accumulate[i] * moment + gradient[i]; | ||||
| weight[i] -= accumulate[i] * learning_rate; | weight[i] -= accumulate[i] * learning_rate; | ||||
| } | } | ||||
| @@ -18,6 +18,10 @@ | |||||
| #include <math.h> | #include <math.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <vector> | #include <vector> | ||||
| #include <thread> | |||||
| #include <fstream> | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "nnacl/fp32_grad/batch_norm.h" | #include "nnacl/fp32_grad/batch_norm.h" | ||||
| @@ -34,7 +38,8 @@ namespace mindspore::kernel { | |||||
| int BNGradCPUKernel::ReSize() { | int BNGradCPUKernel::ReSize() { | ||||
| auto *input_x = in_tensors_.at(1); | auto *input_x = in_tensors_.at(1); | ||||
| int channels = input_x->shape().at(kNHWC_C); | int channels = input_x->shape().at(kNHWC_C); | ||||
| set_workspace_size(2 * channels * sizeof(float)); | |||||
| ws_size_ = 2 * channels; | |||||
| set_workspace_size(ws_size_ * sizeof(float)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -46,7 +51,9 @@ int BNGradCPUKernel::Execute(int task_id) { | |||||
| auto *input_scale = in_tensors_.at(2); | auto *input_scale = in_tensors_.at(2); | ||||
| auto *input_mean = in_tensors_.at(3); | auto *input_mean = in_tensors_.at(3); | ||||
| auto *input_var = in_tensors_.at(4); | auto *input_var = in_tensors_.at(4); | ||||
| auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_); | |||||
| int stage = stage_; | |||||
| int thread_num = thread_num_; | |||||
| float *save_mean = reinterpret_cast<float *>(input_mean->MutableData()); | float *save_mean = reinterpret_cast<float *>(input_mean->MutableData()); | ||||
| float *save_var = reinterpret_cast<float *>(input_var->MutableData()); | float *save_var = reinterpret_cast<float *>(input_var->MutableData()); | ||||
| @@ -58,26 +65,57 @@ int BNGradCPUKernel::Execute(int task_id) { | |||||
| int32_t spatial = input_x->Height() * input_x->Width(); | int32_t spatial = input_x->Height() * input_x->Width(); | ||||
| float *workspace_temp = static_cast<float *>(workspace()); | float *workspace_temp = static_cast<float *>(workspace()); | ||||
| std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f); | |||||
| float *dxhat_sum = workspace_temp; | float *dxhat_sum = workspace_temp; | ||||
| float *dxhathat_sum = dxhat_sum + channels; | float *dxhathat_sum = dxhat_sum + channels; | ||||
| float *x = reinterpret_cast<float *>(input_x->MutableData()); | float *x = reinterpret_cast<float *>(input_x->MutableData()); | ||||
| float *yt = reinterpret_cast<float *>(input_yt->MutableData()); | float *yt = reinterpret_cast<float *>(input_yt->MutableData()); | ||||
| float *scale = reinterpret_cast<float *>(input_scale->MutableData()); | float *scale = reinterpret_cast<float *>(input_scale->MutableData()); | ||||
| float *dx = reinterpret_cast<float *>(output_dx->MutableData()); | float *dx = reinterpret_cast<float *>(output_dx->MutableData()); | ||||
| float *dbias = reinterpret_cast<float *>(output_bias->MutableData()); | float *dbias = reinterpret_cast<float *>(output_bias->MutableData()); | ||||
| float *dscale = reinterpret_cast<float *>(output_scale->MutableData()); | float *dscale = reinterpret_cast<float *>(output_scale->MutableData()); | ||||
| std::fill(dbias, dbias + channels, 0.f); | |||||
| std::fill(dscale, dscale + channels, 0.f); | |||||
| backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); | |||||
| int total = spatial * batch; | |||||
| int stride = UP_DIV(total, thread_num); | |||||
| int count = MSMIN(stride, total - stride * task_id); | |||||
| switch (stage) { | |||||
| case 0: { | |||||
| for (int job = task_id; job < 4; job += thread_num) { | |||||
| switch (job) { | |||||
| case 0: | |||||
| var2Invar(save_var, input_var->ElementsNum(), bn_param->epsilon_); | |||||
| break; | |||||
| case 1: | |||||
| std::fill(workspace_temp, workspace_temp + ws_size_, 0.f); | |||||
| break; | |||||
| case 2: | |||||
| std::fill(dbias, dbias + channels, 0.f); | |||||
| break; | |||||
| case 3: | |||||
| std::fill(dscale, dscale + channels, 0.f); | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (thread_num == 1) { | |||||
| backwardAll(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case 1: { | |||||
| backwardP1(x, yt, save_mean, save_var, scale, total, channels, dxhat_sum, dxhathat_sum, dbias, dscale); | |||||
| break; | |||||
| } | |||||
| case 2: { | |||||
| backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, scale, count, | |||||
| total, channels, dxhat_sum, dxhathat_sum, dx + task_id * stride * channels); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int BNGradRun(void *cdata, int task_id) { | int BNGradRun(void *cdata, int task_id) { | ||||
| MS_ASSERT(cdata != nullptr); | MS_ASSERT(cdata != nullptr); | ||||
| auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata); | auto bn_kernel = reinterpret_cast<BNGradCPUKernel *>(cdata); | ||||
| auto error_code = bn_kernel->Execute(task_id); | auto error_code = bn_kernel->Execute(task_id); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | ||||
| @@ -87,15 +125,24 @@ int BNGradRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int BNGradCPUKernel::Run() { | int BNGradCPUKernel::Run() { | ||||
| auto *input_var = in_tensors_.at(4); | |||||
| float *save_var = reinterpret_cast<float *>(input_var->MutableData()); | |||||
| auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_); | |||||
| float eps = bn_param->epsilon_; | |||||
| var2Invar(save_var, input_var->ElementsNum(), eps); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| stage_ = 0; | |||||
| thread_num_ = context_->thread_num_; | |||||
| if (thread_num_ == 1) { | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, thread_num_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | |||||
| const std::vector<int> threads = {thread_num_, 1, thread_num_}; | |||||
| for (size_t stage = 0; stage < threads.size(); stage++) { | |||||
| stage_ = static_cast<int>(stage); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, threads.at(stage)); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -33,6 +33,11 @@ class BNGradCPUKernel : public LiteKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | |||||
| int thread_num_ = 1; | |||||
| int stage_ = 0; | |||||
| size_t ws_size_ = 0; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_BN_GRAD_H_ | ||||
| @@ -54,9 +54,6 @@ int ConvolutionTrainCPUKernel::ReSize() { | |||||
| conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; | conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; | ||||
| const int n = conv_param_->output_channel_ * conv_param_->group_; | const int n = conv_param_->output_channel_ * conv_param_->group_; | ||||
| const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; | const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; | ||||
| ws_size_ = chunk_ * k; | |||||
| int mat_alloc = MatSizeTotal(chunk_, n, k, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); | |||||
| do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && | do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && | ||||
| (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && | (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && | ||||
| @@ -64,6 +61,16 @@ int ConvolutionTrainCPUKernel::ReSize() { | |||||
| (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) | (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) | ||||
| ? false | ? false | ||||
| : true; | : true; | ||||
| do_dw_ = (conv_param_->output_channel_ == conv_param_->group_) && | |||||
| (conv_param_->input_channel_ == conv_param_->output_channel_) && (conv_param_->dilation_h_ == 1) && | |||||
| (conv_param_->dilation_w_ == 1) | |||||
| ? true | |||||
| : false; | |||||
| ws_size_ = chunk_ * conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_; | |||||
| ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param_->group_; | |||||
| int mat_alloc = MatSizeTotal(chunk_, n, k, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -97,7 +104,25 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) { | |||||
| float *workspace_temp = static_cast<float *>(workspace()); | float *workspace_temp = static_cast<float *>(workspace()); | ||||
| float *mat_workspace = workspace_temp + ws_size_; | float *mat_workspace = workspace_temp + ws_size_; | ||||
| if (do_img2col_) { | |||||
| if (do_dw_) { | |||||
| const int kernel_spatial = k_h * k_w; | |||||
| for (int i = 0; i < batch; ++i) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_a = workspace_temp; | |||||
| float *im = x_addr + (i * in_ch * in_h * in_w); | |||||
| RollingIm2ColPackDwUnitFp32(im, conv_param_, mat_a, real_chunk, ci); | |||||
| for (int j = 0; j < groups; ++j) { | |||||
| const float *mat_b = w_addr + j * nweights / groups; | |||||
| float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch; | |||||
| // float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups); | |||||
| // RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci); | |||||
| GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a + (j * kernel_spatial), k * groups, mat_b, k, 0, mat_c, out_ch, | |||||
| mat_workspace); | |||||
| } | |||||
| } | |||||
| } | |||||
| } else if (do_img2col_) { | |||||
| for (int i = 0; i < batch; ++i) { | for (int i = 0; i < batch; ++i) { | ||||
| for (int j = 0; j < groups; ++j) { | for (int j = 0; j < groups; ++j) { | ||||
| for (int ci = 0; ci < m; ci += chunk_) { | for (int ci = 0; ci < m; ci += chunk_) { | ||||
| @@ -37,6 +37,7 @@ class ConvolutionTrainCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| int ws_size_ = 0; | int ws_size_ = 0; | ||||
| bool do_img2col_ = true; | bool do_img2col_ = true; | ||||
| bool do_dw_ = false; | |||||
| #ifdef ENABLE_ARM32 | #ifdef ENABLE_ARM32 | ||||
| const int chunk_ = C4NUM * 2; | const int chunk_ = C4NUM * 2; | ||||
| #else | #else | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" | #include "src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "nnacl/pack.h" | #include "nnacl/pack.h" | ||||
| #include "nnacl/fp32_grad/convolution_grad_filter.h" | |||||
| #include "nnacl/fp32_grad/pack_ext.h" | #include "nnacl/fp32_grad/pack_ext.h" | ||||
| #include "nnacl/fp32_grad/gemm.h" | #include "nnacl/fp32_grad/gemm.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -51,20 +52,25 @@ int ConvolutionGradFilterCPUKernel::ReSize() { | |||||
| conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; | ||||
| conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; | ||||
| ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| int k = conv_param->output_channel_ / conv_param->group_; | |||||
| int thread_num = context_->thread_num_; | |||||
| mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); | |||||
| do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && | do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && | ||||
| (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && | (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && | ||||
| (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && | (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && | ||||
| (conv_param->stride_w_ == 1) && (conv_param->group_ == 1) | (conv_param->stride_w_ == 1) && (conv_param->group_ == 1) | ||||
| ? false | ? false | ||||
| : true; | : true; | ||||
| do_dw_ = (conv_param->output_channel_ == conv_param->group_) && | |||||
| (conv_param->input_channel_ == conv_param->output_channel_) && (conv_param->dilation_h_ == 1) && | |||||
| (conv_param->dilation_w_ == 1) | |||||
| ? true | |||||
| : false; | |||||
| ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||||
| ws_size_ = do_dw_ ? ws_size_ : ws_size_ / conv_param->group_; | |||||
| int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; | |||||
| int k = conv_param->output_channel_ / conv_param->group_; | |||||
| int thread_num = context_->thread_num_; | |||||
| mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); | |||||
| set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -105,10 +111,38 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { | |||||
| int start = stride * task_id; | int start = stride * task_id; | ||||
| int end = start + count; | int end = start + count; | ||||
| if (do_img2col_) { | |||||
| if (do_dw_) { | |||||
| #ifdef ENABLE_ARM | |||||
| stride = UP_DIV(k_h * k_w, thread_num); | |||||
| count = MSMIN(stride, k_h * k_w - stride * task_id); | |||||
| start = stride * task_id; | |||||
| ConvDwFilterGrad(x_addr, dy_addr, dw_addr, start, count, conv_param); | |||||
| #else | |||||
| stride = UP_DIV(groups, thread_num); | |||||
| count = MSMIN(stride, groups - stride * task_id); | |||||
| start = stride * task_id; | |||||
| end = start + count; | |||||
| const int kernel_spatial = k_h * k_w; | |||||
| for (int i = 0; i < batch; ++i) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | |||||
| float *mat_b = workspace_temp + task_id * ws_size_; | |||||
| float *im = x_addr + (i * in_ch * in_h * in_w); | |||||
| RollingIm2ColPackDwUnitFp32(im, conv_param, mat_b, real_chunk, ci); | |||||
| for (int j = start; j < end; ++j) { | |||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | |||||
| float *mat_c = dw_addr + j * nweights / groups; | |||||
| GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b + (j * kernel_spatial), n * groups, 1, mat_c, n, | |||||
| mat_workspace); | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } else if (do_img2col_) { | |||||
| for (int i = start; i < end; ++i) { | for (int i = start; i < end; ++i) { | ||||
| for (int j = 0; j < groups; ++j) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| for (int ci = 0; ci < m; ci += chunk_) { | |||||
| for (int j = 0; j < groups; ++j) { | |||||
| int real_chunk = MSMIN(m - ci, chunk_); | int real_chunk = MSMIN(m - ci, chunk_); | ||||
| float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; | ||||
| float *mat_b = workspace_temp + task_id * ws_size_; | float *mat_b = workspace_temp + task_id * ws_size_; | ||||
| @@ -38,6 +38,7 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| size_t ws_size_ = 0; | size_t ws_size_ = 0; | ||||
| bool do_img2col_ = true; | bool do_img2col_ = true; | ||||
| bool do_dw_ = false; | |||||
| std::mutex lock_; | std::mutex lock_; | ||||
| size_t mat_alloc_ = 0; | size_t mat_alloc_ = 0; | ||||
| #ifdef ENABLE_ARM32 | #ifdef ENABLE_ARM32 | ||||
| @@ -66,13 +66,20 @@ int PoolingGradCPUKernel::Execute(int task_id) { | |||||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | ||||
| int stride = UP_DIV(pool_param->output_batch_, thread_num_); | |||||
| int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); | |||||
| int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; | |||||
| int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; | |||||
| std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size, | |||||
| 0.f); | |||||
| if (pool_param->pool_mode_ == PoolMode_MaxPool) { | if (pool_param->pool_mode_ == PoolMode_MaxPool) { | ||||
| auto dx_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||||
| auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | ||||
| MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); | |||||
| MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size, | |||||
| output_ptr + task_id * stride * in_batch_size, count, pool_param); | |||||
| } else { | } else { | ||||
| input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | ||||
| AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); | |||||
| AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, count, | |||||
| pool_param); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -89,7 +96,8 @@ int PoolingGradImpl(void *cdata, int task_id) { | |||||
| } | } | ||||
| int PoolingGradCPUKernel::Run() { | int PoolingGradCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); | |||||
| thread_num_ = context_->thread_num_; | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, thread_num_); | |||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -40,6 +40,7 @@ class PoolingGradCPUKernel : public LiteKernel { | |||||
| int Execute(int task_id); | int Execute(int task_id); | ||||
| private: | private: | ||||
| int thread_num_ = 1; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -0,0 +1,150 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32_grad/strided_slice_grad.h" | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "nnacl/fp32_grad/strided_slice_grad.h" | |||||
| #include "src/ops/populate/strided_slice_populate.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_StridedSliceGrad; | |||||
| namespace mindspore::kernel { | |||||
| int StridedSliceGradCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| param_ = reinterpret_cast<StridedSliceParameter *>(op_parameter_); | |||||
| auto input = in_tensors_.at(0); | |||||
| MS_ASSERT(input); | |||||
| switch (input->data_type()) { | |||||
| case kNumberTypeFloat32: | |||||
| param_->data_type = kDataTypeFloat; | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| FillEmptyDims(); | |||||
| FillOutputDim(); | |||||
| return ReSize(); | |||||
| } | |||||
| void StridedSliceGradCPUKernel::FillEmptyDims() { | |||||
| int32_t begins[DIMENSION_7D]; | |||||
| int32_t ends[DIMENSION_7D]; | |||||
| int32_t strides[DIMENSION_7D]; | |||||
| int32_t input_shape[DIMENSION_7D]; | |||||
| int32_t i; | |||||
| for (i = 0; i < param_->num_axes_; ++i) { | |||||
| begins[i] = param_->begins_[i]; | |||||
| ends[i] = MSMIN(param_->ends_[i], param_->in_shape_[i]); | |||||
| strides[i] = param_->strides_[i]; | |||||
| input_shape[i] = param_->in_shape_[i]; | |||||
| } | |||||
| for (i = param_->num_axes_; i < param_->in_shape_length_; ++i) { | |||||
| input_shape[i] = param_->in_shape_[i]; | |||||
| begins[i] = 0; | |||||
| ends[i] = param_->in_shape_[i]; | |||||
| strides[i] = 1; | |||||
| } | |||||
| int32_t real_index = param_->in_shape_length_ - 1; | |||||
| for (i = DIMENSION_7D - 1; i >= 0; --i) { | |||||
| if (real_index >= 0) { | |||||
| param_->begins_[i] = begins[real_index]; | |||||
| param_->ends_[i] = ends[real_index]; | |||||
| param_->strides_[i] = strides[real_index]; | |||||
| param_->in_shape_[i] = input_shape[real_index--]; | |||||
| } else { | |||||
| param_->begins_[i] = 0; | |||||
| param_->ends_[i] = 1; | |||||
| param_->strides_[i] = 1; | |||||
| param_->in_shape_[i] = 1; | |||||
| } | |||||
| } | |||||
| param_->num_axes_ = DIMENSION_7D; | |||||
| param_->in_shape_length_ = DIMENSION_7D; | |||||
| for (i = 0; i < DIMENSION_7D; ++i) { | |||||
| if (param_->begins_[i] < 0) { | |||||
| param_->begins_[i] += param_->in_shape_[i]; | |||||
| } | |||||
| if (param_->ends_[i] < 0) { | |||||
| param_->ends_[i] += param_->in_shape_[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void StridedSliceGradCPUKernel::FillOutputDim() { | |||||
| auto output = out_tensors_.at(0); | |||||
| size_t out_size = output->shape().size(); | |||||
| for (size_t i = 0; i < DIMENSION_7D; i++) { | |||||
| if (i < out_size) { | |||||
| output_shape_.push_back(output->shape()[i]); | |||||
| } else { | |||||
| output_shape_.insert(output_shape_.begin(), 1); | |||||
| } | |||||
| } | |||||
| } | |||||
| int StridedSliceGradCPUKernel::ReSize() { return RET_OK; } | |||||
| int StridedSliceGradImpl(void *cdata, int task_id) { | |||||
| MS_ASSERT(cdata != nullptr); | |||||
| auto slice = reinterpret_cast<StridedSliceGradCPUKernel *>(cdata); | |||||
| auto error_code = slice->Execute(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "StridedSliceGrad Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int StridedSliceGradCPUKernel::Run() { | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, StridedSliceGradImpl, this, 1); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Strided slice error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int StridedSliceGradCPUKernel::Execute(int task_id) { | |||||
| auto input = in_tensors_.at(0); | |||||
| auto output = out_tensors_.at(0); | |||||
| MS_ASSERT(output); | |||||
| int *po = output_shape_.data(); | |||||
| auto ret = DoStridedSliceGrad(reinterpret_cast<float *>(input->MutableData()), | |||||
| reinterpret_cast<float *>(output->MutableData()), po, param_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_StridedSliceGrad, LiteKernelCreator<StridedSliceGradCPUKernel>) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_ | |||||
| #include <vector> | |||||
| #include "nnacl/fp32_grad/strided_slice_grad.h" | |||||
| #include "src/lite_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class StridedSliceGradCPUKernel : public LiteKernel { | |||||
| public: | |||||
| StridedSliceGradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| param_ = reinterpret_cast<StridedSliceParameter *>(parameter); | |||||
| } | |||||
| ~StridedSliceGradCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| int Execute(int task_id); | |||||
| private: | |||||
| void FillEmptyDims(); | |||||
| void FillOutputDim(); | |||||
| void ParseMasks(); | |||||
| StridedSliceParameter *param_; | |||||
| std::vector<int> output_shape_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_STRIDED_SLICE_GRAD_H_ | |||||
| @@ -49,6 +49,7 @@ | |||||
| #include "src/ops/smooth_l1_loss_grad.h" | #include "src/ops/smooth_l1_loss_grad.h" | ||||
| #include "nnacl/fp32_grad/smooth_l1_loss.h" | #include "nnacl/fp32_grad/smooth_l1_loss.h" | ||||
| #include "src/ops/arithmetic_grad.h" | #include "src/ops/arithmetic_grad.h" | ||||
| #include "src/ops/populate/strided_slice_populate.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *DefaultPopulateParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| @@ -569,6 +570,9 @@ void PopulateTrainParameters() { | |||||
| DefaultPopulateParameter); | DefaultPopulateParameter); | ||||
| lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad, | lite::Registry SigmoidCrossEntropyWithLogitsGradRegistry(schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad, | ||||
| DefaultPopulateParameter); | DefaultPopulateParameter); | ||||
| lite::Registry FlattenGradParameterRegistry(schema::PrimitiveType_FlattenGrad, DefaultPopulateParameter); | |||||
| lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, | |||||
| mindspore::lite::PopulateStridedSliceParameter); | |||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ | |||||
| #define MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include <unordered_map> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "include/train_session.h" | |||||
| #include "src/train/train_model.h" | |||||
| #include "src/lite_session.h" | |||||
| #include "src/train/train_session.h" | |||||
| /* | |||||
| Inheritance Diagram | |||||
| +-------------------------------+ | |||||
| | session::LiteSession | | |||||
| +--------+------------+---------+ | |||||
| / \ | |||||
| +-----------------+-----+ +-------+------------+ | |||||
| | session::TrainSession | | lite::LiteSession | | |||||
| +-----------------+-----+ +-------+------------+ | |||||
| \ / | |||||
| +--------+------------+---------+ | |||||
| | lite::TrainSession | | |||||
| +-------------------------------+ | |||||
| | | |||||
| +--------+------------+---------+ | |||||
| | lite::TrasferSession | | |||||
| +-------------------------------+ | |||||
| */ | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TransferSession : public lite::TrainSession { | |||||
| public: | |||||
| TransferSession(); | |||||
| explicit TransferSession(lite::LiteSession *backend_session); | |||||
| ~TransferSession(); | |||||
| int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; | |||||
| void BindThread(bool if_bind) override; | |||||
| std::vector<tensor::MSTensor *> GetInputs() const override { return lite::LiteSession::GetInputs(); } | |||||
| mindspore::tensor::MSTensor *GetInputsByTensorName(const std::string &tensor_name) const override { | |||||
| return lite::LiteSession::GetInputsByTensorName(tensor_name); | |||||
| } | |||||
| protected: | |||||
| lite::LiteSession *backend_session_; | |||||
| private: | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ | |||||
| @@ -1,13 +1,15 @@ | |||||
| mini_alexnet | mini_alexnet | ||||
| # mobilenetv1 | |||||
| mobilenetv1 | |||||
| mobilenetv2 | mobilenetv2 | ||||
| mobilenetv3 | mobilenetv3 | ||||
| lenet | lenet | ||||
| effnet | effnet | ||||
| # effnet_tune | |||||
| # lenetv1 | |||||
| # resnet | |||||
| # googlenet | |||||
| effnet_tune | |||||
| resnet | |||||
| googlenet | |||||
| # densenet | # densenet | ||||
| # shufflenetv2 | |||||
| # nin | |||||
| # one_net | # one_net | ||||
| # lenetv1 | |||||
| #LAST | #LAST | ||||
| @@ -0,0 +1,82 @@ | |||||
| #!/bin/bash | |||||
| # Print start msg after run testcase | |||||
| function MS_PRINT_TESTCASE_END_MSG() { | |||||
| echo -e "-----------------------------------------------------------------------------------------------------------------------------------" | |||||
| } | |||||
| function Print_Result() { | |||||
| MS_PRINT_TESTCASE_END_MSG | |||||
| while read line; do | |||||
| arr=("${line}") | |||||
| printf "%-15s %-20s %-90s %-7s\n" ${arr[0]} ${arr[1]} ${arr[2]} ${arr[3]} | |||||
| done < $1 | |||||
| MS_PRINT_TESTCASE_END_MSG | |||||
| } | |||||
| basepath=$(pwd) | |||||
| echo ${basepath} | |||||
| # Example:run_net_export.sh -m /home/emir/Work/TestingEnv/train_models | |||||
| epoch_num=1 | |||||
| while getopts "m:t:" opt; do | |||||
| case ${opt} in | |||||
| m) | |||||
| models_path=${OPTARG}"/models_train" | |||||
| echo "models_path is ${OPTARG}" | |||||
| ;; | |||||
| t) | |||||
| epoch_num=${OPTARG} | |||||
| echo "train epoch num is ${OPTARG}" | |||||
| ;; | |||||
| ?) | |||||
| echo "unknown para" | |||||
| exit 1;; | |||||
| esac | |||||
| done | |||||
| # Set models config filepath | |||||
| models_mindspore_train_config=${basepath}/models_ms_train.cfg | |||||
| logs_path=${basepath}/logs_train | |||||
| rm -rf ${logs_path} | |||||
| mkdir -p ${logs_path} | |||||
| docker_image=mindspore/mindspore-gpu:1.1.0 | |||||
| # Export models | |||||
| echo "Start Exporting models ..." | |||||
| # Set log files | |||||
| export_log_file=${logs_path}/export_log.txt | |||||
| echo ' ' > ${export_log_file} | |||||
| export_result_file=${logs_path}/export_result.txt | |||||
| echo ' ' > ${export_result_file} | |||||
| # Run export according to config file | |||||
| cd $models_path || exit 1 | |||||
| if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then | |||||
| echo "CLOUD_MODEL_ZOO is not defined - exiting export models" | |||||
| exit 1 | |||||
| fi | |||||
| # Export mindspore train models: | |||||
| while read line; do | |||||
| model_name=${line} | |||||
| if [[ $model_name == \#* ]]; then | |||||
| continue | |||||
| fi | |||||
| echo ${model_name}'_train_export.py' >> "${export_log_file}" | |||||
| echo 'exporting' ${model_name} | |||||
| echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" | |||||
| docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" | |||||
| if [ $? = 0 ]; then | |||||
| export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} | |||||
| else | |||||
| export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file} | |||||
| fi | |||||
| done < ${models_mindspore_train_config} | |||||
| Print_Result ${export_result_file} | |||||
| @@ -1,7 +1,7 @@ | |||||
| #!/bin/bash | #!/bin/bash | ||||
| # Run Export on x86 platform and create output test files: | # Run Export on x86 platform and create output test files: | ||||
| docker_image=mindspore_dev:8 | |||||
| docker_image= | |||||
| function Run_Export(){ | function Run_Export(){ | ||||
| cd $models_path || exit 1 | cd $models_path || exit 1 | ||||
| if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then | if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then | ||||
| @@ -16,8 +16,13 @@ function Run_Export(){ | |||||
| fi | fi | ||||
| echo ${model_name}'_train_export.py' >> "${export_log_file}" | echo ${model_name}'_train_export.py' >> "${export_log_file}" | ||||
| echo 'exporting' ${model_name} | echo 'exporting' ${model_name} | ||||
| echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" | |||||
| docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" | |||||
| if [ -n "$docker_image" ]; then | |||||
| echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" | |||||
| docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" | |||||
| else | |||||
| echo 'CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" | |||||
| CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" | |||||
| fi | |||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} | export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} | ||||
| else | else | ||||
| @@ -28,7 +33,7 @@ function Run_Export(){ | |||||
| # Run converter on x86 platform: | # Run converter on x86 platform: | ||||
| function Run_Converter() { | function Run_Converter() { | ||||
| # Unzip x86 runtime and convertor | |||||
| # Unzip x86 runtime and converter | |||||
| cd ${x86_path} || exit 1 | cd ${x86_path} || exit 1 | ||||
| tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1 | tar -zxf mindspore-lite-${version}-train-linux-x64.tar.gz || exit 1 | ||||
| @@ -189,7 +194,7 @@ ENDM | |||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} | run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} | ||||
| else | else | ||||
| run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; return 1 | |||||
| run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; | |||||
| fi | fi | ||||
| done < ${models_mindspore_train_config} | done < ${models_mindspore_train_config} | ||||
| } | } | ||||
| @@ -222,16 +227,15 @@ echo ${basepath} | |||||
| # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" | # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" | ||||
| # For running on arm64, use -t to set platform tools path (for using adb commands) | # For running on arm64, use -t to set platform tools path (for using adb commands) | ||||
| epoch_num=1 | epoch_num=1 | ||||
| threads=1 | |||||
| threads=2 | |||||
| train_io_path="" | train_io_path="" | ||||
| while getopts "r:m:d:i:e:vt:q:" opt; do | |||||
| while getopts "r:m:d:i:e:vt:q:D" opt; do | |||||
| case ${opt} in | case ${opt} in | ||||
| r) | r) | ||||
| release_path=${OPTARG} | release_path=${OPTARG} | ||||
| echo "release_path is ${OPTARG}" | echo "release_path is ${OPTARG}" | ||||
| ;; | ;; | ||||
| m) | m) | ||||
| models_path=${OPTARG}"/models_train" | models_path=${OPTARG}"/models_train" | ||||
| echo "models_path is ${OPTARG}" | echo "models_path is ${OPTARG}" | ||||
| ;; | ;; | ||||
| @@ -244,8 +248,9 @@ while getopts "r:m:d:i:e:vt:q:" opt; do | |||||
| echo "device_id is ${OPTARG}" | echo "device_id is ${OPTARG}" | ||||
| ;; | ;; | ||||
| e) | e) | ||||
| enable_export=${OPTARG} | |||||
| echo "enable_export = ${OPTARG}" | |||||
| enable_export=1 | |||||
| docker_image=${OPTARG} | |||||
| echo "enable_export = 1, docker_image = ${OPTARG}" | |||||
| ;; | ;; | ||||
| v) | v) | ||||
| run_valgrind="valgrind --log-file=valgrind.log " | run_valgrind="valgrind --log-file=valgrind.log " | ||||
| @@ -404,27 +409,27 @@ function Print_Benchmark_Result() { | |||||
| done < ${run_benchmark_train_result_file} | done < ${run_benchmark_train_result_file} | ||||
| MS_PRINT_TESTCASE_END_MSG | MS_PRINT_TESTCASE_END_MSG | ||||
| } | } | ||||
| result=0 | |||||
| # Check benchmark_train result and return value | # Check benchmark_train result and return value | ||||
| if [[ ${Run_x86_status} != 0 ]];then | if [[ ${Run_x86_status} != 0 ]];then | ||||
| echo "Run_x86 failed" | echo "Run_x86 failed" | ||||
| cat ${run_x86_log_file} | cat ${run_x86_log_file} | ||||
| exit 1 | |||||
| result=1 | |||||
| fi | fi | ||||
| if [[ ${Run_arm64_status} != 0 ]];then | if [[ ${Run_arm64_status} != 0 ]];then | ||||
| echo "Run_arm64 failed" | echo "Run_arm64 failed" | ||||
| cat ${run_arm64_log_file} | cat ${run_arm64_log_file} | ||||
| exit 1 | |||||
| result=1 | |||||
| fi | fi | ||||
| if [[ ${Run_arm32_status} != 0 ]];then | if [[ ${Run_arm32_status} != 0 ]];then | ||||
| echo "Run_arm32 failed" | echo "Run_arm32 failed" | ||||
| cat ${run_arm32_log_file} | cat ${run_arm32_log_file} | ||||
| exit 1 | |||||
| result=1 | |||||
| fi | fi | ||||
| echo "Test ended - Results:" | echo "Test ended - Results:" | ||||
| Print_Benchmark_Result | Print_Benchmark_Result | ||||
| echo "Test run Time:" $DIFF | echo "Test run Time:" $DIFF | ||||
| exit 0 | |||||
| exit ${result} | |||||
| @@ -79,15 +79,18 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { | |||||
| auto output_data = new float[output_data_size]; | auto output_data = new float[output_data_size]; | ||||
| ASSERT_NE(output_data, nullptr); | ASSERT_NE(output_data, nullptr); | ||||
| // warm up loop | // warm up loop | ||||
| for (int i = 0; i < 3; i++) { | for (int i = 0; i < 3; i++) { | ||||
| AvgPoolingGrad(input_data, output_data, pooling_param, 1); | |||||
| std::fill(output_data, output_data + output_data_size, 0.f); | |||||
| AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param); | |||||
| } | } | ||||
| int loop_count = 100; | int loop_count = 100; | ||||
| auto time_start = mindspore::lite::GetTimeUs(); | auto time_start = mindspore::lite::GetTimeUs(); | ||||
| for (int i = 0; i < loop_count; i++) { | for (int i = 0; i < loop_count; i++) { | ||||
| AvgPoolingGrad(input_data, output_data, pooling_param, 1); | |||||
| std::fill(output_data, output_data + output_data_size, 0.f); | |||||
| AvgPoolingGrad(input_data, output_data, pooling_param->output_batch_, pooling_param); | |||||
| } | } | ||||
| auto time_end = mindspore::lite::GetTimeUs(); | auto time_end = mindspore::lite::GetTimeUs(); | ||||
| auto cost = time_end - time_start; | auto cost = time_end - time_start; | ||||
| @@ -407,18 +410,21 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { | |||||
| std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; | std::string dx_path = "./test_data/pooling/maxpoolgradfp32_1_dx_1_28_28_3.bin"; | ||||
| auto dx_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx_path.c_str(), &input_size)); | auto dx_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(dx_path.c_str(), &input_size)); | ||||
| ASSERT_NE(dx_data, nullptr); | ASSERT_NE(dx_data, nullptr); | ||||
| int in_batch_size = | |||||
| pooling_param->input_h_ * pooling_param->input_w_ * pooling_param->input_channel_ * pooling_param->input_batch_; | |||||
| auto output_data = new float[output_data_size]; | auto output_data = new float[output_data_size]; | ||||
| ASSERT_NE(output_data, nullptr); | ASSERT_NE(output_data, nullptr); | ||||
| // warm up loop | // warm up loop | ||||
| for (int i = 0; i < 3; i++) { | for (int i = 0; i < 3; i++) { | ||||
| MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); | |||||
| std::fill(output_data, output_data + in_batch_size, 0.f); | |||||
| MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param); | |||||
| } | } | ||||
| int loop_count = 100; | int loop_count = 100; | ||||
| auto time_start = mindspore::lite::GetTimeUs(); | auto time_start = mindspore::lite::GetTimeUs(); | ||||
| for (int i = 0; i < loop_count; i++) { | for (int i = 0; i < loop_count; i++) { | ||||
| MaxPoolingGrad(in_data, dx_data, dy_data, output_data, pooling_param, 1); | |||||
| std::fill(output_data, output_data + in_batch_size, 0.f); | |||||
| MaxPoolingGrad(in_data, dy_data, output_data, pooling_param->output_batch_, pooling_param); | |||||
| } | } | ||||
| auto time_end = mindspore::lite::GetTimeUs(); | auto time_end = mindspore::lite::GetTimeUs(); | ||||
| auto cost = time_end - time_start; | auto cost = time_end - time_start; | ||||
| @@ -135,7 +135,6 @@ int NetTrain::ReadCalibData() { | |||||
| MS_LOG(INFO) << "Start reading calibData file"; | MS_LOG(INFO) << "Start reading calibData file"; | ||||
| std::string tensor_name; | std::string tensor_name; | ||||
| while (!in_file.eof()) { | while (!in_file.eof()) { | ||||
| getline(in_file, line); | getline(in_file, line); | ||||
| std::stringstream string_line1(line); | std::stringstream string_line1(line); | ||||
| @@ -79,7 +79,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||||
| std::vector<std::string> input_data_list_; | std::vector<std::string> input_data_list_; | ||||
| DataType in_data_type_; | DataType in_data_type_; | ||||
| std::string in_data_type_in_ = "bin"; | std::string in_data_type_in_ = "bin"; | ||||
| int cpu_bind_mode_ = 0; | |||||
| int cpu_bind_mode_ = 1; | |||||
| // MarkPerformance | // MarkPerformance | ||||
| int num_threads_ = 1; | int num_threads_ = 1; | ||||
| int warm_up_loop_count_ = 0; | int warm_up_loop_count_ = 0; | ||||
| @@ -32,7 +32,6 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = { | |||||
| schema::PrimitiveType_PoolingGrad, | schema::PrimitiveType_PoolingGrad, | ||||
| schema::PrimitiveType_BiasGrad, | schema::PrimitiveType_BiasGrad, | ||||
| schema::PrimitiveType_BNGrad, | schema::PrimitiveType_BNGrad, | ||||
| schema::PrimitiveType_ActivationGrad, | |||||
| schema::PrimitiveType_ApplyMomentum, | schema::PrimitiveType_ApplyMomentum, | ||||
| schema::PrimitiveType_Sgd, | schema::PrimitiveType_Sgd, | ||||
| schema::PrimitiveType_Adam, | schema::PrimitiveType_Adam, | ||||
| @@ -219,6 +218,26 @@ STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::v | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| static bool IsKCHWSource(kTransFilterType type) { | |||||
| return (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW); | |||||
| } | |||||
| static bool IsCKHWSource(kTransFilterType type) { | |||||
| return (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC); | |||||
| } | |||||
| static bool IsHWCKSource(kTransFilterType type) { return (type == kHWCK2KCHW || type == kHWCK2CKHW); } | |||||
| static bool IsHWKCSource(kTransFilterType type) { return (type == kHWKC2KCHW || type == kHWKC2CKHW); } | |||||
| static bool IsNHWCSource(kTransFilterType type) { | |||||
| return (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW); | |||||
| } | |||||
| static bool IsCHWKSource(kTransFilterType type) { return (type == kCHWK2HWCK || type == kCHWK2KHWC); } | |||||
| static bool IsKHWCSource(kTransFilterType type) { return (type == kKHWC2HWCK || type == kKHWC2CHWK); } | |||||
| STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | ||||
| int32_t *filterH, int32_t *filterW) { | int32_t *filterH, int32_t *filterW) { | ||||
| if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) { | if (filterK == nullptr || filterC == nullptr || filterH == nullptr || filterW == nullptr) { | ||||
| @@ -226,37 +245,37 @@ STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| MS_ASSERT(oriDims.size() == 4); | MS_ASSERT(oriDims.size() == 4); | ||||
| if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { | |||||
| if (IsKCHWSource(type)) { | |||||
| *filterK = oriDims.at(KCHW_K); | *filterK = oriDims.at(KCHW_K); | ||||
| *filterC = oriDims.at(KCHW_C); | *filterC = oriDims.at(KCHW_C); | ||||
| *filterH = oriDims.at(KCHW_H); | *filterH = oriDims.at(KCHW_H); | ||||
| *filterW = oriDims.at(KCHW_W); | *filterW = oriDims.at(KCHW_W); | ||||
| } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { | |||||
| } else if (IsCKHWSource(type)) { | |||||
| *filterC = oriDims.at(CKHW_C); | *filterC = oriDims.at(CKHW_C); | ||||
| *filterK = oriDims.at(CKHW_K); | *filterK = oriDims.at(CKHW_K); | ||||
| *filterH = oriDims.at(CKHW_H); | *filterH = oriDims.at(CKHW_H); | ||||
| *filterW = oriDims.at(CKHW_W); | *filterW = oriDims.at(CKHW_W); | ||||
| } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { | |||||
| } else if (IsHWCKSource(type)) { | |||||
| *filterH = oriDims.at(HWCK_H); | *filterH = oriDims.at(HWCK_H); | ||||
| *filterW = oriDims.at(HWCK_W); | *filterW = oriDims.at(HWCK_W); | ||||
| *filterC = oriDims.at(HWCK_C); | *filterC = oriDims.at(HWCK_C); | ||||
| *filterK = oriDims.at(HWCK_K); | *filterK = oriDims.at(HWCK_K); | ||||
| } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { | |||||
| } else if (IsHWKCSource(type)) { | |||||
| *filterH = oriDims.at(HWKC_H); | *filterH = oriDims.at(HWKC_H); | ||||
| *filterW = oriDims.at(HWKC_W); | *filterW = oriDims.at(HWKC_W); | ||||
| *filterK = oriDims.at(HWKC_K); | *filterK = oriDims.at(HWKC_K); | ||||
| *filterC = oriDims.at(HWKC_C); | *filterC = oriDims.at(HWKC_C); | ||||
| } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { | |||||
| } else if (IsNHWCSource(type)) { | |||||
| *filterK = oriDims.at(NHWC_N); | *filterK = oriDims.at(NHWC_N); | ||||
| *filterH = oriDims.at(NHWC_H); | *filterH = oriDims.at(NHWC_H); | ||||
| *filterW = oriDims.at(NHWC_W); | *filterW = oriDims.at(NHWC_W); | ||||
| *filterC = oriDims.at(NHWC_C); | *filterC = oriDims.at(NHWC_C); | ||||
| } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { | |||||
| } else if (IsCHWKSource(type)) { | |||||
| *filterC = oriDims.at(CHWK_C); | *filterC = oriDims.at(CHWK_C); | ||||
| *filterH = oriDims.at(CHWK_H); | *filterH = oriDims.at(CHWK_H); | ||||
| *filterW = oriDims.at(CHWK_W); | *filterW = oriDims.at(CHWK_W); | ||||
| *filterK = oriDims.at(CHWK_K); | *filterK = oriDims.at(CHWK_K); | ||||
| } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { | |||||
| } else if (IsKHWCSource(type)) { | |||||
| *filterK = oriDims.at(KHWC_K); | *filterK = oriDims.at(KHWC_K); | ||||
| *filterH = oriDims.at(KHWC_H); | *filterH = oriDims.at(KHWC_H); | ||||
| *filterW = oriDims.at(KHWC_W); | *filterW = oriDims.at(KHWC_W); | ||||
| @@ -290,6 +309,37 @@ STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filt | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| static int Convert2KHWC(int srcFormat) { | |||||
| if (srcFormat == schema::Format::Format_KCHW) return kKCHW2KHWC; | |||||
| if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KHWC; | |||||
| if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KHWC; | |||||
| return -1; | |||||
| } | |||||
| static int Convert2HWCK(int srcFormat) { | |||||
| if (srcFormat == schema::Format::Format_KCHW) return kKCHW2HWCK; | |||||
| if (srcFormat == schema::Format::Format_KHWC) return kKHWC2HWCK; | |||||
| if (srcFormat == schema::Format::Format_CKHW) return kCKHW2HWCK; | |||||
| if (srcFormat == schema::Format::Format_CHWK) return kCHWK2HWCK; | |||||
| return -1; | |||||
| } | |||||
| static int Convert2KCHW(int srcFormat) { | |||||
| if (srcFormat == schema::Format::Format_HWCK) return kHWCK2KCHW; | |||||
| if (srcFormat == schema::Format::Format_HWKC) return kHWKC2KCHW; | |||||
| if (srcFormat == schema::Format::Format_KHWC) return kKHWC2KCHW; | |||||
| if (srcFormat == schema::Format::Format_CKHW) return kCKHW2KCHW; | |||||
| if (srcFormat == schema::Format::Format_CHWK) return kCHWK2KCHW; | |||||
| return -1; | |||||
| } | |||||
| static int Convert2CKHW(int srcFormat) { | |||||
| if (srcFormat == schema::Format::Format_HWCK) return kHWCK2CKHW; | |||||
| if (srcFormat == schema::Format::Format_HWKC) return kHWKC2CKHW; | |||||
| if (srcFormat == schema::Format::Format_KCHW) return kKCHW2CKHW; | |||||
| return -1; | |||||
| } | |||||
| STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | ||||
| if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| @@ -303,231 +353,40 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { | |||||
| auto srcFormat = tensor->format; | auto srcFormat = tensor->format; | ||||
| auto dataType = tensor->dataType; | auto dataType = tensor->dataType; | ||||
| STATUS status; | STATUS status; | ||||
| int convert = -1; | |||||
| if (dstFormat == srcFormat) return RET_OK; | |||||
| switch (dstFormat) { | switch (dstFormat) { | ||||
| case schema::Format::Format_KHWC: { | |||||
| switch (srcFormat) { | |||||
| case schema::Format::Format_KCHW: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kKCHW2KHWC); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CKHW: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kCKHW2KHWC); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CHWK: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kCHWK2KHWC); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_KHWC: | |||||
| return RET_OK; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " | |||||
| << EnumNameFormat(dstFormat); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } break; | |||||
| case schema::Format::Format_HWCK: { | |||||
| switch (srcFormat) { | |||||
| case schema::Format::Format_KCHW: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kKCHW2HWCK); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_KHWC: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kKHWC2HWCK); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CKHW: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kCKHW2HWCK); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CHWK: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kCHWK2HWCK); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_HWCK: | |||||
| return RET_OK; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " | |||||
| << EnumNameFormat(dstFormat); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } break; | |||||
| case schema::Format::Format_KCHW: { | |||||
| switch (srcFormat) { | |||||
| case schema::Format::Format_KCHW: | |||||
| return RET_OK; | |||||
| case schema::Format::Format_HWCK: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kHWCK2KCHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_HWKC: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kHWKC2KCHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_KHWC: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kKHWC2KCHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CKHW: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kCKHW2KCHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CHWK: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kCHWK2KCHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " | |||||
| << EnumNameFormat(dstFormat); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } break; | |||||
| case schema::Format::Format_CKHW: { | |||||
| switch (srcFormat) { | |||||
| case schema::Format::Format_HWCK: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kHWCK2CKHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_HWKC: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kHWKC2CKHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_KCHW: | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, kKCHW2CKHW); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| break; | |||||
| case schema::Format::Format_CKHW: | |||||
| return RET_OK; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " | |||||
| << EnumNameFormat(dstFormat); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } break; | |||||
| case schema::Format::Format_KHWC: | |||||
| convert = Convert2KHWC(srcFormat); | |||||
| break; | |||||
| case schema::Format::Format_HWCK: | |||||
| convert = Convert2HWCK(srcFormat); | |||||
| break; | |||||
| case schema::Format::Format_KCHW: | |||||
| convert = Convert2KCHW(srcFormat); | |||||
| break; | |||||
| case schema::Format::Format_CKHW: | |||||
| convert = Convert2CKHW(srcFormat); | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " | |||||
| << EnumNameFormat(dstFormat); | |||||
| return RET_ERROR; | |||||
| convert = -1; | |||||
| } | |||||
| if (convert == -1) { | |||||
| MS_LOG(ERROR) << "Unsupported transform from " << EnumNameFormat(srcFormat) << " to " << EnumNameFormat(dstFormat); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (dataType == kNumberTypeFloat32) { | |||||
| status = TransFilterFormat<float>(tensor, static_cast<kTransFilterType>(convert)); | |||||
| } else if (dataType == kNumberTypeUInt8) { | |||||
| status = TransFilterFormat<uint8_t>(tensor, static_cast<kTransFilterType>(convert)); | |||||
| } else if (dataType == kNumberTypeInt8) { | |||||
| status = TransFilterFormat<int8_t>(tensor, static_cast<kTransFilterType>(convert)); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dataType: " << dataType; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "TransFilterData failed: " << status; | MS_LOG(ERROR) << "TransFilterData failed: " << status; | ||||