|
|
@@ -52,6 +52,8 @@ ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra |
|
|
auto op_name = primitive->name(); |
|
|
auto op_name = primitive->name(); |
|
|
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); |
|
|
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); |
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
|
|
|
std::set<TypePtr> valid_params_types = {kTensorType}; |
|
|
|
|
|
CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name); |
|
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
auto inshape = shape_map[kShape]; |
|
|
auto inshape = shape_map[kShape]; |
|
|
auto value = MakeValue(inshape); |
|
|
auto value = MakeValue(inshape); |
|
|
|