From 3a71a6218aa8988e9366c9dc870db03e8b68e094 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Thu, 1 Apr 2021 19:40:58 +0800 Subject: [PATCH] Add ConvertShapePtrToShapeMap --- mindspore/core/utils/check_convert_utils.cc | 14 ++++++++++++++ mindspore/core/utils/check_convert_utils.h | 7 +++++++ 2 files changed, 21 insertions(+) diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index c72a6cf163..0aa3d3d706 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -396,6 +396,20 @@ std::vector CheckAndConvertUtils::ConvertShapePtrToShape(const std::str return shape_element->shape(); } +ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) { + MS_EXCEPTION_IF_NULL(shape); + if (!shape->isa()) { + return std::map>(); + } + auto shape_element = shape->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + ShapeMap shape_map; + shape_map[kShape] = shape_element->shape(); + shape_map[kMinShape] = shape_element->min_shape(); + shape_map[kMaxShape] = shape_element->max_shape(); + return shape_map; +} + void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &value_name, int64_t value, const string &prim_name, ExceptionType exception_type) { diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index bc07d40312..03bf225164 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -30,6 +30,10 @@ #include "utils/log_adapter.h" namespace mindspore { typedef std::pair, std::map> AttrConverterPair; +typedef std::map> ShapeMap; +constexpr auto kShape = "shape"; +constexpr auto kMinShape = "min_shape"; +constexpr auto kMaxShape = "max_shape"; enum CompareEnum : int64_t { kEqual = 1, // == @@ -234,6 +238,9 @@ class CheckAndConvertUtils { static std::vector ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, const std::string &prim_name); + + static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape); + static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type, const std::string &value_name, int64_t value, const std::string &prim_name = "", ExceptionType exception_type = ValueError);