Browse Source

!14566 Add ConvertShapePtrToShapeMap

From: @liangzhibo
Reviewed-by: @zh_qh,@ginfung,@zh_qh
Signed-off-by: @zh_qh,@zh_qh
pull/14566/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
9563bd09e9
2 changed files with 21 additions and 0 deletions
  1. +14
    -0
      mindspore/core/utils/check_convert_utils.cc
  2. +7
    -0
      mindspore/core/utils/check_convert_utils.h

+ 14
- 0
mindspore/core/utils/check_convert_utils.cc View File

@@ -396,6 +396,20 @@ std::vector<int64_t> CheckAndConvertUtils::ConvertShapePtrToShape(const std::str
return shape_element->shape(); return shape_element->shape();
} }


ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
if (!shape->isa<abstract::Shape>()) {
return std::map<std::string, std::vector<int64_t>>();
}
auto shape_element = shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
ShapeMap shape_map;
shape_map[kShape] = shape_element->shape();
shape_map[kMinShape] = shape_element->min_shape();
shape_map[kMaxShape] = shape_element->max_shape();
return shape_map;
}

void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type,
const string &value_name, int64_t value, const string &prim_name, const string &value_name, int64_t value, const string &prim_name,
ExceptionType exception_type) { ExceptionType exception_type) {


+ 7
- 0
mindspore/core/utils/check_convert_utils.h View File

@@ -30,6 +30,10 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair; typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair;
typedef std::map<std::string, std::vector<int64_t>> ShapeMap;
constexpr auto kShape = "shape";
constexpr auto kMinShape = "min_shape";
constexpr auto kMaxShape = "max_shape";


enum CompareEnum : int64_t { enum CompareEnum : int64_t {
kEqual = 1, // == kEqual = 1, // ==
@@ -234,6 +238,9 @@ class CheckAndConvertUtils {


static std::vector<int64_t> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, static std::vector<int64_t> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape,
const std::string &prim_name); const std::string &prim_name);

static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape);

static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type, static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type,
const std::string &value_name, int64_t value, const std::string &prim_name = "", const std::string &value_name, int64_t value, const std::string &prim_name = "",
ExceptionType exception_type = ValueError); ExceptionType exception_type = ValueError);


Loading…
Cancel
Save