| @@ -271,6 +271,7 @@ inline const PrimitivePtr kPrimLrn = std::make_shared<Primitive>("LRN"); | |||||
| inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGrad"); | ||||
| inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | ||||
| inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | ||||
| inline const PrimitivePtr kPrimLog1p = std::make_shared<Primitive>("Log1p"); | |||||
| inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | ||||
| inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | ||||
| inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad"); | inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad"); | ||||
| @@ -287,6 +288,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu"); | |||||
| inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6"); | ||||
| inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | ||||
| inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU"); | inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU"); | ||||
| inline const PrimitivePtr kPrimSoftplus = std::make_shared<Primitive>("Softplus"); | |||||
| inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros"); | inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros"); | ||||
| inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike"); | ||||
| inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike"); | ||||
| @@ -23,13 +23,12 @@ namespace { | |||||
| abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto broad_cast_to = primitive->cast<PrimBroadcastToPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(broad_cast_to); | |||||
| auto prim_name = broad_cast_to->name(); | |||||
| auto prim_name = primitive->name(); | |||||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); | ||||
| auto input_x = broad_cast_to->get_shape(); | |||||
| auto value_ptr = primitive->GetAttr(kShape); | |||||
| auto input_x = GetValue<std::vector<int64_t>>(value_ptr); | |||||
| int64_t outer_dim_offset = input_x.size() - x_shape.size(); | int64_t outer_dim_offset = input_x.size() - x_shape.size(); | ||||
| CheckAndConvertUtils::Check("x shape", x_shape, kLessEqual, "input_x", input_x, prim_name); | |||||
| CheckAndConvertUtils::Check("x shape", x_shape.size(), kLessEqual, "input_x", input_x.size(), prim_name); | |||||
| bool flag = true; | bool flag = true; | ||||
| if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) { | if (input_x.end() == find(input_x.begin(), input_x.end(), -1)) { | ||||
| flag = false; | flag = false; | ||||
| @@ -49,7 +48,6 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::reverse(input_x.begin(), input_x.end()); | |||||
| return std::make_shared<abstract::Shape>(input_x); | return std::make_shared<abstract::Shape>(input_x); | ||||
| } | } | ||||
| @@ -78,8 +76,8 @@ std::vector<int64_t> BroadcastTo::get_shape() const { | |||||
| AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args), | ||||
| BroadcastToInferShape(primitive, input_args)->shape()); | |||||
| BroadcastToInferShape(primitive, input_args)); | |||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameBroadcastTo, BroadcastTo); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -69,6 +69,6 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv | |||||
| GatherDInferShape(primitive, input_args)); | GatherDInferShape(primitive, input_args)); | ||||
| return abs; | return abs; | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/log1p.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "utils/tensor_construct_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| // log1p | |||||
| namespace { | |||||
| abstract::ShapePtr Log1pInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||||
| auto in_shape = shape_map[kShape]; | |||||
| auto min_shape = shape_map[kMinShape]; | |||||
| auto max_shape = shape_map[kMaxShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); | |||||
| } | |||||
| TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| // check | |||||
| std::set<TypePtr> valid_index_types = {kFloat16, kFloat32}; | |||||
| auto x_type = | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name); | |||||
| return x_type; | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(Log1pInferType(primitive, input_args), | |||||
| Log1pInferShape(primitive, input_args)); | |||||
| return abs; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_LOG1P_H_ | |||||
| #define MINDSPORE_CORE_OPS_LOG1P_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameLog1p = "Log1p"; | |||||
| class Log1p : public PrimitiveC { | |||||
| public: | |||||
| Log1p() : PrimitiveC(kNameLog1p) { InitIOName({"x"}, {"y"}); } | |||||
| ~Log1p() = default; | |||||
| MS_DECLARE_PARENT(Log1p, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_LOG1P_H_ | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/softplus.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "utils/tensor_construct_utils.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| // softplus | |||||
| namespace { | |||||
| abstract::ShapePtr SoftplusInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||||
| auto in_shape = shape_map[kShape]; | |||||
| auto min_shape = shape_map[kMinShape]; | |||||
| auto max_shape = shape_map[kMaxShape]; | |||||
| return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); | |||||
| } | |||||
| TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto prim_name = prim->name(); | |||||
| // check | |||||
| std::set<TypePtr> valid_index_types = {kFloat16, kFloat32, kFloat64}; | |||||
| auto x_type = | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name); | |||||
| return x_type; | |||||
| } | |||||
| } // namespace | |||||
| AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto abs = std::make_shared<abstract::AbstractTensor>(SoftplusInferType(primitive, input_args), | |||||
| SoftplusInferShape(primitive, input_args)); | |||||
| return abs; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_SOFTPLUS_H_ | |||||
| #define MINDSPORE_CORE_OPS_SOFTPLUS_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "ops/op_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| constexpr auto kNameSoftplus = "Softplus"; | |||||
| class Softplus : public PrimitiveC { | |||||
| public: | |||||
| Softplus() : PrimitiveC(kNameSoftplus) { InitIOName({"x"}, {"output"}); } | |||||
| ~Softplus() = default; | |||||
| MS_DECLARE_PARENT(Softplus, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_SOFTPLUS_H_ | |||||
| @@ -1306,7 +1306,7 @@ class Ones(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class Zeros(PrimitiveWithInfer): | |||||
| class Zeros(Primitive): | |||||
| r""" | r""" | ||||
| Creates a tensor filled with value zeros. | Creates a tensor filled with value zeros. | ||||
| @@ -4570,7 +4570,7 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| class BroadcastTo(PrimitiveWithInfer): | |||||
| class BroadcastTo(Primitive): | |||||
| """ | """ | ||||
| Broadcasts input tensor to a given shape. | Broadcasts input tensor to a given shape. | ||||
| @@ -4629,34 +4629,6 @@ class BroadcastTo(PrimitiveWithInfer): | |||||
| validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) | validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name) | ||||
| self.shape = shape | self.shape = shape | ||||
| def infer_shape(self, x_shape): | |||||
| validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name) | |||||
| reversed_x_shape = tuple(reversed(x_shape)) | |||||
| reversed_filtered_target = [] | |||||
| for i, v in enumerate(tuple(reversed(self.shape))): | |||||
| if v == -1: | |||||
| if i >= len(reversed_x_shape): | |||||
| raise ValueError("-1 is not valid in a leading, non-existing dimension") | |||||
| reversed_filtered_target.append(reversed_x_shape[i]) | |||||
| else: | |||||
| reversed_filtered_target.append(v) | |||||
| self.shape = tuple(reversed(reversed_filtered_target)) | |||||
| self.add_prim_attr('shape', self.shape) | |||||
| for i, v in enumerate(reversed_x_shape): | |||||
| if v not in (reversed_filtered_target[i], 1): | |||||
| raise ValueError(f"Not supported shapes for broadcast, " | |||||
| f"x_shape: {tuple(x_shape)}, target shape {self.shape}.") | |||||
| return self.shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name) | |||||
| return x_dtype | |||||
| class Meshgrid(PrimitiveWithInfer): | class Meshgrid(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -5121,7 +5093,7 @@ class EmbeddingLookup(PrimitiveWithCheck): | |||||
| raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp)) | raise ValueError("The dimension of 'params' in EmbeddingLookup must <= 2, but got %d." % len(params_shp)) | ||||
| class GatherD(PrimitiveWithInfer): | |||||
| class GatherD(Primitive): | |||||
| """ | """ | ||||
| Gathers values along an axis specified by dim. | Gathers values along an axis specified by dim. | ||||
| @@ -20,7 +20,7 @@ from mindspore import context | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | |||||
| from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | |||||
| def _check_mode(class_name): | def _check_mode(class_name): | ||||
| @@ -50,7 +50,7 @@ def _check_summary_param(name, value, class_name): | |||||
| SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None} | SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None} | ||||
| class ScalarSummary(PrimitiveWithInfer): | |||||
| class ScalarSummary(Primitive): | |||||
| """ | """ | ||||
| Outputs a scalar to a protocol buffer through a scalar summary operator. | Outputs a scalar to a protocol buffer through a scalar summary operator. | ||||
| @@ -141,7 +141,7 @@ class ImageSummary(PrimitiveWithInfer): | |||||
| return SUMMARY_RETURN_VALUE | return SUMMARY_RETURN_VALUE | ||||
| class TensorSummary(PrimitiveWithInfer): | |||||
| class TensorSummary(Primitive): | |||||
| """ | """ | ||||
| Outputs a tensor to a protocol buffer through a tensor summary operator. | Outputs a tensor to a protocol buffer through a tensor summary operator. | ||||
| @@ -26,7 +26,7 @@ from ...common import dtype as mstype | |||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ...common._decorator import deprecated | from ...common._decorator import deprecated | ||||
| from .._utils import get_broadcast_shape | from .._utils import get_broadcast_shape | ||||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | |||||
| def _infer_shape_reduce(x, axis, keep_dims, prim_name): | def _infer_shape_reduce(x, axis, keep_dims, prim_name): | ||||
| @@ -1873,7 +1873,7 @@ class Log(PrimitiveWithInfer): | |||||
| return None | return None | ||||
| class Log1p(PrimitiveWithInfer): | |||||
| class Log1p(Primitive): | |||||
| """ | """ | ||||
| Returns the natural logarithm of one plus the input tensor element-wise. | Returns the natural logarithm of one plus the input tensor element-wise. | ||||
| @@ -1901,14 +1901,6 @@ class Log1p(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | self.init_prim_io_names(inputs=['x'], outputs=['y']) | ||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | |||||
| class Erf(PrimitiveWithInfer): | class Erf(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| @@ -230,7 +230,7 @@ class LogSoftmax(PrimitiveWithInfer): | |||||
| return logits | return logits | ||||
| class Softplus(PrimitiveWithInfer): | |||||
| class Softplus(Primitive): | |||||
| r""" | r""" | ||||
| Softplus activation function. | Softplus activation function. | ||||
| @@ -267,13 +267,6 @@ class Softplus(PrimitiveWithInfer): | |||||
| """Initialize Softplus""" | """Initialize Softplus""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name) | |||||
| return x_dtype | |||||
| class Softsign(PrimitiveWithInfer): | class Softsign(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||