| @@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) { | |||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name.compare("None") == 0) { | if (type_name.compare("None") == 0) { | ||||
| type = std::make_shared<TypeNone>(); | type = std::make_shared<TypeNone>(); | ||||
| } else if (type_name.compare("Ellipsis") == 0) { | |||||
| type = std::make_shared<Ellipsis>(); | |||||
| } else if (type_name.compare("TypeType") == 0) { | } else if (type_name.compare("TypeType") == 0) { | ||||
| type = std::make_shared<TypeType>(); | type = std::make_shared<TypeType>(); | ||||
| } else if (type_name.compare("SymbolicKeyType") == 0) { | } else if (type_name.compare("SymbolicKeyType") == 0) { | ||||
| @@ -18,6 +18,5 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| const TypePtr kTypeNone = std::make_shared<TypeNone>(); | const TypePtr kTypeNone = std::make_shared<TypeNone>(); | ||||
| const TypePtr kTypeAnything = std::make_shared<TypeAnything>(); | |||||
| const TypePtr kAnyType = std::make_shared<TypeAnything>(); | const TypePtr kAnyType = std::make_shared<TypeAnything>(); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -71,8 +71,20 @@ class TypeNull : public Type { | |||||
| }; | }; | ||||
| using TypeNullPtr = std::shared_ptr<TypeNull>; | using TypeNullPtr = std::shared_ptr<TypeNull>; | ||||
| class Ellipsis : public Type { | |||||
| public: | |||||
| Ellipsis() : Type(kMetaTypeEllipsis) {} | |||||
| ~Ellipsis() override {} | |||||
| MS_DECLARE_PARENT(Ellipsis, Type) | |||||
| TypeId generic_type_id() const override { return kMetaTypeEllipsis; } | |||||
| TypePtr DeepCopy() const override { return std::make_shared<Ellipsis>(); } | |||||
| std::string ToReprString() const override { return "Ellipsis"; } | |||||
| std::string DumpText() const override { return "Ellipsis"; } | |||||
| }; | |||||
| using EllipsisPtr = std::shared_ptr<Ellipsis>; | |||||
| extern const TypePtr kTypeNone; | extern const TypePtr kTypeNone; | ||||
| extern const TypePtr kTypeAnything; | |||||
| extern const TypePtr kAnyType; | extern const TypePtr kAnyType; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -49,6 +49,7 @@ enum TypeId : int { | |||||
| kMetaTypeExternal, | kMetaTypeExternal, | ||||
| kMetaTypeNone, | kMetaTypeNone, | ||||
| kMetaTypeNull, | kMetaTypeNull, | ||||
| kMetaTypeEllipsis, | |||||
| kMetaTypeEnd, | kMetaTypeEnd, | ||||
| // | // | ||||
| // Object types | // Object types | ||||
| @@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract: | |||||
| const NamedPtr kNone = std::make_shared<None>(); | const NamedPtr kNone = std::make_shared<None>(); | ||||
| abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); } | abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); } | ||||
| const NamedPtr kNullObj = std::make_shared<NullObj>(); | |||||
| const NamedPtr kNull = std::make_shared<NullObj>(); | |||||
| abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); } | |||||
| const NamedPtr kEllipsis = std::make_shared<EllipsisObj>(); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -61,7 +61,6 @@ class Named : public Value { | |||||
| std::string name_; | std::string name_; | ||||
| std::size_t hash_id_; | std::size_t hash_id_; | ||||
| }; | }; | ||||
| using NamedPtr = std::shared_ptr<Named>; | using NamedPtr = std::shared_ptr<Named>; | ||||
| class None : public Named { | class None : public Named { | ||||
| @@ -71,7 +70,6 @@ class None : public Named { | |||||
| MS_DECLARE_PARENT(None, Named); | MS_DECLARE_PARENT(None, Named); | ||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| }; | }; | ||||
| extern const NamedPtr kNone; | extern const NamedPtr kNone; | ||||
| class NullObj : public Named { | class NullObj : public Named { | ||||
| @@ -81,7 +79,15 @@ class NullObj : public Named { | |||||
| MS_DECLARE_PARENT(NullObj, Named); | MS_DECLARE_PARENT(NullObj, Named); | ||||
| abstract::AbstractBasePtr ToAbstract() override; | abstract::AbstractBasePtr ToAbstract() override; | ||||
| }; | }; | ||||
| extern const NamedPtr kNull; | |||||
| extern const NamedPtr kNullObj; | |||||
| class EllipsisObj : public Named { | |||||
| public: | |||||
| EllipsisObj() : Named("Ellipsis") {} | |||||
| ~EllipsisObj() override = default; | |||||
| MS_DECLARE_PARENT(EllipsisObj, Named); | |||||
| abstract::AbstractBasePtr ToAbstract() override; | |||||
| }; | |||||
| extern const NamedPtr kEllipsis; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_IR_NAMED_H_ | #endif // MINDSPORE_CCSRC_IR_NAMED_H_ | ||||
| @@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) { | |||||
| if (std::is_integral<T>::value) { | if (std::is_integral<T>::value) { | ||||
| return static_cast<int>(x) % static_cast<int>(y); | return static_cast<int>(x) % static_cast<int>(y); | ||||
| } | } | ||||
| float x_int = std::floor(x); | |||||
| float y_int = std::ceil(y); | |||||
| float max = x_int / y_int; | |||||
| int x_int = std::floor(x); | |||||
| int y_int = std::ceil(y); | |||||
| int max = x_int / y_int; | |||||
| float ret = x - y * max; | float ret = x - y * max; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase; | |||||
| using mindspore::abstract::AbstractClass; | using mindspore::abstract::AbstractClass; | ||||
| using mindspore::abstract::AbstractDictionary; | using mindspore::abstract::AbstractDictionary; | ||||
| using mindspore::abstract::AbstractDictionaryPtr; | using mindspore::abstract::AbstractDictionaryPtr; | ||||
| using mindspore::abstract::AbstractEllipsis; | |||||
| using mindspore::abstract::AbstractEllipsisPtr; | |||||
| using mindspore::abstract::AbstractFunction; | using mindspore::abstract::AbstractFunction; | ||||
| using mindspore::abstract::AbstractFunctionPtr; | using mindspore::abstract::AbstractFunctionPtr; | ||||
| using mindspore::abstract::AbstractList; | using mindspore::abstract::AbstractList; | ||||
| @@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, | |||||
| std::vector<unsigned int> shrink; | std::vector<unsigned int> shrink; | ||||
| auto slice_tuple_eles = slice_tuple->elements(); | auto slice_tuple_eles = slice_tuple->elements(); | ||||
| size_t ellipsis_num = 0; | |||||
| for (size_t index = 0; index < slice_tuple_size; index++) { | for (size_t index = 0; index < slice_tuple_size; index++) { | ||||
| if (slice_tuple_eles[index]->isa<AbstractSlice>()) { | if (slice_tuple_eles[index]->isa<AbstractSlice>()) { | ||||
| AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]); | AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]); | ||||
| @@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, | |||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got " | |||||
| if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) { | |||||
| ellipsis_num++; | |||||
| if (ellipsis_num > 1) { | |||||
| MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis"; | |||||
| } | |||||
| size_t ellipsis_len = shape_size - (slice_tuple_size - 1); | |||||
| begin->insert(begin->end(), ellipsis_len, 0); | |||||
| end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len); | |||||
| strides->insert(strides->end(), ellipsis_len, 1); | |||||
| shrink.insert(shrink.end(), ellipsis_len, 0); | |||||
| continue; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got " | |||||
| << slice_tuple_eles[index]->ToString(); | << slice_tuple_eles[index]->ToString(); | ||||
| } | } | ||||
| @@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec | |||||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | abstract::CheckArgsSize(op_name, args_spec_list, 2); | ||||
| AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr tensor_node = ret_graph->add_parameter(); | |||||
| (void)ret_graph->add_parameter(); | |||||
| auto shape = tensorPtr->shape()->shape(); | auto shape = tensorPtr->shape()->shape(); | ||||
| std::vector<int> begin; | std::vector<int> begin; | ||||
| std::vector<int> end; | std::vector<int> end; | ||||
| @@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec | |||||
| shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); | shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); | ||||
| } else if (args_spec_list[1]->isa<AbstractScalar>()) { | } else if (args_spec_list[1]->isa<AbstractScalar>()) { | ||||
| AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]); | AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]); | ||||
| if (scalar_ptr->BuildValue()->isa<BoolImm>()) { | |||||
| if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) { | |||||
| return ExpandADim(ret_graph, tensor_node); | |||||
| } | |||||
| } | |||||
| shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); | shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); | ||||
| } else if (args_spec_list[1]->isa<AbstractEllipsis>()) { | |||||
| ret_graph->set_output(tensor_node); | |||||
| return ret_graph; | |||||
| } else if (args_spec_list[1]->isa<AbstractNone>()) { | |||||
| return ExpandADim(ret_graph, tensor_node); | |||||
| } else { | } else { | ||||
| std::ostringstream args_info; | std::ostringstream args_info; | ||||
| for (const auto &arg : args_spec_list) { | for (const auto &arg : args_spec_list) { | ||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| args_info << arg->ToString() << "\n"; | args_info << arg->ToString() << "\n"; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got " | |||||
| << args_info.str(); | |||||
| MS_LOG(EXCEPTION) | |||||
| << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got " | |||||
| << args_info.str(); | |||||
| } | } | ||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr tensor_node = ret_graph->add_parameter(); | |||||
| (void)ret_graph->add_parameter(); | |||||
| auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); | auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); | ||||
| auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), | auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), | ||||
| NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); | NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); | ||||
| @@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const { | |||||
| auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); | |||||
| ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); | |||||
| return ret_graph; | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | ||||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | ||||
| .def(py::init<std::string &>()); | .def(py::init<std::string &>()); | ||||
| @@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph { | |||||
| MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | ||||
| friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | ||||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; | |||||
| }; | }; | ||||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | using TensorSlicePtr = std::shared_ptr<TensorSlice>; | ||||
| @@ -109,6 +109,7 @@ void Parser::BuildMethodMap() { | |||||
| expr_method_map_["Index"] = &Parser::ParseIndex; | expr_method_map_["Index"] = &Parser::ParseIndex; | ||||
| expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; | expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; | ||||
| expr_method_map_["Dict"] = &Parser::ParseDict; | expr_method_map_["Dict"] = &Parser::ParseDict; | ||||
| expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; | |||||
| } | } | ||||
| void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } | void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } | ||||
| @@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, | |||||
| namelist_for_default_value.push_back(arg_name); | namelist_for_default_value.push_back(arg_name); | ||||
| if (py::isinstance<py::none>(defaults[i])) { | if (py::isinstance<py::none>(defaults[i])) { | ||||
| default_values.push_back(NewValueNode(kNullObj)); | |||||
| default_values.push_back(NewValueNode(kNull)); | |||||
| } else { | } else { | ||||
| default_values.push_back(ParseExprNode(block, defaults[i])); | default_values.push_back(ParseExprNode(block, defaults[i])); | ||||
| } | } | ||||
| @@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { | |||||
| return NewValueNode(kNone); | return NewValueNode(kNone); | ||||
| } | } | ||||
| AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { | |||||
| MS_LOG(DEBUG) << "Process ast Ellipsis"; | |||||
| return NewValueNode(kEllipsis); | |||||
| } | |||||
| AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { | AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { | ||||
| MS_LOG(DEBUG) << "Process ast Num"; | MS_LOG(DEBUG) << "Process ast Num"; | ||||
| py::object obj = python_adapter::GetPyObjAttr(node, "n"); | py::object obj = python_adapter::GetPyObjAttr(node, "n"); | ||||
| @@ -92,6 +92,8 @@ class Parser { | |||||
| AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); | AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); | ||||
| // process NoneType | // process NoneType | ||||
| AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); | AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); | ||||
| // process Ellipsis | |||||
| AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); | |||||
| // process a integer or float number | // process a integer or float number | ||||
| AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); | AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); | ||||
| // process a string variable | // process a string variable | ||||
| @@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const { | |||||
| std::string AbstractNull::ToString() const { | std::string AbstractNull::ToString() const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << type_name() << "(" | |||||
| << "Value: " | |||||
| << "Null" | |||||
| << ")"; | |||||
| buffer << type_name() << "(Value: Null)"; | |||||
| return buffer.str(); | |||||
| } | |||||
| bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } | |||||
| bool AbstractEllipsis::operator==(const AbstractBase &other) const { | |||||
| if (&other == this) { | |||||
| return true; | |||||
| } | |||||
| if (other.isa<AbstractEllipsis>()) { | |||||
| auto other_none = static_cast<const AbstractEllipsis *>(&other); | |||||
| return *this == *other_none; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| std::string AbstractEllipsis::ToString() const { | |||||
| std::ostringstream buffer; | |||||
| buffer << type_name() << "(Value: Ellipsis)"; | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| @@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr<AbstractNone>; | |||||
| // the un assigned state value for variable, which means the variable is not assigned | // the un assigned state value for variable, which means the variable is not assigned | ||||
| class AbstractNull : public AbstractBase { | class AbstractNull : public AbstractBase { | ||||
| public: | public: | ||||
| AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared<TypeNull>()); } | |||||
| AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); } | |||||
| ~AbstractNull() override = default; | ~AbstractNull() override = default; | ||||
| MS_DECLARE_PARENT(AbstractNull, AbstractBase) | MS_DECLARE_PARENT(AbstractNull, AbstractBase) | ||||
| @@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase { | |||||
| }; | }; | ||||
| using AbstractNullPtr = std::shared_ptr<AbstractNull>; | using AbstractNullPtr = std::shared_ptr<AbstractNull>; | ||||
| class AbstractEllipsis : public AbstractBase { | |||||
| public: | |||||
| AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<Ellipsis>()); } | |||||
| ~AbstractEllipsis() override = default; | |||||
| MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) | |||||
| TypePtr BuildType() const override { return std::make_shared<Ellipsis>(); } | |||||
| bool operator==(const AbstractEllipsis &other) const; | |||||
| bool operator==(const AbstractBase &other) const override; | |||||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); } | |||||
| std::string ToString() const override; | |||||
| }; | |||||
| using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>; | |||||
| class AbstractRefKey : public AbstractBase { | class AbstractRefKey : public AbstractBase { | ||||
| public: | public: | ||||
| AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); } | AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); } | ||||
| @@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index): | |||||
| @getitem.register("Tensor", "Slice") | @getitem.register("Tensor", "Slice") | ||||
| def _tensor_getitem_by_slice(data, slice_index): | def _tensor_getitem_by_slice(data, slice_index): | ||||
| """ | """ | ||||
| Getting item of tensor by slice index. | |||||
| Getting item of tensor by slice. | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): A tensor. | data (Tensor): A tensor. | ||||
| @@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index): | |||||
| @getitem.register("Tensor", "Tuple") | @getitem.register("Tensor", "Tuple") | ||||
| def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): | def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): | ||||
| """ | """ | ||||
| Getting item of tensor by slice tuple index. | |||||
| Getting item of tensor by slice tuple. | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): A tensor. | data (Tensor): A tensor. | ||||
| @@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): | |||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, slice_tuple_index) | return _tensor_slice(data, slice_tuple_index) | ||||
| @getitem.register("Tensor", "Ellipsis") | |||||
| def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||||
| """ | |||||
| Getting item of tensor by Ellipsis. | |||||
| Inputs: | |||||
| data (Tensor): A tensor. | |||||
| ellipsis (Ellipsis): A Ellipsis object. | |||||
| Outputs: | |||||
| Tensor, same as data. | |||||
| """ | |||||
| return _tensor_slice(data, ellipsis_index) | |||||
| @@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor() | |||||
| tuple_to_array = P.TupleToArray() | tuple_to_array = P.TupleToArray() | ||||
| scalar_cast = P.ScalarCast() | scalar_cast = P.ScalarCast() | ||||
| print_ = P.Print() | print_ = P.Print() | ||||
| expand_dims = P.ExpandDims() | |||||
| tuple_setitem = Primitive('tuple_setitem') | tuple_setitem = Primitive('tuple_setitem') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| @@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell): | |||||
| return ret0, ret1, ret2, ret3 | return ret0, ret1, ret2, ret3 | ||||
| class NetWorkSliceEllipsis(Cell): | |||||
| def __init__(self): | |||||
| super(NetWorkSliceEllipsis, self).__init__() | |||||
| self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32)) | |||||
| self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32)) | |||||
| self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0 | |||||
| ret1 = tensor[...] + self.tensor_ret1 | |||||
| ret2 = tensor[True] + self.tensor_ret2 | |||||
| return ret0, ret1, ret2 | |||||
| class NetWorkReduceDimension(Cell): | class NetWorkReduceDimension(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(NetWorkReduceDimension, self).__init__() | super(NetWorkReduceDimension, self).__init__() | ||||
| @@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell): | |||||
| class TensorAssignWithBoolTensorIndex(Cell): | class TensorAssignWithBoolTensorIndex(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | super(TensorAssignWithBoolTensorIndex, self).__init__() | ||||
| self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) | |||||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | |||||
| def construct(self, a, b, c, u_tensor, _scalar): | def construct(self, a, b, c, u_tensor, _scalar): | ||||
| a[c] = u_scalar | a[c] = u_scalar | ||||
| @@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell): | |||||
| class TensorAssignWithBoolTensorIndex2(Cell): | class TensorAssignWithBoolTensorIndex2(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | super(TensorAssignWithBoolTensorIndex2, self).__init__() | ||||
| self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) | |||||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) | |||||
| def construct(self, a, u_tensor, _scalar): | def construct(self, a, u_tensor, _scalar): | ||||
| a[a>8] = u_tensor | |||||
| a[a>=6] = u_scalar | |||||
| a[a<3] = u_scalar | |||||
| a[a<=5] = u_tensor | |||||
| a[a==5] = u_scalar | |||||
| a[a > 8] = u_tensor | |||||
| a[a >= 6] = u_scalar | |||||
| a[a < 3] = u_scalar | |||||
| a[a <= 5] = u_tensor | |||||
| a[a == 5] = u_scalar | |||||
| z = a + self.t | z = a + self.t | ||||
| return z | return z | ||||
| @@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): | |||||
| super(TensorAssignWithBoolTensorIndex2Error, self).__init__() | super(TensorAssignWithBoolTensorIndex2Error, self).__init__() | ||||
| def construct(self, a, u_tensor): | def construct(self, a, u_tensor): | ||||
| a[a>8][a>5] = u_tensor | |||||
| a[a > 8][a > 5] = u_tensor | |||||
| return a | return a | ||||
| a = np.random.uniform(1,10,[2,3]) | |||||
| a = np.random.uniform(1, 10, [2, 3]) | |||||
| b = a > 5 | b = a > 5 | ||||
| c = a < 3 | c = a < 3 | ||||
| Ta = Tensor(a) | Ta = Tensor(a) | ||||
| @@ -152,7 +166,7 @@ def test_tensor_assign_bool_index(): | |||||
| net1(Ta, Tb, Ta, u_tensor, u_scalar) | net1(Ta, Tb, Ta, u_tensor, u_scalar) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net1(Ta, Tb, Tc, u_tensor_error, u_scalar) | net1(Ta, Tb, Tc, u_tensor_error, u_scalar) | ||||
| #net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) | |||||
| # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net2(Ta, u_tensor_error, u_scalar) | net2(Ta, u_tensor_error, u_scalar) | ||||
| net3 = TensorAssignWithBoolTensorIndexError() | net3 = TensorAssignWithBoolTensorIndexError() | ||||
| @@ -192,7 +206,10 @@ test_cases = [ | |||||
| 'block': NetWorkReduceToScalar(), | 'block': NetWorkReduceToScalar(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], | 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], | ||||
| }), | }), | ||||
| ('NetWorkSliceEllipsis', { | |||||
| 'block': NetWorkSliceEllipsis(), | |||||
| 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], | |||||
| }), | |||||
| ] | ] | ||||
| @@ -162,14 +162,15 @@ def test_ops(): | |||||
| if self.int > self.float: | if self.int > self.float: | ||||
| if [1, 2, 3] != None: | if [1, 2, 3] != None: | ||||
| if self.str_a + self.str_b == "helloworld": | if self.str_a + self.str_b == "helloworld": | ||||
| print("hello world") | |||||
| return ret | |||||
| if q == 86: | |||||
| print("hello world") | |||||
| return ret | |||||
| return x | return x | ||||
| net = OpsNet(9, 2) | net = OpsNet(9, 2) | ||||
| x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) | x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) | ||||
| y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) | y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) | ||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| net(x, y) | net(x, y) | ||||