|
|
|
@@ -23,12 +23,6 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace abstract { |
|
|
|
const size_t stride_num_element = 2; |
|
|
|
const size_t stride_start_idx = 2; |
|
|
|
const size_t dilation_num_element = 2; |
|
|
|
const size_t dilation_start_idx = 2; |
|
|
|
const size_t padding_num_element = 4; |
|
|
|
const size_t padding_start_idx = 0; |
|
|
|
int64_t GetAndCheckFormat(const ValuePtr &value) { |
|
|
|
int64_t data_format; |
|
|
|
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); |
|
|
|
@@ -82,7 +76,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & |
|
|
|
auto pad_mode_ptr = primitive->GetAttr("pad_mode"); |
|
|
|
if (pad_mode_ptr != nullptr) { |
|
|
|
int64_t pad_mode; |
|
|
|
const size_t middle = 2; |
|
|
|
const int64_t middle = 2; |
|
|
|
CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true); |
|
|
|
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) { |
|
|
|
padding = 0; |
|
|
|
@@ -98,7 +92,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; |
|
|
|
} |
|
|
|
} |
|
|
|
const size_t twice = 2; |
|
|
|
const int64_t twice = 2; |
|
|
|
int64_t h_out = (((h_input + twice * padding - (window - 1)) - 1) / stride) + 1; |
|
|
|
int64_t w_out = (((w_input + twice * padding - (window - 1)) - 1) / stride) + 1; |
|
|
|
ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; |
|
|
|
@@ -203,50 +197,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit |
|
|
|
return std::make_shared<AbstractTuple>(rets); |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h, |
|
|
|
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride, |
|
|
|
const std::vector<int64_t> &dilation, const int64_t &pad_mode, |
|
|
|
const std::vector<int64_t> &padding) { |
|
|
|
const size_t middle = 2; |
|
|
|
const size_t second_index = 2; |
|
|
|
const size_t third_index = 3; |
|
|
|
if (pad_mode == static_cast<int64_t>(PadMode::VALID)) { |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]))); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]))); |
|
|
|
const size_t nhwc = 4; |
|
|
|
(void)pad_list->insert(pad_list->begin(), nhwc, 0); |
|
|
|
} else if (pad_mode == static_cast<int64_t>(PadMode::SAME)) { |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0]))); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1]))); |
|
|
|
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; |
|
|
|
pad_needed_h = std::max((int64_t)0, pad_needed_h); |
|
|
|
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / middle))); |
|
|
|
pad_list->push_back(pad_needed_h - pad_list->at(0)); |
|
|
|
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; |
|
|
|
pad_needed_w = std::max((int64_t)0, pad_needed_w); |
|
|
|
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / middle))); |
|
|
|
pad_list->push_back(pad_needed_w - pad_list->at(middle)); |
|
|
|
} else if (pad_mode == static_cast<int64_t>(PadMode::PAD)) { |
|
|
|
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); |
|
|
|
output_hw->push_back(static_cast<int64_t>(std::floor( |
|
|
|
1 + (((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0]) - (kernel[0] - 1) * (dilation[0] - 1)) / |
|
|
|
stride[0]))); |
|
|
|
output_hw->push_back(static_cast<int64_t>( |
|
|
|
std::floor(1 + (((x_w * 1.0) + pad_list->at(second_index) + pad_list->at(third_index) - kernel[1]) - |
|
|
|
(kernel[1] - 1) * (dilation[1] - 1)) / |
|
|
|
stride[1]))); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const AbstractTensorPtr &input_w) { |
|
|
|
ShapeVector w_min_shape = input_w->shape()->min_shape(); |
|
|
|
ShapeVector w_max_shape = input_w->shape()->max_shape(); |
|
|
|
CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); |
|
|
|
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); |
|
|
|
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); |
|
|
|
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
// Inputs: at least one tensor(y_backprop) |
|
|
|
|