Merge pull request !3555 from fary86/support_dynamic_min_max_shapetags/v0.7.0-beta
| @@ -113,6 +113,8 @@ inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransDat | |||||
| inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | ||||
| inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | ||||
| inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | ||||
| inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique"); | |||||
| inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad"); | |||||
| // NN | // NN | ||||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | ||||
| @@ -148,5 +148,47 @@ AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| ret->set_shape(std::make_shared<Shape>(shape)); | ret->set_shape(std::make_shared<Shape>(shape)); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // inputs: a 1-d Tensor | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto shape = input->shape(); | |||||
| if (shape->shape().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; | |||||
| } | |||||
| std::vector<int> ids_shape = {Shape::SHP_ANY}; | |||||
| std::vector<int> min_shape = {1}; | |||||
| std::vector<int> max_shape = shape->shape(); | |||||
| auto ids = | |||||
| std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); | |||||
| auto ids_idx = std::make_shared<AbstractTensor>(std::make_shared<Int>(32), shape->shape()); | |||||
| // outputs: ids, ids_idx | |||||
| AbstractBasePtrList elements = {ids, ids_idx}; | |||||
| return std::make_shared<AbstractTuple>(elements); | |||||
| } | |||||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // inputs: a 1-d Tensor | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| AbstractTuplePtr dout = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||||
| CheckArgsSize(op_name + " dout", dout->elements(), 2); | |||||
| auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0); | |||||
| auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1); | |||||
| if (ids->shape()->shape().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1."; | |||||
| } | |||||
| if (ids_idx->shape()->shape().size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1."; | |||||
| } | |||||
| // outputs: dx | |||||
| return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape()); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <unordered_set> | |||||
| #include "frontend/operator/cc_implementations.h" | #include "frontend/operator/cc_implementations.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| @@ -62,6 +63,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | ||||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | ||||
| {prim::kPrimPack, {InferImplPack, true}}, | {prim::kPrimPack, {InferImplPack, true}}, | ||||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||||
| // Structure | // Structure | ||||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | ||||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | {prim::kPrimMakeList, {InferImplMakeList, true}}, | ||||
| @@ -389,6 +392,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| if (abs_base->isa<AbstractTensor>()) { | if (abs_base->isa<AbstractTensor>()) { | ||||
| auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | ||||
| dic["shape"] = arg_tensor->shape()->shape(); | dic["shape"] = arg_tensor->shape()->shape(); | ||||
| if (MsContext::GetInstance()->execution_mode() == kGraphMode) { | |||||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | |||||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | |||||
| if (!min_shape.empty() && !max_shape.empty()) { | |||||
| dic["min_shape"] = min_shape; | |||||
| dic["max_shape"] = max_shape; | |||||
| } | |||||
| } | |||||
| dic["dtype"] = arg_tensor->BuildType(); | dic["dtype"] = arg_tensor->BuildType(); | ||||
| dic["value"] = BuildValue(arg_tensor->BuildValue()); | dic["value"] = BuildValue(arg_tensor->BuildValue()); | ||||
| } else if (abs_base->isa<AbstractIndexedSlices>()) { | } else if (abs_base->isa<AbstractIndexedSlices>()) { | ||||
| @@ -503,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||||
| if (output["value"].is_none()) { | if (output["value"].is_none()) { | ||||
| auto out_shape = output["shape"]; | auto out_shape = output["shape"]; | ||||
| auto out_dtype = output["dtype"]; | auto out_dtype = output["dtype"]; | ||||
| return PyListDtype2AbstractTensor(out_shape, out_dtype); | |||||
| py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); | |||||
| py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); | |||||
| return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); | |||||
| } | } | ||||
| // Convert pyobject to Value, then to AbstractValue | // Convert pyobject to Value, then to AbstractValue | ||||
| ValuePtr converted_ret = nullptr; | ValuePtr converted_ret = nullptr; | ||||
| @@ -244,6 +244,10 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| @@ -371,7 +371,8 @@ py::object VectorRefToPyData(const VectorRef &value_list) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) { | |||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||||
| const py::object &min_shape, const py::object &max_shape) { | |||||
| if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) { | if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) { | ||||
| auto ret_vec = shape_obj.cast<std::vector<int>>(); | auto ret_vec = shape_obj.cast<std::vector<int>>(); | ||||
| auto ret_dtype = type_obj.cast<TypePtr>(); | auto ret_dtype = type_obj.cast<TypePtr>(); | ||||
| @@ -382,12 +383,23 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py | |||||
| return abs_scalar; | return abs_scalar; | ||||
| } | } | ||||
| AbstractBasePtr tensor = nullptr; | AbstractBasePtr tensor = nullptr; | ||||
| std::vector<int> min_shape_vec; | |||||
| std::vector<int> max_shape_vec; | |||||
| if (!min_shape.is_none()) { | |||||
| min_shape_vec = min_shape.cast<std::vector<int>>(); | |||||
| } | |||||
| if (!max_shape.is_none()) { | |||||
| max_shape_vec = max_shape.cast<std::vector<int>>(); | |||||
| } | |||||
| auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec); | |||||
| if (ret_dtype->isa<TensorType>()) { | if (ret_dtype->isa<TensorType>()) { | ||||
| auto tensor_type = type_obj.cast<TensorTypePtr>(); | auto tensor_type = type_obj.cast<TensorTypePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(tensor_type); | MS_EXCEPTION_IF_NULL(tensor_type); | ||||
| tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec); | |||||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element()); | |||||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||||
| } else { | } else { | ||||
| tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec); | |||||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype); | |||||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||||
| } | } | ||||
| return tensor; | return tensor; | ||||
| } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) { | } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) { | ||||
| @@ -47,7 +47,9 @@ bool BaseRefToInt(const ValuePtr &v, int *value); | |||||
| bool ValueToBool(const ValuePtr &in, bool *out); | bool ValueToBool(const ValuePtr &in, bool *out); | ||||
| py::object ValuePtrToPyData(const ValuePtr &value); | py::object ValuePtrToPyData(const ValuePtr &value); | ||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj); | |||||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||||
| const py::object &min_shape = py::none(), | |||||
| const py::object &max_shape = py::none()); | |||||
| bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | ||||
| const std::shared_ptr<py::object> &ret_val); | const std::shared_ptr<py::object> &ret_val); | ||||
| @@ -67,6 +67,9 @@ std::string Shape::DumpText() const { | |||||
| buffer << "["; | buffer << "["; | ||||
| for (size_t i = 0; i < shape_.size(); i++) { | for (size_t i = 0; i < shape_.size(); i++) { | ||||
| buffer << (i > 0 ? ", " : "") << shape_[i]; | buffer << (i > 0 ? ", " : "") << shape_[i]; | ||||
| if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) { | |||||
| buffer << "_" << min_shape_[i] << "^" << max_shape_[i]; | |||||
| } | |||||
| } | } | ||||
| buffer << "]"; | buffer << "]"; | ||||
| return buffer.str(); | return buffer.str(); | ||||
| @@ -74,16 +74,22 @@ class Shape : public BaseShape { | |||||
| (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_), | (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_), | ||||
| [](const int64_t &value) { return static_cast<int>(value); }); | [](const int64_t &value) { return static_cast<int>(value); }); | ||||
| } | } | ||||
| Shape(const std::vector<int> &list, const std::vector<int> &min_shape, const std::vector<int> &max_shape) | |||||
| : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {} | |||||
| ~Shape() override = default; | ~Shape() override = default; | ||||
| MS_DECLARE_PARENT(Shape, BaseShape) | MS_DECLARE_PARENT(Shape, BaseShape) | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| std::string DumpText() const override; | std::string DumpText() const override; | ||||
| bool operator==(const BaseShape &other) const override; | bool operator==(const BaseShape &other) const override; | ||||
| BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_); } | |||||
| BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_, min_shape_, max_shape_); } | |||||
| void Broaden() override; | void Broaden() override; | ||||
| std::vector<int> &shape() { return shape_; } | std::vector<int> &shape() { return shape_; } | ||||
| std::vector<int> &min_shape() { return min_shape_; } | |||||
| std::vector<int> &max_shape() { return max_shape_; } | |||||
| std::vector<int> shape_; // use SHP_ANY to implement the any shape in python | |||||
| std::vector<int> shape_; // use SHP_ANY to implement the any shape in python | |||||
| std::vector<int> min_shape_; // record mininum length for each dynamic dimention | |||||
| std::vector<int> max_shape_; // record maximum length for each dynamic dimention | |||||
| }; | }; | ||||
| using ShapePtr = std::shared_ptr<Shape>; | using ShapePtr = std::shared_ptr<Shape>; | ||||
| using ShapePtrList = std::vector<ShapePtr>; | using ShapePtrList = std::vector<ShapePtr>; | ||||
| @@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { | |||||
| return shape1; | return shape1; | ||||
| } | } | ||||
| std::vector<int> dims; | std::vector<int> dims; | ||||
| bool has_dynamic_shape = false; | |||||
| dims.resize(shape1->shape().size()); | dims.resize(shape1->shape().size()); | ||||
| for (std::size_t i = 0; i < shape1->shape().size(); i++) { | for (std::size_t i = 0; i < shape1->shape().size(); i++) { | ||||
| if (shape1->shape()[i] == shape2->shape()[i]) { | if (shape1->shape()[i] == shape2->shape()[i]) { | ||||
| dims[i] = shape1->shape()[i]; | dims[i] = shape1->shape()[i]; | ||||
| if (shape1->shape()[i] == Shape::SHP_ANY) { | |||||
| has_dynamic_shape = true; | |||||
| } | |||||
| } else { | } else { | ||||
| dims[i] = Shape::SHP_ANY; | dims[i] = Shape::SHP_ANY; | ||||
| has_dynamic_shape = true; | |||||
| } | } | ||||
| } | } | ||||
| return std::make_shared<Shape>(dims); | |||||
| if (!has_dynamic_shape) { | |||||
| return std::make_shared<Shape>(dims); | |||||
| } | |||||
| // calculate dynamic shape | |||||
| std::vector<int> min_dims(dims.size()); | |||||
| std::vector<int> max_dims(dims.size()); | |||||
| for (size_t i = 0; i < dims.size(); ++i) { | |||||
| if (dims[i] != Shape::SHP_ANY) { | |||||
| min_dims[i] = max_dims[i] = dims[i]; | |||||
| continue; | |||||
| } | |||||
| if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { | |||||
| min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]); | |||||
| max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]); | |||||
| continue; | |||||
| } | |||||
| if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { | |||||
| if (shape1->min_shape().empty() || shape1->max_shape().empty()) { | |||||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||||
| << " has dynamic shape, but does not have min/max shape info."; | |||||
| } | |||||
| min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]); | |||||
| max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]); | |||||
| continue; | |||||
| } | |||||
| if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) { | |||||
| if (shape2->min_shape().empty() || shape2->max_shape().empty()) { | |||||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||||
| << " has dynamic shape, but does not have min/max shape info."; | |||||
| } | |||||
| min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]); | |||||
| max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]); | |||||
| continue; | |||||
| } | |||||
| // both shapes contains dynamic shape | |||||
| if (shape1->min_shape().empty() || shape1->max_shape().empty()) { | |||||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||||
| << " has dynamic shape, but does not have min/max shape info."; | |||||
| } | |||||
| if (shape2->min_shape().empty() || shape2->max_shape().empty()) { | |||||
| MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString() | |||||
| << " has dynamic shape, but does not have min/max shape info."; | |||||
| } | |||||
| min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]); | |||||
| max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]); | |||||
| } | |||||
| return std::make_shared<Shape>(dims, min_dims, max_dims); | |||||
| } | } | ||||
| AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { | ||||
| @@ -807,3 +807,23 @@ def get_bprop_trans_shape(self): | |||||
| dx = op(dout, shape_op(x)) | dx = op(dout, shape_op(x)) | ||||
| return (dx, zeros_like(shape)) | return (dx, zeros_like(shape)) | ||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.Unique) | |||||
| def get_bprop_unique(self): | |||||
| """Generate bprop for Unique""" | |||||
| op = G.UniqueGrad() | |||||
| def bprop(x, out, dout): | |||||
| dx = op(dout, out) | |||||
| return (dx,) | |||||
| return bprop | |||||
| @bprop_getters.register(P.UnsortedSegmentSum) | |||||
| def get_bprop_unsorted_segment_sum(self): | |||||
| """Generate bprop for UnsortedSegmentSum""" | |||||
| op = G.UnsortedSegmentSumGrad() | |||||
| def bprop(x, segment_ids, num_segments, out, dout): | |||||
| dx = op(dout, segment_ids) | |||||
| return (dx, zeros_like(segment_ids), zeros_like(num_segments)) | |||||
| return bprop | |||||
| @@ -82,5 +82,8 @@ def get_concat_offset(x_shp, x_type, axis, prim_name): | |||||
| if j != axis and v[j] != x_shp[0][j]: | if j != axis and v[j] != x_shp[0][j]: | ||||
| raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") | raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") | ||||
| offset.append(all_shp) | offset.append(all_shp) | ||||
| all_shp += v[axis] | |||||
| if all_shp == -1 or v[axis] == -1: | |||||
| all_shp = -1 | |||||
| else: | |||||
| all_shp += v[axis] | |||||
| return offset, all_shp, axis | return offset, all_shp, axis | ||||
| @@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | Squeeze, StridedSlice, Tile, TensorScatterUpdate, | ||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup) | |||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||||
| Unique) | |||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice, | _VirtualDiv, _GetTensorSlice, | ||||
| @@ -491,6 +491,31 @@ class FusedBatchNormGrad(Primitive): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class UniqueGrad(Primitive): | |||||
| """Gradients of Unique operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx']) | |||||
| def __call__(self, dy, x, scale, save_mean, save_inv_variance): | |||||
| raise NotImplementedError | |||||
| class UnsortedSegmentSumGrad(PrimitiveWithInfer): | |||||
| """Gradients of UnsortedSegmentSum operation.""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y']) | |||||
| def infer_shape(self, grads, ids): | |||||
| return ids + grads[len(ids):] | |||||
| def infer_dtype(self, grads, ids): | |||||
| return grads | |||||
| class BNTrainingReduceGrad(PrimitiveWithInfer): | class BNTrainingReduceGrad(PrimitiveWithInfer): | ||||
| """Gradients of FusedBatchNorm operation.""" | """Gradients of FusedBatchNorm operation.""" | ||||
| @@ -27,7 +27,7 @@ import numpy as np | |||||
| from .._utils import get_concat_offset | from .._utils import get_concat_offset | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | |||||
| from ..._c_expression import signature_dtype as sig_dtype | from ..._c_expression import signature_dtype as sig_dtype | ||||
| from ..._c_expression import signature_kind as sig_kind | from ..._c_expression import signature_kind as sig_kind | ||||
| from ..._c_expression import signature_rw as sig_rw | from ..._c_expression import signature_rw as sig_rw | ||||
| @@ -556,6 +556,28 @@ class Transpose(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class Unique(Primitive): | |||||
| """ | |||||
| Returns the unique elements of input tensor and also return a tensor containing the index of each value of input | |||||
| tensor corresponding to the output unique tensor. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The input tensor. | |||||
| Outputs: | |||||
| Tuple, containing tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor | |||||
| containing indices of elements in the input coressponding to the output tensor. | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.float32) | |||||
| >>> out = P.Unique()(x) | |||||
| (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.float32)) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||||
| class GatherV2(PrimitiveWithInfer): | class GatherV2(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Returns a slice of input tensor based on the specified indices and axis. | Returns a slice of input tensor based on the specified indices and axis. | ||||
| @@ -20,6 +20,7 @@ import copy | |||||
| from mindspore.common.api import _wrap_func | from mindspore.common.api import _wrap_func | ||||
| from mindspore.common import Parameter | from mindspore.common import Parameter | ||||
| from mindspore.common._register_for_tensor import tensor_operator_registry | from mindspore.common._register_for_tensor import tensor_operator_registry | ||||
| from mindspore import context | |||||
| from .._c_expression import Primitive_, real_run_op, prim_type | from .._c_expression import Primitive_, real_run_op, prim_type | ||||
| from .._c_expression import signature_rw as sig_rw | from .._c_expression import signature_rw as sig_rw | ||||
| from .._c_expression import signature_kind as sig_kind | from .._c_expression import signature_kind as sig_kind | ||||
| @@ -138,6 +139,8 @@ class Primitive(Primitive_): | |||||
| return self | return self | ||||
| def __getattr__(self, item): | def __getattr__(self, item): | ||||
| if item == 'infer_dynamic_shape': | |||||
| return None | |||||
| if item in super().get_attr_dict(): | if item in super().get_attr_dict(): | ||||
| return super().get_attr_dict()[item] | return super().get_attr_dict()[item] | ||||
| if item in self.attrs: | if item in self.attrs: | ||||
| @@ -282,13 +285,49 @@ class PrimitiveWithInfer(Primitive): | |||||
| def __infer__(self, *args): | def __infer__(self, *args): | ||||
| """Infer shape, type, and value at the same time by using dictionary as arguments.""" | """Infer shape, type, and value at the same time by using dictionary as arguments.""" | ||||
| is_graph_mode = context.get_context("mode") == context.GRAPH_MODE | |||||
| fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None) | |||||
| if is_graph_mode and fn_infer_dynamic_shape is not None: | |||||
| out = fn_infer_dynamic_shape(*args) | |||||
| tracks = ['dtype', 'value'] | |||||
| for track in tracks: | |||||
| fn = getattr(self, 'infer_' + track) | |||||
| # fn may return None | |||||
| out[track] = fn(*(x[track] for x in args)) | |||||
| return out | |||||
| tracks = ['dtype', 'shape', 'value'] | tracks = ['dtype', 'shape', 'value'] | ||||
| out = {} | out = {} | ||||
| for track in tracks: | for track in tracks: | ||||
| fn = getattr(self, 'infer_' + track) | fn = getattr(self, 'infer_' + track) | ||||
| # fn may return None | # fn may return None | ||||
| out[track] = fn(*(x[track] for x in args)) | out[track] = fn(*(x[track] for x in args)) | ||||
| return out | |||||
| # in non-graph_mode, it is not necessary to infer min/max shape | |||||
| if not is_graph_mode: | |||||
| return out | |||||
| def get_specified_shape(elems, attr): | |||||
| has_specified_shape = False | |||||
| ret_vals = [] | |||||
| for elem in elems: | |||||
| if attr in elem: | |||||
| has_specified_shape = True | |||||
| ret_vals.append(elem[attr]) | |||||
| else: | |||||
| ret_vals.append(elem['shape']) | |||||
| return has_specified_shape, tuple(ret_vals) | |||||
| has_min_shape, min_shapes = get_specified_shape(args, 'min_shape') | |||||
| has_max_shape, max_shapes = get_specified_shape(args, 'max_shape') | |||||
| if not (has_min_shape or has_max_shape): | |||||
| return out | |||||
| if has_min_shape and has_max_shape: | |||||
| fn_infer_shape = getattr(self, 'infer_shape') | |||||
| out['min_shape'] = fn_infer_shape(*min_shapes) | |||||
| out['max_shape'] = fn_infer_shape(*max_shapes) | |||||
| return out | |||||
| raise ValueError('Input args has invalid dynamic shape, args info: {args}') | |||||
| def prim_attr_register(fn): | def prim_attr_register(fn): | ||||