|
|
|
@@ -28,11 +28,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A |
|
|
|
MS_EXCEPTION_IF_NULL(reverseV2_prim); |
|
|
|
auto prim_name = reverseV2_prim->name(); |
|
|
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); |
|
|
|
// auto axis = reverseV2_prim->get_axis(); |
|
|
|
// int dim = x_shape.size(); |
|
|
|
// for (auto &axis_value : axis) { |
|
|
|
// CheckAndConvertUtils::CheckInRange("axis value", axis_value, kIncludeLeft, {-dim, dim}, prim_name); |
|
|
|
// } |
|
|
|
return std::make_shared<abstract::Shape>(x_shape); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -40,17 +35,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & |
|
|
|
for (const auto &item : input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
} |
|
|
|
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, |
|
|
|
kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, |
|
|
|
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; |
|
|
|
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, |
|
|
|
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool}; |
|
|
|
auto infer_type = input_args[0]->BuildType(); |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, prim->name()); |
|
|
|
MS_EXCEPTION_IF_NULL(infer_type); |
|
|
|
auto tensor_type = infer_type->cast<TensorTypePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type); |
|
|
|
auto data_type = tensor_type->element(); |
|
|
|
MS_EXCEPTION_IF_NULL(data_type); |
|
|
|
return data_type; |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, prim->name()); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
|