From: @lianliguang Reviewed-by: @zh_qh,@ginfung Signed-off-by: @ginfungpull/13727/MERGE
| @@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -36,8 +36,8 @@ PadMode AvgPool::get_pad_mode() const { | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | return PadMode(GetValue<int64_t>(value_ptr)); | ||||
| } | } | ||||
| void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), | |||||
| false, true))); | |||||
| this->AddAttr(kKernelSize, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); | |||||
| } | } | ||||
| std::vector<int64_t> AvgPool::get_kernel_size() const { | std::vector<int64_t> AvgPool::get_kernel_size() const { | ||||
| @@ -45,8 +45,7 @@ std::vector<int64_t> AvgPool::get_kernel_size() const { | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| void AvgPool::set_strides(const std::vector<int64_t> &strides) { | void AvgPool::set_strides(const std::vector<int64_t> &strides) { | ||||
| this->AddAttr(kStrides, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true))); | |||||
| this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); | |||||
| } | } | ||||
| std::vector<int64_t> AvgPool::get_strides() const { | std::vector<int64_t> AvgPool::get_strides() const { | ||||
| @@ -93,8 +93,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve | |||||
| w_out = floor(w_out); | w_out = floor(w_out); | ||||
| } | } | ||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); | CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); | ||||
| primitive->AddAttr(kPadList, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true))); | |||||
| primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name))); | |||||
| std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out}; | ||||
| if (format == NHWC) { | if (format == NHWC) { | ||||
| out_shape = {x_shape[0], h_out, w_out, out_channel}; | out_shape = {x_shape[0], h_out, w_out, out_channel}; | ||||
| @@ -144,11 +143,11 @@ void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||||
| } | } | ||||
| void Conv2D::set_stride(const std::vector<int64_t> &stride) { | void Conv2D::set_stride(const std::vector<int64_t> &stride) { | ||||
| AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); | |||||
| AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name()))); | |||||
| } | } | ||||
| void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { | void Conv2D::set_dilation(const std::vector<int64_t> &dilation) { | ||||
| AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); | |||||
| AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name()))); | |||||
| } | } | ||||
| void Conv2D::set_pad_mode(const PadMode &pad_mode) { | void Conv2D::set_pad_mode(const PadMode &pad_mode) { | ||||
| @@ -166,7 +165,7 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) { | |||||
| void Conv2D::set_pad(const std::vector<int64_t> &pad) { | void Conv2D::set_pad(const std::vector<int64_t> &pad) { | ||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | ||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); | |||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); | |||||
| } | } | ||||
| void Conv2D::set_mode(int64_t mode) { | void Conv2D::set_mode(int64_t mode) { | ||||
| @@ -111,7 +111,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) { | |||||
| void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) { | void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) { | ||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | ||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); | |||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); | |||||
| } | } | ||||
| void Conv2dTranspose::set_mode(int64_t mode) { | void Conv2dTranspose::set_mode(int64_t mode) { | ||||
| @@ -35,13 +35,13 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i | |||||
| this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); | this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); | ||||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | ||||
| auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); | |||||
| auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name()); | |||||
| if (strides[0] != strides[1]) { | if (strides[0] != strides[1]) { | ||||
| MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] | MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] | ||||
| << ", width " << strides[1]; | << ", width " << strides[1]; | ||||
| } | } | ||||
| this->set_stride(strides); | this->set_stride(strides); | ||||
| auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); | |||||
| auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name()); | |||||
| if (dilations[0] != dilations[1]) { | if (dilations[0] != dilations[1]) { | ||||
| MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] | MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] | ||||
| << ", width " << dilations[1]; | << ", width " << dilations[1]; | ||||
| @@ -57,7 +57,7 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i | |||||
| } else { | } else { | ||||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | ||||
| } | } | ||||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name())); | |||||
| this->set_out_channel( | this->set_out_channel( | ||||
| CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); | CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); | ||||
| @@ -30,13 +30,13 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve | |||||
| this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); | this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); | ||||
| this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); | ||||
| auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); | |||||
| auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name()); | |||||
| if (strides[0] != strides[1]) { | if (strides[0] != strides[1]) { | ||||
| MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] | MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] | ||||
| << ", width " << strides[1]; | << ", width " << strides[1]; | ||||
| } | } | ||||
| this->set_stride(strides); | this->set_stride(strides); | ||||
| auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); | |||||
| auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name()); | |||||
| if (dilations[0] != dilations[1]) { | if (dilations[0] != dilations[1]) { | ||||
| MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] | MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] | ||||
| << ", width " << dilations[1]; | << ", width " << dilations[1]; | ||||
| @@ -52,7 +52,7 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve | |||||
| } else { | } else { | ||||
| CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); | ||||
| } | } | ||||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); | |||||
| this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name())); | |||||
| this->set_out_channel( | this->set_out_channel( | ||||
| CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); | CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); | ||||
| @@ -105,11 +105,11 @@ void Conv2DBackpropInput::set_kernel_size(const std::vector<int64_t> &kernel_siz | |||||
| } | } | ||||
| void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) { | void Conv2DBackpropInput::set_stride(const std::vector<int64_t> &stride) { | ||||
| AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); | |||||
| AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name()))); | |||||
| } | } | ||||
| void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) { | void Conv2DBackpropInput::set_dilation(const std::vector<int64_t> &dilation) { | ||||
| AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); | |||||
| AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name()))); | |||||
| } | } | ||||
| void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { | void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { | ||||
| @@ -127,7 +127,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { | |||||
| void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) { | void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) { | ||||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | ||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); | |||||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); | |||||
| } | } | ||||
| void Conv2DBackpropInput::set_mode(int64_t mode) { | void Conv2DBackpropInput::set_mode(int64_t mode) { | ||||
| @@ -36,8 +36,8 @@ PadMode MaxPool::get_pad_mode() const { | |||||
| return PadMode(GetValue<int64_t>(value_ptr)); | return PadMode(GetValue<int64_t>(value_ptr)); | ||||
| } | } | ||||
| void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), | |||||
| false, true))); | |||||
| this->AddAttr(kKernelSize, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); | |||||
| } | } | ||||
| std::vector<int64_t> MaxPool::get_kernel_size() const { | std::vector<int64_t> MaxPool::get_kernel_size() const { | ||||
| @@ -45,8 +45,7 @@ std::vector<int64_t> MaxPool::get_kernel_size() const { | |||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| } | } | ||||
| void MaxPool::set_strides(const std::vector<int64_t> &strides) { | void MaxPool::set_strides(const std::vector<int64_t> &strides) { | ||||
| this->AddAttr(kStrides, | |||||
| MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true))); | |||||
| this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); | |||||
| } | } | ||||
| std::vector<int64_t> MaxPool::get_strides() const { | std::vector<int64_t> MaxPool::get_strides() const { | ||||
| @@ -330,24 +330,10 @@ bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, cons | |||||
| std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name, | std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name, | ||||
| const std::vector<int64_t> &arg_value, | const std::vector<int64_t> &arg_value, | ||||
| const std::string &prim_name, bool allow_four, | |||||
| bool ret_four) { | |||||
| auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void { | |||||
| std::ostringstream buffer; | |||||
| buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; | |||||
| if (allow_four) { | |||||
| buffer << "or four "; | |||||
| } | |||||
| buffer << " positive int64_t numbers , but got ["; | |||||
| for (auto item : arg_value) { | |||||
| buffer << item << ","; | |||||
| } | |||||
| buffer << "]"; | |||||
| MS_EXCEPTION(ValueError) << buffer.str(); | |||||
| }; | |||||
| const std::string &prim_name) { | |||||
| for (auto item : arg_value) { | for (auto item : arg_value) { | ||||
| if (item < 0) { | if (item < 0) { | ||||
| raise_message(); | |||||
| MS_EXCEPTION(ValueError) << "For " << prim_name << " attr " << arg_name << " should be a positive vector"; | |||||
| } | } | ||||
| } | } | ||||
| return arg_value; | return arg_value; | ||||
| @@ -162,8 +162,7 @@ const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeT | |||||
| class CheckAndConvertUtils { | class CheckAndConvertUtils { | ||||
| public: | public: | ||||
| static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value, | static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value, | ||||
| const std::string &prim_name, bool allow_four = false, | |||||
| bool ret_four = false); | |||||
| const std::string &prim_name); | |||||
| static std::string CheckString(const std::string &arg_name, const std::string &arg_value, | static std::string CheckString(const std::string &arg_name, const std::string &arg_value, | ||||
| const std::set<std::string> &check_list, const std::string &prim_name); | const std::set<std::string> &check_list, const std::string &prim_name); | ||||