|
|
|
@@ -268,200 +268,109 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit |
|
|
|
return std::make_shared<AbstractTuple>(rets); |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h, |
|
|
|
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride, |
|
|
|
const std::vector<int64_t> &dilation, const std::string &pad_mode, |
|
|
|
const std::vector<int64_t> &padding) { |
|
|
|
if (pad_mode == "valid") { |
|
|
|
output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); |
|
|
|
output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); |
|
|
|
pad_list->insert(pad_list->begin(), 4, 0); |
|
|
|
} else if (pad_mode == "same") { |
|
|
|
output_hw->push_back(std::ceil((x_h * 1.0) / stride[0])); |
|
|
|
output_hw->push_back(std::ceil((x_w * 1.0) / stride[1])); |
|
|
|
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; |
|
|
|
pad_needed_h = std::max((int64_t)0, pad_needed_h); |
|
|
|
pad_list->push_back(std::floor(pad_needed_h / 2)); |
|
|
|
pad_list->push_back(pad_needed_h - pad_list->at(0)); |
|
|
|
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w; |
|
|
|
pad_needed_w = std::max((int64_t)0, pad_needed_w); |
|
|
|
pad_list->push_back(std::floor(pad_needed_w / 2)); |
|
|
|
pad_list->push_back(pad_needed_w - pad_list->at(2)); |
|
|
|
} else if (pad_mode == "pad") { |
|
|
|
pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); |
|
|
|
output_hw->push_back(std::floor( |
|
|
|
1 + |
|
|
|
((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) / stride[0])); |
|
|
|
output_hw->push_back(std::floor( |
|
|
|
1 + |
|
|
|
((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) / stride[1])); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const AbstractBasePtrList &args_spec_list) { |
|
|
|
const std::string op_name = primitive->name(); |
|
|
|
CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
|
|
|
|
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(input_x); |
|
|
|
MS_EXCEPTION_IF_NULL(input_x->shape()); |
|
|
|
ShapeVector x_shape = input_x->shape()->shape(); |
|
|
|
ShapeVector x_min_shape = input_x->shape()->min_shape(); |
|
|
|
ShapeVector x_max_shape = input_x->shape()->max_shape(); |
|
|
|
(void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); |
|
|
|
for (size_t i = 0; i < x_shape.size(); ++i) { |
|
|
|
if ((x_shape[i] < 0) && (x_shape[i] != Shape::SHP_ANY)) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape element x_shape[" << i << "] must be positive integer, but got " << x_shape[i]; |
|
|
|
} |
|
|
|
if (x_min_shape[i] < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Min Shape element x_min_shape[" << i << "] must be positive integer, but got " |
|
|
|
<< x_min_shape[i]; |
|
|
|
} |
|
|
|
if (x_max_shape[i] < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Max Shape element x_max_shape[" << i << "] must be positive integer, but got " |
|
|
|
<< x_max_shape[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); |
|
|
|
CheckShapeAnyAndPositive(op_name + " x_shape", x_shape); |
|
|
|
CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape); |
|
|
|
CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape); |
|
|
|
AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); |
|
|
|
MS_EXCEPTION_IF_NULL(input_w); |
|
|
|
MS_EXCEPTION_IF_NULL(input_w->shape()); |
|
|
|
ShapeVector w_shape = input_w->shape()->shape(); |
|
|
|
ShapeVector w_min_shape = input_w->shape()->min_shape(); |
|
|
|
ShapeVector w_max_shape = input_w->shape()->max_shape(); |
|
|
|
(void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); |
|
|
|
for (size_t i = 0; i < w_shape.size(); ++i) { |
|
|
|
if ((w_shape[i] < 0) && (w_shape[i] != Shape::SHP_ANY)) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape element w_shape[" << i << "] must be positive integer, but got " << w_shape[i]; |
|
|
|
} |
|
|
|
if (w_min_shape[i] < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Min Shape element w_min_shape[" << i << "] must be positive integer, but got " |
|
|
|
<< w_min_shape[i]; |
|
|
|
} |
|
|
|
if (w_max_shape[i] < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Max Shape element w_max_shape[" << i << "] must be positive integer, but got " |
|
|
|
<< w_max_shape[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::set<std::string> available_data_format{"NCHW", "NHWC"}; |
|
|
|
CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); |
|
|
|
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); |
|
|
|
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); |
|
|
|
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); |
|
|
|
std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"}); |
|
|
|
int64_t n_axis = 0; |
|
|
|
int64_t c_axis = 1; |
|
|
|
int64_t h_axis = 2; |
|
|
|
int64_t w_axis = 3; |
|
|
|
auto data_format_ptr = primitive->GetAttr("format"); |
|
|
|
std::string data_format = "NCHW"; |
|
|
|
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { |
|
|
|
data_format = data_format_ptr->cast<StringImmPtr>()->value(); |
|
|
|
if (available_data_format.find(data_format) == available_data_format.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ". use NCHW or NHWC"; |
|
|
|
} |
|
|
|
if (data_format == "NHWC") { |
|
|
|
c_axis = 3; |
|
|
|
h_axis = 1; |
|
|
|
w_axis = 2; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int64_t group = primitive->GetAttr("group")->cast<Int64ImmPtr>()->value(); |
|
|
|
if (group <= 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid group value: " << group << ", should be greater then 0"; |
|
|
|
if (data_format == "NHWC") { |
|
|
|
c_axis = 3; |
|
|
|
h_axis = 1; |
|
|
|
w_axis = 2; |
|
|
|
} |
|
|
|
int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group"); |
|
|
|
if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) { |
|
|
|
MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis] |
|
|
|
<< " (channels) must be divisible by group = " << group; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t out_channel = primitive->GetAttr("out_channel")->cast<Int64ImmPtr>()->value(); |
|
|
|
if (out_channel <= 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid out_channel value: " << out_channel << ", should be greater then 0"; |
|
|
|
} |
|
|
|
if ((w_shape[0] != Shape::SHP_ANY) && (w_shape[0] != out_channel)) { |
|
|
|
MS_LOG(EXCEPTION) << "w_shape[0] = " << w_shape[0] << " must equal to = " << out_channel; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t kernel_h = 0; |
|
|
|
int64_t kernel_w = 0; |
|
|
|
ValuePtr kernel_size_attr = primitive->GetAttr("kernel_size"); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_size_attr); |
|
|
|
if (kernel_size_attr->isa<ValueTuple>()) { |
|
|
|
std::vector<ValuePtr> kernel_size_vec = kernel_size_attr->cast<ValueTuplePtr>()->value(); |
|
|
|
kernel_h = GetValue<int64_t>(kernel_size_vec[0]); |
|
|
|
kernel_w = GetValue<int64_t>(kernel_size_vec[1]); |
|
|
|
} else { |
|
|
|
int64_t kernel_size = kernel_size_attr->cast<Int64ImmPtr>()->value(); |
|
|
|
kernel_h = kernel_size; |
|
|
|
kernel_w = kernel_size; |
|
|
|
} |
|
|
|
if ((w_shape[2] != Shape::SHP_ANY) && (w_shape[2] != kernel_h)) { |
|
|
|
MS_LOG(EXCEPTION) << "weight height, w_shape[2] = " << w_shape[2] << ", must equal to = " << kernel_h; |
|
|
|
int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel"); |
|
|
|
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) { |
|
|
|
MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel; |
|
|
|
} |
|
|
|
if ((w_shape[3] != Shape::SHP_ANY) && (w_shape[3] != kernel_w)) { |
|
|
|
MS_LOG(EXCEPTION) << "weight width, w_shape[3] = " << w_shape[3] << ", must equal to = " << kernel_w; |
|
|
|
std::vector<int64_t> kernel_size = CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, 2); |
|
|
|
if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) { |
|
|
|
MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0]; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t stride_h = 0; |
|
|
|
int64_t stride_w = 0; |
|
|
|
ValuePtr stride_attr = primitive->GetAttr("stride"); |
|
|
|
MS_EXCEPTION_IF_NULL(stride_attr); |
|
|
|
if (stride_attr->isa<ValueTuple>()) { |
|
|
|
std::vector<ValuePtr> stride_vec = stride_attr->cast<ValueTuplePtr>()->value(); |
|
|
|
stride_h = GetValue<int64_t>(stride_vec[2]); |
|
|
|
stride_w = GetValue<int64_t>(stride_vec[3]); |
|
|
|
} else { |
|
|
|
int64_t stride = stride_attr->cast<Int64ImmPtr>()->value(); |
|
|
|
stride_h = stride; |
|
|
|
stride_w = stride; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t dilation_h = 0; |
|
|
|
int64_t dilation_w = 0; |
|
|
|
ValuePtr dilation_attr = primitive->GetAttr("dilation"); |
|
|
|
MS_EXCEPTION_IF_NULL(dilation_attr); |
|
|
|
if (dilation_attr->isa<ValueTuple>()) { |
|
|
|
std::vector<ValuePtr> dilation_vec = dilation_attr->cast<ValueTuplePtr>()->value(); |
|
|
|
dilation_h = GetValue<int64_t>(dilation_vec[2]); |
|
|
|
dilation_w = GetValue<int64_t>(dilation_vec[3]); |
|
|
|
} else { |
|
|
|
int64_t dilation = dilation_attr->cast<Int64ImmPtr>()->value(); |
|
|
|
dilation_h = dilation; |
|
|
|
dilation_w = dilation; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> padding; |
|
|
|
ValuePtr padding_attr = primitive->GetAttr("pad"); |
|
|
|
MS_EXCEPTION_IF_NULL(padding_attr); |
|
|
|
if (padding_attr->isa<ValueTuple>()) { |
|
|
|
std::vector<ValuePtr> padding_vec = padding_attr->cast<ValueTuplePtr>()->value(); |
|
|
|
(void)std::transform(std::begin(padding_vec), std::end(padding_vec), std::back_inserter(padding), |
|
|
|
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); |
|
|
|
} else { |
|
|
|
int64_t padding_val = padding_attr->cast<Int64ImmPtr>()->value(); |
|
|
|
padding = {padding_val, padding_val, padding_val, padding_val}; |
|
|
|
if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) { |
|
|
|
MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1]; |
|
|
|
} |
|
|
|
|
|
|
|
std::set<std::string> available_pad_mode{"pad", "same", "valid"}; |
|
|
|
ValuePtr pad_mode_attr = primitive->GetAttr("pad_mode"); |
|
|
|
MS_EXCEPTION_IF_NULL(pad_mode_attr); |
|
|
|
auto pad_mode = pad_mode_attr->cast<StringImmPtr>()->value(); |
|
|
|
if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; |
|
|
|
} |
|
|
|
|
|
|
|
std::function<void(int64_t, int64_t, std::vector<int64_t> &, std::vector<int64_t> &)> pad_function = |
|
|
|
[kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, pad_mode, padding]( |
|
|
|
int64_t x_h, int64_t x_w, std::vector<int64_t> &output_hw, std::vector<int64_t> &pad_list) { |
|
|
|
if (pad_mode == "valid") { |
|
|
|
output_hw.push_back(std::ceil(((x_h * 1.0) - dilation_h * (kernel_h - 1)) / stride_h)); |
|
|
|
output_hw.push_back(std::ceil(((x_w * 1.0) - dilation_w * (kernel_w - 1)) / stride_w)); |
|
|
|
pad_list = {0, 0, 0, 0}; |
|
|
|
} else if (pad_mode == "same") { |
|
|
|
output_hw.push_back(std::ceil((x_h * 1.0) / stride_h)); |
|
|
|
output_hw.push_back(std::ceil((x_w * 1.0) / stride_w)); |
|
|
|
int64_t pad_needed_h = (output_hw[0] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_h; |
|
|
|
pad_needed_h = std::max((int64_t)0, pad_needed_h); |
|
|
|
pad_list.push_back(std::floor(pad_needed_h / 2)); |
|
|
|
pad_list.push_back(pad_needed_h - pad_list[0]); |
|
|
|
int64_t pad_needed_w = (output_hw[1] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_w; |
|
|
|
pad_needed_w = std::max((int64_t)0, pad_needed_w); |
|
|
|
pad_list.push_back(std::floor(pad_needed_w / 2)); |
|
|
|
pad_list.push_back(pad_needed_w - pad_list[2]); |
|
|
|
} else if (pad_mode == "pad") { |
|
|
|
pad_list = padding; |
|
|
|
output_hw.push_back(std::floor( |
|
|
|
1 + ((x_h * 1.0) + pad_list[0] + pad_list[1] - kernel_h - (kernel_h - 1) * (dilation_h - 1)) / stride_h)); |
|
|
|
output_hw.push_back(std::floor( |
|
|
|
1 + ((x_w * 1.0) + pad_list[2] + pad_list[3] - kernel_w - (kernel_w - 1) * (dilation_w - 1)) / stride_w)); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
std::vector<int64_t> stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2); |
|
|
|
std::vector<int64_t> dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2); |
|
|
|
std::vector<int64_t> padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4); |
|
|
|
std::string pad_mode = |
|
|
|
CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"}); |
|
|
|
std::vector<int64_t> output_hw; |
|
|
|
std::vector<int64_t> pad_list; |
|
|
|
std::vector<int64_t> output_hw_min; |
|
|
|
std::vector<int64_t> pad_list_min; |
|
|
|
std::vector<int64_t> output_hw_max; |
|
|
|
std::vector<int64_t> pad_list_max; |
|
|
|
pad_function(x_shape[h_axis], x_shape[w_axis], output_hw, pad_list); |
|
|
|
Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode, |
|
|
|
padding); |
|
|
|
if (x_shape[h_axis] == Shape::SHP_ANY) { |
|
|
|
output_hw[0] = Shape::SHP_ANY; |
|
|
|
} |
|
|
|
if (x_shape[w_axis] == Shape::SHP_ANY) { |
|
|
|
output_hw[1] = Shape::SHP_ANY; |
|
|
|
} |
|
|
|
pad_function(x_min_shape[h_axis], x_min_shape[w_axis], output_hw_min, pad_list_min); |
|
|
|
pad_function(x_max_shape[h_axis], x_max_shape[w_axis], output_hw_max, pad_list_max); |
|
|
|
|
|
|
|
Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride, |
|
|
|
dilation, pad_mode, padding); |
|
|
|
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride, |
|
|
|
dilation, pad_mode, padding); |
|
|
|
std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]), |
|
|
|
MakeValue(pad_list[3])}; |
|
|
|
primitive->set_attr("pad_list", MakeValue(pad_list_val)); |
|
|
|
@@ -477,28 +386,15 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p |
|
|
|
output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]}; |
|
|
|
output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]}; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < output_shape.size(); ++i) { |
|
|
|
if ((output_shape[i] < 0) && (output_shape[i] != Shape::SHP_ANY)) { |
|
|
|
MS_LOG(EXCEPTION) << "Shape element output_shape[" << i << "] must be positive integer, but got " |
|
|
|
<< output_shape[i]; |
|
|
|
} |
|
|
|
if (output_shape_min[i] < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Min Shape element output_shape_min[" << i << "] must be positive integer, but got " |
|
|
|
<< output_shape_min[i]; |
|
|
|
} |
|
|
|
if (output_shape_max[i] < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Max Shape element output_shape_max[" << i << "] must be positive integer, but got " |
|
|
|
<< output_shape_max[i]; |
|
|
|
} |
|
|
|
CheckShapeAnyAndPositive(op_name + " output_shape", output_shape); |
|
|
|
CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min); |
|
|
|
CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max); |
|
|
|
TypePtr x_type = input_x->element()->GetTypeTrack(); |
|
|
|
if (x_type->type_id() == TypeId::kNumberTypeInt8) { |
|
|
|
x_type = kInt32; |
|
|
|
} |
|
|
|
|
|
|
|
ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max); |
|
|
|
if (input_x->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt8) { |
|
|
|
auto output = std::make_shared<AbstractTensor>(kInt32, output_shape); |
|
|
|
output->set_shape(output_shape_ptr); |
|
|
|
return output; |
|
|
|
} |
|
|
|
return std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr); |
|
|
|
return std::make_shared<AbstractTensor>(x_type, output_shape_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
|