Browse Source

oneslike op infer in c++

pull/15693/head
simson 4 years ago
parent
commit
f8fff098c3
3 changed files with 10 additions and 14 deletions
  1. +8
    -4
      mindspore/core/ops/ones_like.cc
  2. +1
    -2
      mindspore/core/ops/ones_like.h
  3. +1
    -8
      mindspore/ops/operations/array_ops.py

+ 8
- 4
mindspore/core/ops/ones_like.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@
#include "ops/ones_like.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
@@ -34,14 +35,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A

TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = input_args[0]->BuildType();
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike");
auto valid_type = common_valid_types;
valid_type.insert(kBool);
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, "OnesLike");
}

} // namespace
AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
InferShape(primitive, input_args));
}
REGISTER_PRIMITIVE_C(kNameOnesLike, OnesLike);
REGISTER_PRIMITIVE_EVAL_IMPL(OnesLike, prim::kPrimOnesLike, OnesLikeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 1
- 2
mindspore/core/ops/ones_like.h View File

@@ -24,10 +24,9 @@

namespace mindspore {
namespace ops {
constexpr auto kNameOnesLike = "OnesLike";
class OnesLike : public PrimitiveC {
public:
OnesLike() : PrimitiveC(kNameOnesLike) {}
OnesLike() : PrimitiveC(prim::kPrimOnesLike->name()) {}
~OnesLike() = default;
MS_DECLARE_PARENT(OnesLike, PrimitiveC);
void Init() {}


+ 1
- 8
mindspore/ops/operations/array_ops.py View File

@@ -1345,7 +1345,7 @@ class Zeros(Primitive):
"""Initialize Zeros"""


class OnesLike(PrimitiveWithInfer):
class OnesLike(Primitive):
"""
Creates a new tensor. The values of all elements are 1.

@@ -1376,13 +1376,6 @@ class OnesLike(PrimitiveWithInfer):
def __init__(self):
"""Initialize OnesLike"""

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype


class ZerosLike(PrimitiveWithCheck):
"""


Loading…
Cancel
Save