| @@ -77,6 +77,7 @@ constexpr auto kFastGeLUGrad = "FastGeLUGrad"; | |||
| constexpr auto kZerosLike = "ZerosLike"; | |||
| constexpr auto kOnesLike = "OnesLike"; | |||
| constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; | |||
| constexpr auto kTranspose = "Transpose"; | |||
| // NN | |||
| constexpr auto kCTCLoss = "CTCLoss"; | |||
| @@ -156,7 +157,7 @@ inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast"); | |||
| inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat"); | |||
| inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze"); | |||
| inline const PrimitivePtr kPrimUnsqueeze = std::make_shared<Primitive>("Unsqueeze"); | |||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose"); | |||
| inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>(kTranspose); | |||
| inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2"); | |||
| inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD"); | |||
| inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>("Gather"); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -15,12 +15,79 @@ | |||
| */ | |||
| #include "ops/transpose.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| REGISTER_PRIMITIVE_C(kNameTranspose, Transpose); | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto op_name = primitive->name(); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape]; | |||
| auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape]; | |||
| ShapeVector p_value; | |||
| if (input_args.size() == 1) { | |||
| ValuePtr perm = primitive->GetAttr("perm"); | |||
| auto perm_val = perm->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(perm_val); | |||
| auto perm_val_data = perm_val->value(); | |||
| (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(p_value), | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| } else { | |||
| p_value = CheckAndConvertUtils::CheckAttrTupleInt("shape", input_args[1]->BuildValue(), op_name); | |||
| } | |||
| if (x_shape.size() != p_value.size()) { | |||
| MS_EXCEPTION(ValueError) << "The dimension of x " << x_shape.size() << " and perm " << p_value.size() | |||
| << " must be equal."; | |||
| } | |||
| for (auto i : p_value) { | |||
| CheckAndConvertUtils::CheckInteger("perm element", i, kGreaterEqual, 0, op_name); | |||
| CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, p_value.size(), op_name); | |||
| } | |||
| std::vector<int64_t> tmp(p_value); | |||
| for (auto it = tmp.begin(); it != tmp.end();) { | |||
| auto dim = *it; | |||
| if (!tmp.empty()) { | |||
| it = tmp.erase(it); | |||
| } | |||
| if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) { | |||
| MS_EXCEPTION(ValueError) << "The value of perm is wrong"; | |||
| } | |||
| } | |||
| std::vector<int64_t> in_shape(p_value); | |||
| std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](int i) { return x_shape[i]; }); | |||
| if (!x_min_shape.empty() && !x_max_shape.empty()) { | |||
| std::vector<int64_t> min_shape; | |||
| std::vector<int64_t> max_shape; | |||
| for (auto i : p_value) { | |||
| min_shape.push_back(x_min_shape[i]); | |||
| max_shape.push_back(x_max_shape[i]); | |||
| } | |||
| return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); | |||
| } else { | |||
| return std::make_shared<abstract::Shape>(in_shape); | |||
| } | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {kTensorType}, prim->name()); | |||
| } | |||
| } // namespace | |||
| AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| CheckAndConvertUtils::CheckInteger("Transpose infer", input_args.size(), kGreaterEqual, 1, primitive->name()); | |||
| auto abs = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); | |||
| return abs; | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -16,20 +16,25 @@ | |||
| #ifndef MINDSPORE_CORE_OPS_TRANSPOSE_H_ | |||
| #define MINDSPORE_CORE_OPS_TRANSPOSE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameTranspose = "Transpose"; | |||
| constexpr auto kNameTranspose = prim::kTranspose; | |||
| class Transpose : public PrimitiveC { | |||
| public: | |||
| Transpose() : PrimitiveC(kNameTranspose) { InitIOName({"x", "perm"}, {"output"}); } | |||
| Transpose() : PrimitiveC(prim::kTranspose) { InitIOName({"x", "perm"}, {"output"}); } | |||
| ~Transpose() = default; | |||
| MS_DECLARE_PARENT(Transpose, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimitiveTransposePtr = std::shared_ptr<Transpose>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -627,6 +627,25 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::str | |||
| return result; | |||
| } | |||
| std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr, | |||
| const std::string &prim_name) { | |||
| std::vector<int64_t> result; | |||
| MS_EXCEPTION_IF_NULL(attr); | |||
| if (attr->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value(); | |||
| (void)std::transform( | |||
| attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t { | |||
| if (!e->isa<Int64Imm>()) { | |||
| MS_EXCEPTION(TypeError) << "For " << prim_name << ", the element type of" << arg_name << " must be Int64"; | |||
| } | |||
| return GetValue<int64_t>(e); | |||
| }); | |||
| } else { | |||
| MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Tuple"; | |||
| } | |||
| return result; | |||
| } | |||
| void CheckAndConvertUtils::CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) { | |||
| *min_shape = (*min_shape).empty() ? shape : *min_shape; | |||
| *max_shape = (*max_shape).empty() ? shape : *max_shape; | |||
| @@ -303,6 +303,8 @@ class CheckAndConvertUtils { | |||
| static void CheckMode(const std::string &class_name); | |||
| static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr, | |||
| const std::string &arg_name); | |||
| static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr, | |||
| const std::string &arg_name); | |||
| static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); | |||
| static int64_t GetAndCheckFormat(const ValuePtr &value); | |||
| static int64_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list); | |||
| @@ -684,7 +684,7 @@ class Squeeze(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class Transpose(PrimitiveWithInfer): | |||
| class Transpose(Primitive): | |||
| """ | |||
| Permutes the dimensions of the input tensor according to input permutation. | |||
| @@ -725,37 +725,6 @@ class Transpose(PrimitiveWithInfer): | |||
| """Initialize Transpose""" | |||
| self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) | |||
| def __infer__(self, x, perm): | |||
| x_shape = x['shape'] | |||
| p_value = perm['value'] | |||
| x_type = x['dtype'] | |||
| validator.check_value_type("p_value", p_value, [tuple], self.name) | |||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | |||
| if len(x_shape) != len(p_value): | |||
| raise ValueError('The dimension of x and perm must be equal.') | |||
| tmp = list(p_value) | |||
| for i, dim in enumerate(p_value): | |||
| validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name) | |||
| validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name) | |||
| tmp.remove(dim) | |||
| if dim in tmp: | |||
| raise ValueError('The value of perm is wrong.') | |||
| out_shapes = [] | |||
| for i in p_value: | |||
| out_shapes.append(x_shape[i]) | |||
| out = {'shape': tuple(out_shapes), | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| if 'min_shape' in x and 'max_shape' in x: | |||
| min_vec = [] | |||
| max_vec = [] | |||
| for i in p_value: | |||
| min_vec.append(x['min_shape'][i]) | |||
| max_vec.append(x['max_shape'][i]) | |||
| out['min_shape'] = tuple(min_vec) | |||
| out['max_shape'] = tuple(max_vec) | |||
| return out | |||
| class Unique(Primitive): | |||
| """ | |||