|
|
|
@@ -230,28 +230,28 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa |
|
|
|
const std::vector<int64_t> &dilation, const int64_t &pad_mode, |
|
|
|
const std::vector<int64_t> &padding) { |
|
|
|
if (pad_mode == PadMode::VALID) { |
|
|
|
output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); |
|
|
|
output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); |
|
|
|
pad_list->insert(pad_list->begin(), 4, 0); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]))); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]))); |
|
|
|
(void)pad_list->insert(pad_list->begin(), 4, 0); |
|
|
|
} else if (pad_mode == PadMode::SAME) { |
|
|
|
output_hw->push_back(std::ceil((x_h * 1.0) / stride[0])); |
|
|
|
output_hw->push_back(std::ceil((x_w * 1.0) / stride[1])); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0]))); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1]))); |
|
|
|
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; |
|
|
|
pad_needed_h = std::max((int64_t)0, pad_needed_h); |
|
|
|
pad_list->push_back(std::floor(pad_needed_h / 2)); |
|
|
|
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2))); |
|
|
|
pad_list->push_back(pad_needed_h - pad_list->at(0)); |
|
|
|
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; |
|
|
|
pad_needed_w = std::max((int64_t)0, pad_needed_w); |
|
|
|
pad_list->push_back(std::floor(pad_needed_w / 2)); |
|
|
|
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2))); |
|
|
|
pad_list->push_back(pad_needed_w - pad_list->at(2)); |
|
|
|
} else if (pad_mode == PadMode::PAD) { |
|
|
|
pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); |
|
|
|
output_hw->push_back(std::floor( |
|
|
|
1 + |
|
|
|
((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / stride[0])); |
|
|
|
output_hw->push_back(std::floor( |
|
|
|
1 + |
|
|
|
((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / stride[1])); |
|
|
|
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::floor( |
|
|
|
1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / |
|
|
|
stride[0]))); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::floor( |
|
|
|
1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / |
|
|
|
stride[1]))); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -279,10 +279,10 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); |
|
|
|
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); |
|
|
|
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); |
|
|
|
const int64_t n_axis = 0; |
|
|
|
int64_t c_axis = 1; |
|
|
|
int64_t h_axis = 2; |
|
|
|
int64_t w_axis = 3; |
|
|
|
const uint64_t n_axis = 0; |
|
|
|
uint64_t c_axis = 1; |
|
|
|
uint64_t h_axis = 2; |
|
|
|
uint64_t w_axis = 3; |
|
|
|
int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format")); |
|
|
|
if (data_format == Format::NHWC) { |
|
|
|
c_axis = 3; |
|
|
|
|