| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/dtype.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||
| const AbstractBasePtr &infer) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| const std::set<TypePtr> valid_types = {kTensorType}; | |||
| auto type = | |||
| CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name); | |||
| return type; | |||
| } | |||
| AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto value = DTypeInferValue(primitive, input_args, nullptr); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto type = value->cast<TypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| auto abstract = std::make_shared<abstract::AbstractType>(type); | |||
| return abstract; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_DTYPE_H_ | |||
| #define MINDSPORE_CORE_OPS_DTYPE_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| class DType : public PrimitiveC { | |||
| public: | |||
| DType() : PrimitiveC(prim::kPrimDType->name()) { InitIOName({"x"}, {"output"}); } | |||
| ~DType() = default; | |||
| MS_DECLARE_PARENT(DType, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| using PrimDTypePtr = std::shared_ptr<DType>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_DTYPE_H_ | |||
| @@ -31,7 +31,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||
| // infer shape | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||
| CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||
| auto in_shape = shape_map[kShape]; | |||
| // infer type | |||
| AbstractBasePtrList abs_list; | |||
| std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), | |||
| @@ -39,9 +42,20 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||
| return std::make_shared<abstract::AbstractScalar>(item); | |||
| }); | |||
| auto abs = std::make_shared<abstract::AbstractTuple>(abs_list); | |||
| abs->set_value(MakeValue(in_shape)); | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameShape, Shape); | |||
| ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||
| const AbstractBasePtr &infer) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||
| auto inshape = shape_map[kShape]; | |||
| auto value = MakeValue(inshape); | |||
| return value; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, false); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -26,17 +26,13 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameShape = "Shape"; | |||
| class Shape : public PrimitiveC { | |||
| public: | |||
| Shape() : PrimitiveC(kNameShape) {} | |||
| Shape() : PrimitiveC(prim::kPrimShape->name()) {} | |||
| ~Shape() = default; | |||
| MS_DECLARE_PARENT(Shape, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimShapePtr = std::shared_ptr<Shape>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -197,7 +197,7 @@ class ExpandDims(PrimitiveWithInfer): | |||
| return out | |||
| class DType(PrimitiveWithInfer): | |||
| class DType(Primitive): | |||
| """ | |||
| Returns the data type of the input tensor as mindspore.dtype. | |||
| @@ -224,14 +224,6 @@ class DType(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| """Initialize DType""" | |||
| def __infer__(self, x): | |||
| addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.' | |||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info) | |||
| out = {'shape': (), | |||
| 'dtype': mstype.type_type, | |||
| 'value': x['dtype'].element_type()} | |||
| return out | |||
| class SameTypeShape(PrimitiveWithInfer): | |||
| """ | |||
| @@ -549,7 +541,7 @@ class Reshape(PrimitiveWithInfer): | |||
| return out | |||
| class Shape(PrimitiveWithInfer): | |||
| class Shape(Primitive): | |||
| """ | |||
| Returns the shape of the input tensor. | |||
| @@ -578,13 +570,6 @@ class Shape(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| """Initialize Shape""" | |||
| def __infer__(self, x): | |||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) | |||
| out = {'shape': (), | |||
| 'dtype': mstype.tuple_, | |||
| 'value': tuple(x['shape'])} | |||
| return out | |||
| class DynamicShape(Primitive): | |||
| """ | |||
| @@ -129,7 +129,7 @@ class Flatten(PrimitiveWithInfer): | |||
| return input_x | |||
| class Softmax(PrimitiveWithInfer): | |||
| class Softmax(Primitive): | |||
| r""" | |||
| Softmax operation. | |||