|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include "ops/ones_like.h" |
|
|
|
#include "ops/op_utils.h" |
|
|
|
#include "utils/check_convert_utils.h" |
|
|
|
#include "utils/tensor_construct_utils.h" |
|
|
|
#include "abstract/primitive_infer_map.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
@@ -34,14 +35,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A |
|
|
|
|
|
|
|
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
auto infer_type = input_args[0]->BuildType(); |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike"); |
|
|
|
auto valid_type = common_valid_types; |
|
|
|
valid_type.insert(kBool); |
|
|
|
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, "OnesLike"); |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace |
|
|
|
AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), |
|
|
|
InferShape(primitive, input_args)->shape()); |
|
|
|
InferShape(primitive, input_args)); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_C(kNameOnesLike, OnesLike); |
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(OnesLike, prim::kPrimOnesLike, OnesLikeInfer, nullptr, true); |
|
|
|
} // namespace ops |
|
|
|
} // namespace mindspore |