Browse Source

Transpose

tags/v1.3.0
shen_jingxing 4 years ago
parent
commit
413862059b
6 changed files with 101 additions and 38 deletions
  1. +2
    -1
      mindspore/core/base/core_ops.h
  2. +69
    -2
      mindspore/core/ops/transpose.cc
  3. +8
    -3
      mindspore/core/ops/transpose.h
  4. +19
    -0
      mindspore/core/utils/check_convert_utils.cc
  5. +2
    -0
      mindspore/core/utils/check_convert_utils.h
  6. +1
    -32
      mindspore/ops/operations/array_ops.py

+ 2
- 1
mindspore/core/base/core_ops.h View File

@@ -77,6 +77,7 @@ constexpr auto kFastGeLUGrad = "FastGeLUGrad";
constexpr auto kZerosLike = "ZerosLike";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
constexpr auto kTranspose = "Transpose";

// NN
constexpr auto kCTCLoss = "CTCLoss";
@@ -156,7 +157,7 @@ inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimUnsqueeze = std::make_shared<Primitive>("Unsqueeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>(kTranspose);
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD");
inline const PrimitivePtr kPrimGather = std::make_shared<Primitive>("Gather");


+ 69
- 2
mindspore/core/ops/transpose.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,12 +15,79 @@
*/

#include "ops/transpose.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameTranspose, Transpose);
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
auto x_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
ShapeVector p_value;
if (input_args.size() == 1) {
ValuePtr perm = primitive->GetAttr("perm");
auto perm_val = perm->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(perm_val);
auto perm_val_data = perm_val->value();
(void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(p_value),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
} else {
p_value = CheckAndConvertUtils::CheckAttrTupleInt("shape", input_args[1]->BuildValue(), op_name);
}
if (x_shape.size() != p_value.size()) {
MS_EXCEPTION(ValueError) << "The dimension of x " << x_shape.size() << " and perm " << p_value.size()
<< " must be equal.";
}
for (auto i : p_value) {
CheckAndConvertUtils::CheckInteger("perm element", i, kGreaterEqual, 0, op_name);
CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, p_value.size(), op_name);
}
std::vector<int64_t> tmp(p_value);
for (auto it = tmp.begin(); it != tmp.end();) {
auto dim = *it;
if (!tmp.empty()) {
it = tmp.erase(it);
}
if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) {
MS_EXCEPTION(ValueError) << "The value of perm is wrong";
}
}
std::vector<int64_t> in_shape(p_value);
std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](int i) { return x_shape[i]; });
if (!x_min_shape.empty() && !x_max_shape.empty()) {
std::vector<int64_t> min_shape;
std::vector<int64_t> max_shape;
for (auto i : p_value) {
min_shape.push_back(x_min_shape[i]);
max_shape.push_back(x_max_shape[i]);
}
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
} else {
return std::make_shared<abstract::Shape>(in_shape);
}
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
return CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {kTensorType}, prim->name());
}
} // namespace

AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInteger("Transpose infer", input_args.size(), kGreaterEqual, 1, primitive->name());
auto abs = abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 8
- 3
mindspore/core/ops/transpose.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,20 +16,25 @@

#ifndef MINDSPORE_CORE_OPS_TRANSPOSE_H_
#define MINDSPORE_CORE_OPS_TRANSPOSE_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameTranspose = "Transpose";
constexpr auto kNameTranspose = prim::kTranspose;
class Transpose : public PrimitiveC {
public:
Transpose() : PrimitiveC(kNameTranspose) { InitIOName({"x", "perm"}, {"output"}); }
Transpose() : PrimitiveC(prim::kTranspose) { InitIOName({"x", "perm"}, {"output"}); }
~Transpose() = default;
MS_DECLARE_PARENT(Transpose, PrimitiveC);
void Init() {}
};
AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimitiveTransposePtr = std::shared_ptr<Transpose>;
} // namespace ops
} // namespace mindspore



+ 19
- 0
mindspore/core/utils/check_convert_utils.cc View File

@@ -627,6 +627,25 @@ std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::str
return result;
}

std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr,
const std::string &prim_name) {
std::vector<int64_t> result;
MS_EXCEPTION_IF_NULL(attr);
if (attr->isa<ValueTuple>()) {
std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
(void)std::transform(
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
if (!e->isa<Int64Imm>()) {
MS_EXCEPTION(TypeError) << "For " << prim_name << ", the element type of" << arg_name << " must be Int64";
}
return GetValue<int64_t>(e);
});
} else {
MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Tuple";
}
return result;
}

void CheckAndConvertUtils::CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
*min_shape = (*min_shape).empty() ? shape : *min_shape;
*max_shape = (*max_shape).empty() ? shape : *max_shape;


+ 2
- 0
mindspore/core/utils/check_convert_utils.h View File

@@ -303,6 +303,8 @@ class CheckAndConvertUtils {
static void CheckMode(const std::string &class_name);
static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
const std::string &arg_name);
static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr,
const std::string &arg_name);
static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
static int64_t GetAndCheckFormat(const ValuePtr &value);
static int64_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);


+ 1
- 32
mindspore/ops/operations/array_ops.py View File

@@ -684,7 +684,7 @@ class Squeeze(PrimitiveWithInfer):
return x_dtype


class Transpose(PrimitiveWithInfer):
class Transpose(Primitive):
"""
Permutes the dimensions of the input tensor according to input permutation.

@@ -725,37 +725,6 @@ class Transpose(PrimitiveWithInfer):
"""Initialize Transpose"""
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])

def __infer__(self, x, perm):
x_shape = x['shape']
p_value = perm['value']
x_type = x['dtype']
validator.check_value_type("p_value", p_value, [tuple], self.name)
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
if len(x_shape) != len(p_value):
raise ValueError('The dimension of x and perm must be equal.')
tmp = list(p_value)
for i, dim in enumerate(p_value):
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
tmp.remove(dim)
if dim in tmp:
raise ValueError('The value of perm is wrong.')
out_shapes = []
for i in p_value:
out_shapes.append(x_shape[i])
out = {'shape': tuple(out_shapes),
'dtype': x['dtype'],
'value': None}
if 'min_shape' in x and 'max_shape' in x:
min_vec = []
max_vec = []
for i in p_value:
min_vec.append(x['min_shape'][i])
max_vec.append(x['max_shape'][i])
out['min_shape'] = tuple(min_vec)
out['max_shape'] = tuple(max_vec)
return out


class Unique(Primitive):
"""


Loading…
Cancel
Save