diff --git a/mindspore/core/ops/ones_like.cc b/mindspore/core/ops/ones_like.cc index 1aaf4fb277..7f19ef7f9e 100644 --- a/mindspore/core/ops/ones_like.cc +++ b/mindspore/core/ops/ones_like.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include "ops/ones_like.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" #include "abstract/primitive_infer_map.h" namespace mindspore { @@ -34,14 +35,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto infer_type = input_args[0]->BuildType(); - return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike"); + auto valid_type = common_valid_types; + valid_type.insert(kBool); + return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, "OnesLike"); } + } // namespace AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + InferShape(primitive, input_args)); } -REGISTER_PRIMITIVE_C(kNameOnesLike, OnesLike); +REGISTER_PRIMITIVE_EVAL_IMPL(OnesLike, prim::kPrimOnesLike, OnesLikeInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/ones_like.h b/mindspore/core/ops/ones_like.h index b0af5ef2a7..cff0b8650a 100644 --- a/mindspore/core/ops/ones_like.h +++ b/mindspore/core/ops/ones_like.h @@ -24,10 +24,9 @@ namespace mindspore { namespace ops { -constexpr auto kNameOnesLike = "OnesLike"; class OnesLike : public PrimitiveC { public: - OnesLike() : PrimitiveC(kNameOnesLike) {} + OnesLike() : PrimitiveC(prim::kPrimOnesLike->name()) {} ~OnesLike() = default; MS_DECLARE_PARENT(OnesLike, PrimitiveC); void Init() {} diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index ae251d332d..0a56913d08 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1345,7 +1345,7 @@ class Zeros(Primitive): """Initialize Zeros""" -class OnesLike(PrimitiveWithInfer): +class OnesLike(Primitive): """ Creates a new tensor. The values of all elements are 1. @@ -1376,13 +1376,6 @@ class OnesLike(PrimitiveWithInfer): def __init__(self): """Initialize OnesLike""" - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) - return x_dtype - class ZerosLike(PrimitiveWithCheck): """