diff --git a/mindspore/core/ops/shape.cc b/mindspore/core/ops/shape.cc index f1cb67814c..764eda5007 100644 --- a/mindspore/core/ops/shape.cc +++ b/mindspore/core/ops/shape.cc @@ -52,6 +52,8 @@ ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vectorname(); CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); MS_EXCEPTION_IF_NULL(input_args[0]); + std::set valid_params_types = {kTensorType}; + CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name); auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); auto inshape = shape_map[kShape]; auto value = MakeValue(inshape);