Merge pull request !3555 from fary86/support_dynamic_min_max_shapetags/v0.7.0-beta
| @@ -113,6 +113,8 @@ inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransDat | |||
| inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); | |||
| inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); | |||
| inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue"); | |||
| inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique"); | |||
| inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -148,5 +148,47 @@ AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| ret->set_shape(std::make_shared<Shape>(shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // inputs: a 1-d Tensor | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto shape = input->shape(); | |||
| if (shape->shape().size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; | |||
| } | |||
| std::vector<int> ids_shape = {Shape::SHP_ANY}; | |||
| std::vector<int> min_shape = {1}; | |||
| std::vector<int> max_shape = shape->shape(); | |||
| auto ids = | |||
| std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape)); | |||
| auto ids_idx = std::make_shared<AbstractTensor>(std::make_shared<Int>(32), shape->shape()); | |||
| // outputs: ids, ids_idx | |||
| AbstractBasePtrList elements = {ids, ids_idx}; | |||
| return std::make_shared<AbstractTuple>(elements); | |||
| } | |||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // inputs: a 1-d Tensor | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTuplePtr dout = CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| CheckArgsSize(op_name + " dout", dout->elements(), 2); | |||
| auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0); | |||
| auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1); | |||
| if (ids->shape()->shape().size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1."; | |||
| } | |||
| if (ids_idx->shape()->shape().size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1."; | |||
| } | |||
| // outputs: dx | |||
| return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape()); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <unordered_set> | |||
| #include "frontend/operator/cc_implementations.h" | |||
| #include "frontend/operator/ops.h" | |||
| @@ -62,6 +63,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, | |||
| {prim::kPrimPack, {InferImplPack, true}}, | |||
| {prim::kPrimUnique, {InferImplUnique, true}}, | |||
| {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -389,6 +392,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| if (abs_base->isa<AbstractTensor>()) { | |||
| auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); | |||
| dic["shape"] = arg_tensor->shape()->shape(); | |||
| if (MsContext::GetInstance()->execution_mode() == kGraphMode) { | |||
| const auto &min_shape = arg_tensor->shape()->min_shape(); | |||
| const auto &max_shape = arg_tensor->shape()->max_shape(); | |||
| if (!min_shape.empty() && !max_shape.empty()) { | |||
| dic["min_shape"] = min_shape; | |||
| dic["max_shape"] = max_shape; | |||
| } | |||
| } | |||
| dic["dtype"] = arg_tensor->BuildType(); | |||
| dic["value"] = BuildValue(arg_tensor->BuildValue()); | |||
| } else if (abs_base->isa<AbstractIndexedSlices>()) { | |||
| @@ -503,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||
| if (output["value"].is_none()) { | |||
| auto out_shape = output["shape"]; | |||
| auto out_dtype = output["dtype"]; | |||
| return PyListDtype2AbstractTensor(out_shape, out_dtype); | |||
| py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none(); | |||
| py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none(); | |||
| return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape); | |||
| } | |||
| // Convert pyobject to Value, then to AbstractValue | |||
| ValuePtr converted_ret = nullptr; | |||
| @@ -244,6 +244,10 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| @@ -371,7 +371,8 @@ py::object VectorRefToPyData(const VectorRef &value_list) { | |||
| return ret; | |||
| } | |||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) { | |||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||
| const py::object &min_shape, const py::object &max_shape) { | |||
| if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) { | |||
| auto ret_vec = shape_obj.cast<std::vector<int>>(); | |||
| auto ret_dtype = type_obj.cast<TypePtr>(); | |||
| @@ -382,12 +383,23 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py | |||
| return abs_scalar; | |||
| } | |||
| AbstractBasePtr tensor = nullptr; | |||
| std::vector<int> min_shape_vec; | |||
| std::vector<int> max_shape_vec; | |||
| if (!min_shape.is_none()) { | |||
| min_shape_vec = min_shape.cast<std::vector<int>>(); | |||
| } | |||
| if (!max_shape.is_none()) { | |||
| max_shape_vec = max_shape.cast<std::vector<int>>(); | |||
| } | |||
| auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec); | |||
| if (ret_dtype->isa<TensorType>()) { | |||
| auto tensor_type = type_obj.cast<TensorTypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||
| tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec); | |||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element()); | |||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||
| } else { | |||
| tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec); | |||
| auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype); | |||
| tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape); | |||
| } | |||
| return tensor; | |||
| } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) { | |||
| @@ -47,7 +47,9 @@ bool BaseRefToInt(const ValuePtr &v, int *value); | |||
| bool ValueToBool(const ValuePtr &in, bool *out); | |||
| py::object ValuePtrToPyData(const ValuePtr &value); | |||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj); | |||
| AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, | |||
| const py::object &min_shape = py::none(), | |||
| const py::object &max_shape = py::none()); | |||
| bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, | |||
| const std::shared_ptr<py::object> &ret_val); | |||
| @@ -67,6 +67,9 @@ std::string Shape::DumpText() const { | |||
| buffer << "["; | |||
| for (size_t i = 0; i < shape_.size(); i++) { | |||
| buffer << (i > 0 ? ", " : "") << shape_[i]; | |||
| if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) { | |||
| buffer << "_" << min_shape_[i] << "^" << max_shape_[i]; | |||
| } | |||
| } | |||
| buffer << "]"; | |||
| return buffer.str(); | |||
| @@ -74,16 +74,22 @@ class Shape : public BaseShape { | |||
| (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| } | |||
| Shape(const std::vector<int> &list, const std::vector<int> &min_shape, const std::vector<int> &max_shape) | |||
| : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {} | |||
| ~Shape() override = default; | |||
| MS_DECLARE_PARENT(Shape, BaseShape) | |||
| std::string ToString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const BaseShape &other) const override; | |||
| BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_); } | |||
| BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_, min_shape_, max_shape_); } | |||
| void Broaden() override; | |||
| std::vector<int> &shape() { return shape_; } | |||
| std::vector<int> &min_shape() { return min_shape_; } | |||
| std::vector<int> &max_shape() { return max_shape_; } | |||
| std::vector<int> shape_; // use SHP_ANY to implement the any shape in python | |||
| std::vector<int> shape_; // use SHP_ANY to implement the any shape in python | |||
| std::vector<int> min_shape_; // record mininum length for each dynamic dimention | |||
| std::vector<int> max_shape_; // record maximum length for each dynamic dimention | |||
| }; | |||
| using ShapePtr = std::shared_ptr<Shape>; | |||
| using ShapePtrList = std::vector<ShapePtr>; | |||
| @@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { | |||
| return shape1; | |||
| } | |||
| std::vector<int> dims; | |||
| bool has_dynamic_shape = false; | |||
| dims.resize(shape1->shape().size()); | |||
| for (std::size_t i = 0; i < shape1->shape().size(); i++) { | |||
| if (shape1->shape()[i] == shape2->shape()[i]) { | |||
| dims[i] = shape1->shape()[i]; | |||
| if (shape1->shape()[i] == Shape::SHP_ANY) { | |||
| has_dynamic_shape = true; | |||
| } | |||
| } else { | |||
| dims[i] = Shape::SHP_ANY; | |||
| has_dynamic_shape = true; | |||
| } | |||
| } | |||
| return std::make_shared<Shape>(dims); | |||
| if (!has_dynamic_shape) { | |||
| return std::make_shared<Shape>(dims); | |||
| } | |||
| // calculate dynamic shape | |||
| std::vector<int> min_dims(dims.size()); | |||
| std::vector<int> max_dims(dims.size()); | |||
| for (size_t i = 0; i < dims.size(); ++i) { | |||
| if (dims[i] != Shape::SHP_ANY) { | |||
| min_dims[i] = max_dims[i] = dims[i]; | |||
| continue; | |||
| } | |||
| if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { | |||
| min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]); | |||
| max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]); | |||
| continue; | |||
| } | |||
| if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { | |||
| if (shape1->min_shape().empty() || shape1->max_shape().empty()) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]); | |||
| max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]); | |||
| continue; | |||
| } | |||
| if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) { | |||
| if (shape2->min_shape().empty() || shape2->max_shape().empty()) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]); | |||
| max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]); | |||
| continue; | |||
| } | |||
| // both shapes contains dynamic shape | |||
| if (shape1->min_shape().empty() || shape1->max_shape().empty()) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| if (shape2->min_shape().empty() || shape2->max_shape().empty()) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]); | |||
| max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]); | |||
| } | |||
| return std::make_shared<Shape>(dims, min_dims, max_dims); | |||
| } | |||
| AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { | |||
| @@ -807,3 +807,23 @@ def get_bprop_trans_shape(self): | |||
| dx = op(dout, shape_op(x)) | |||
| return (dx, zeros_like(shape)) | |||
| return bprop | |||
| @bprop_getters.register(P.Unique) | |||
| def get_bprop_unique(self): | |||
| """Generate bprop for Unique""" | |||
| op = G.UniqueGrad() | |||
| def bprop(x, out, dout): | |||
| dx = op(dout, out) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.UnsortedSegmentSum) | |||
| def get_bprop_unsorted_segment_sum(self): | |||
| """Generate bprop for UnsortedSegmentSum""" | |||
| op = G.UnsortedSegmentSumGrad() | |||
| def bprop(x, segment_ids, num_segments, out, dout): | |||
| dx = op(dout, segment_ids) | |||
| return (dx, zeros_like(segment_ids), zeros_like(num_segments)) | |||
| return bprop | |||
| @@ -82,5 +82,8 @@ def get_concat_offset(x_shp, x_type, axis, prim_name): | |||
| if j != axis and v[j] != x_shp[0][j]: | |||
| raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") | |||
| offset.append(all_shp) | |||
| all_shp += v[axis] | |||
| if all_shp == -1 or v[axis] == -1: | |||
| all_shp = -1 | |||
| else: | |||
| all_shp += v[axis] | |||
| return offset, all_shp, axis | |||
| @@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup) | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| Unique) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| @@ -491,6 +491,31 @@ class FusedBatchNormGrad(Primitive): | |||
| raise NotImplementedError | |||
| class UniqueGrad(Primitive): | |||
| """Gradients of Unique operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx']) | |||
| def __call__(self, dy, x, scale, save_mean, save_inv_variance): | |||
| raise NotImplementedError | |||
| class UnsortedSegmentSumGrad(PrimitiveWithInfer): | |||
| """Gradients of UnsortedSegmentSum operation.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y']) | |||
| def infer_shape(self, grads, ids): | |||
| return ids + grads[len(ids):] | |||
| def infer_dtype(self, grads, ids): | |||
| return grads | |||
| class BNTrainingReduceGrad(PrimitiveWithInfer): | |||
| """Gradients of FusedBatchNorm operation.""" | |||
| @@ -27,7 +27,7 @@ import numpy as np | |||
| from .._utils import get_concat_offset | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | |||
| from ..._c_expression import signature_dtype as sig_dtype | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| @@ -556,6 +556,28 @@ class Transpose(PrimitiveWithInfer): | |||
| return out | |||
| class Unique(Primitive): | |||
| """ | |||
| Returns the unique elements of input tensor and also return a tensor containing the index of each value of input | |||
| tensor corresponding to the output unique tensor. | |||
| Inputs: | |||
| - **x** (Tensor) - The input tensor. | |||
| Outputs: | |||
| Tuple, containing tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor | |||
| containing indices of elements in the input coressponding to the output tensor. | |||
| Examples: | |||
| >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.float32) | |||
| >>> out = P.Unique()(x) | |||
| (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.float32)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| class GatherV2(PrimitiveWithInfer): | |||
| """ | |||
| Returns a slice of input tensor based on the specified indices and axis. | |||
| @@ -20,6 +20,7 @@ import copy | |||
| from mindspore.common.api import _wrap_func | |||
| from mindspore.common import Parameter | |||
| from mindspore.common._register_for_tensor import tensor_operator_registry | |||
| from mindspore import context | |||
| from .._c_expression import Primitive_, real_run_op, prim_type | |||
| from .._c_expression import signature_rw as sig_rw | |||
| from .._c_expression import signature_kind as sig_kind | |||
| @@ -138,6 +139,8 @@ class Primitive(Primitive_): | |||
| return self | |||
| def __getattr__(self, item): | |||
| if item == 'infer_dynamic_shape': | |||
| return None | |||
| if item in super().get_attr_dict(): | |||
| return super().get_attr_dict()[item] | |||
| if item in self.attrs: | |||
| @@ -282,13 +285,49 @@ class PrimitiveWithInfer(Primitive): | |||
| def __infer__(self, *args): | |||
| """Infer shape, type, and value at the same time by using dictionary as arguments.""" | |||
| is_graph_mode = context.get_context("mode") == context.GRAPH_MODE | |||
| fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None) | |||
| if is_graph_mode and fn_infer_dynamic_shape is not None: | |||
| out = fn_infer_dynamic_shape(*args) | |||
| tracks = ['dtype', 'value'] | |||
| for track in tracks: | |||
| fn = getattr(self, 'infer_' + track) | |||
| # fn may return None | |||
| out[track] = fn(*(x[track] for x in args)) | |||
| return out | |||
| tracks = ['dtype', 'shape', 'value'] | |||
| out = {} | |||
| for track in tracks: | |||
| fn = getattr(self, 'infer_' + track) | |||
| # fn may return None | |||
| out[track] = fn(*(x[track] for x in args)) | |||
| return out | |||
| # in non-graph_mode, it is not necessary to infer min/max shape | |||
| if not is_graph_mode: | |||
| return out | |||
| def get_specified_shape(elems, attr): | |||
| has_specified_shape = False | |||
| ret_vals = [] | |||
| for elem in elems: | |||
| if attr in elem: | |||
| has_specified_shape = True | |||
| ret_vals.append(elem[attr]) | |||
| else: | |||
| ret_vals.append(elem['shape']) | |||
| return has_specified_shape, tuple(ret_vals) | |||
| has_min_shape, min_shapes = get_specified_shape(args, 'min_shape') | |||
| has_max_shape, max_shapes = get_specified_shape(args, 'max_shape') | |||
| if not (has_min_shape or has_max_shape): | |||
| return out | |||
| if has_min_shape and has_max_shape: | |||
| fn_infer_shape = getattr(self, 'infer_shape') | |||
| out['min_shape'] = fn_infer_shape(*min_shapes) | |||
| out['max_shape'] = fn_infer_shape(*max_shapes) | |||
| return out | |||
| raise ValueError('Input args has invalid dynamic shape, args info: {args}') | |||
| def prim_attr_register(fn): | |||