From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qhpull/14948/MERGE
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ops/dtype.h" | |||||
| #include <string> | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "ops/op_utils.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "abstract/primitive_infer_map.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||||
| const AbstractBasePtr &infer) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto op_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||||
| const std::set<TypePtr> valid_types = {kTensorType}; | |||||
| auto type = | |||||
| CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name); | |||||
| return type; | |||||
| } | |||||
| AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| auto value = DTypeInferValue(primitive, input_args, nullptr); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| auto type = value->cast<TypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| auto abstract = std::make_shared<abstract::AbstractType>(type); | |||||
| return abstract; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false); | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CORE_OPS_DTYPE_H_ | |||||
| #define MINDSPORE_CORE_OPS_DTYPE_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ops/primitive_c.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace ops { | |||||
| class DType : public PrimitiveC { | |||||
| public: | |||||
| DType() : PrimitiveC(prim::kPrimDType->name()) { InitIOName({"x"}, {"output"}); } | |||||
| ~DType() = default; | |||||
| MS_DECLARE_PARENT(DType, PrimitiveC); | |||||
| void Init() {} | |||||
| }; | |||||
| using PrimDTypePtr = std::shared_ptr<DType>; | |||||
| } // namespace ops | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CORE_OPS_DTYPE_H_ | |||||
| @@ -31,7 +31,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| // infer shape | // infer shape | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto op_name = primitive->name(); | auto op_name = primitive->name(); | ||||
| auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); | |||||
| CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||||
| auto in_shape = shape_map[kShape]; | |||||
| // infer type | // infer type | ||||
| AbstractBasePtrList abs_list; | AbstractBasePtrList abs_list; | ||||
| std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), | std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list), | ||||
| @@ -39,9 +42,20 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| return std::make_shared<abstract::AbstractScalar>(item); | return std::make_shared<abstract::AbstractScalar>(item); | ||||
| }); | }); | ||||
| auto abs = std::make_shared<abstract::AbstractTuple>(abs_list); | auto abs = std::make_shared<abstract::AbstractTuple>(abs_list); | ||||
| abs->set_value(MakeValue(in_shape)); | |||||
| return abs; | return abs; | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameShape, Shape); | |||||
| ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, | |||||
| const AbstractBasePtr &infer) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto op_name = primitive->name(); | |||||
| CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||||
| auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); | |||||
| auto inshape = shape_map[kShape]; | |||||
| auto value = MakeValue(inshape); | |||||
| return value; | |||||
| } | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, false); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,17 +26,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameShape = "Shape"; | |||||
| class Shape : public PrimitiveC { | class Shape : public PrimitiveC { | ||||
| public: | public: | ||||
| Shape() : PrimitiveC(kNameShape) {} | |||||
| Shape() : PrimitiveC(prim::kPrimShape->name()) {} | |||||
| ~Shape() = default; | ~Shape() = default; | ||||
| MS_DECLARE_PARENT(Shape, PrimitiveC); | MS_DECLARE_PARENT(Shape, PrimitiveC); | ||||
| void Init() {} | void Init() {} | ||||
| }; | }; | ||||
| AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimShapePtr = std::shared_ptr<Shape>; | using PrimShapePtr = std::shared_ptr<Shape>; | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -197,7 +197,7 @@ class ExpandDims(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class DType(PrimitiveWithInfer): | |||||
| class DType(Primitive): | |||||
| """ | """ | ||||
| Returns the data type of the input tensor as mindspore.dtype. | Returns the data type of the input tensor as mindspore.dtype. | ||||
| @@ -224,14 +224,6 @@ class DType(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize DType""" | """Initialize DType""" | ||||
| def __infer__(self, x): | |||||
| addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.' | |||||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info) | |||||
| out = {'shape': (), | |||||
| 'dtype': mstype.type_type, | |||||
| 'value': x['dtype'].element_type()} | |||||
| return out | |||||
| class SameTypeShape(PrimitiveWithInfer): | class SameTypeShape(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -549,7 +541,7 @@ class Reshape(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class Shape(PrimitiveWithInfer): | |||||
| class Shape(Primitive): | |||||
| """ | """ | ||||
| Returns the shape of the input tensor. | Returns the shape of the input tensor. | ||||
| @@ -578,13 +570,6 @@ class Shape(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize Shape""" | """Initialize Shape""" | ||||
| def __infer__(self, x): | |||||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) | |||||
| out = {'shape': (), | |||||
| 'dtype': mstype.tuple_, | |||||
| 'value': tuple(x['shape'])} | |||||
| return out | |||||
| class DynamicShape(Primitive): | class DynamicShape(Primitive): | ||||
| """ | """ | ||||
| @@ -129,7 +129,7 @@ class Flatten(PrimitiveWithInfer): | |||||
| return input_x | return input_x | ||||
| class Softmax(PrimitiveWithInfer): | |||||
| class Softmax(Primitive): | |||||
| r""" | r""" | ||||
| Softmax operation. | Softmax operation. | ||||