|
|
|
@@ -67,19 +67,8 @@ int64_t windowed_output_size(int64_t input_size, int64_t ksize, int64_t stride, |
|
|
|
return output; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &x_shape, |
|
|
|
const std::vector<int64_t> &k_size, const std::vector<int64_t> &stride, |
|
|
|
const PadMode pad_mode, const TypeId x_dtype) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum || stride.size() != kShapeDimNum) { |
|
|
|
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4."; |
|
|
|
} |
|
|
|
int64_t pad_top, pad_bottom, pad_left, pad_right; |
|
|
|
int64_t h_output = windowed_output_size(x_shape[2], k_size[2], stride[2], pad_mode, &pad_top, &pad_bottom); |
|
|
|
int64_t w_output = windowed_output_size(x_shape[3], k_size[3], stride[3], pad_mode, &pad_left, &pad_right); |
|
|
|
|
|
|
|
std::vector<std::vector<float>> GetAssistInputMatrix(const std::vector<int64_t> &x_shape, int64_t pad_top, |
|
|
|
int64_t pad_bottom, int64_t pad_left, int64_t pad_right) { |
|
|
|
// `assist_input_matrix` is a 2d matrix with input_shape after padding, |
|
|
|
// the value of element which is padded is 0, else are 1. |
|
|
|
// For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, |
|
|
|
@@ -102,6 +91,22 @@ ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std |
|
|
|
assist_input_matrix.emplace_back(tmp_one_vector); |
|
|
|
} |
|
|
|
} |
|
|
|
return assist_input_matrix; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr CreateMeanMatrixValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &x_shape, |
|
|
|
const std::vector<int64_t> &k_size, const std::vector<int64_t> &stride, |
|
|
|
const PadMode pad_mode, const TypeId x_dtype) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
if (x_shape.size() != kShapeDimNum || k_size.size() != kShapeDimNum || stride.size() != kShapeDimNum) { |
|
|
|
MS_LOG(EXCEPTION) << "The dim of x_shape or kernel_size or strides of AvgPoolGrad should be 4."; |
|
|
|
} |
|
|
|
int64_t pad_top, pad_bottom, pad_left, pad_right; |
|
|
|
int64_t h_output = windowed_output_size(x_shape[2], k_size[2], stride[2], pad_mode, &pad_top, &pad_bottom); |
|
|
|
int64_t w_output = windowed_output_size(x_shape[3], k_size[3], stride[3], pad_mode, &pad_left, &pad_right); |
|
|
|
auto assist_input_matrix = GetAssistInputMatrix(x_shape, pad_top, pad_bottom, pad_left, pad_right); |
|
|
|
|
|
|
|
// calculate output |
|
|
|
std::vector<float> hw_output(h_output * w_output, 0.0); |
|
|
|
|