|
|
|
@@ -26,30 +26,28 @@ |
|
|
|
namespace mindspore { |
|
|
|
void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); } |
|
|
|
|
|
|
|
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("ksize", MakeValue(kernel_size)); } |
|
|
|
|
|
|
|
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); } |
|
|
|
|
|
|
|
std::vector<int> AvgPool::get_strides() const { |
|
|
|
auto value_ptr = GetAttr("strides"); |
|
|
|
return GetValue<std::vector<int>>(value_ptr); |
|
|
|
std::string AvgPool::get_padding() const { |
|
|
|
auto value_ptr = GetAttr("padding"); |
|
|
|
return GetValue<std::string>(value_ptr); |
|
|
|
} |
|
|
|
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("k_size", MakeValue(kernel_size)); } |
|
|
|
|
|
|
|
std::vector<int> AvgPool::get_kernel_size() const { |
|
|
|
auto value_ptr = GetAttr("ksize"); |
|
|
|
auto value_ptr = GetAttr("k_size"); |
|
|
|
return GetValue<std::vector<int>>(value_ptr); |
|
|
|
} |
|
|
|
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); } |
|
|
|
|
|
|
|
std::string AvgPool::get_padding() const { |
|
|
|
auto value_ptr = GetAttr("padding"); |
|
|
|
return GetValue<std::string>(value_ptr); |
|
|
|
std::vector<int> AvgPool::get_strides() const { |
|
|
|
auto value_ptr = GetAttr("strides"); |
|
|
|
return GetValue<std::vector<int>>(value_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
void AvgPool::Init(const std::vector<int> &kernel_size, const std::vector<int> &stride, const std::string &padding) { |
|
|
|
auto prim_name = this->name(); |
|
|
|
this->AddAttr("data_format", MakeValue("NCHW")); |
|
|
|
this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name)); |
|
|
|
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("ksize", kernel_size, prim_name, false, true)); |
|
|
|
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("k_size", kernel_size, prim_name, false, true)); |
|
|
|
this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true)); |
|
|
|
} |
|
|
|
|
|
|
|
|