support tensor slice using constexpr remove tensorslice metagraph add pynative testcasestags/v0.5.0-beta
| @@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, | |||||
| get_dataclass_attributes, get_dataclass_methods, get_obj_id, | get_dataclass_attributes, get_dataclass_methods, get_obj_id, | ||||
| get_module_namespace, get_obj_type, get_object_key, | get_module_namespace, get_obj_type, get_object_key, | ||||
| get_default_input, get_parse_method_of_class, get_scope_name, | get_default_input, get_parse_method_of_class, get_scope_name, | ||||
| is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj) | |||||
| is_class_member, parse_cb, resolve_symbol) | |||||
| from .serialize import * | from .serialize import * | ||||
| __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | ||||
| @@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', | |||||
| 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', | 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', | ||||
| 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', | 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', | ||||
| 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', | 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', | ||||
| 'create_slice_obj', 'create_ellipsis_obj'] | |||||
| 'create_slice_obj'] | |||||
| @@ -29,7 +29,6 @@ from mindspore.common.dtype import pytype_to_dtype | |||||
| from mindspore.common.api import _MindSporeFunction | from mindspore.common.api import _MindSporeFunction | ||||
| from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace | from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace | ||||
| from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT | from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT | ||||
| from ..utils import Slice, Ellipsis_ | |||||
| # define return value | # define return value | ||||
| RET_SUCCESS = 0 | RET_SUCCESS = 0 | ||||
| @@ -70,14 +69,9 @@ parse_expr_statement_white_list = ( | |||||
| "append", | "append", | ||||
| ) | ) | ||||
| def create_ellipsis_obj(): | |||||
| """Create Slice object""" | |||||
| return Ellipsis_() | |||||
| def create_slice_obj(start, end, step): | def create_slice_obj(start, end, step): | ||||
| """Create Slice object""" | |||||
| return Slice(start, end, step) | |||||
| """Create slice object""" | |||||
| return slice(start, end, step) | |||||
| def parse_cb(func, parse_method=None): | def parse_cb(func, parse_method=None): | ||||
| @@ -19,7 +19,6 @@ import logging | |||||
| import os | import os | ||||
| import inspect | import inspect | ||||
| from functools import wraps | from functools import wraps | ||||
| from dataclasses import dataclass | |||||
| def cal_sha256(file_path): | def cal_sha256(file_path): | ||||
| @@ -100,20 +99,3 @@ def cell_attr_register(fn=None, attrs=None): | |||||
| if fn is not None: | if fn is not None: | ||||
| return wrap_cell(fn) | return wrap_cell(fn) | ||||
| return wrap_cell | return wrap_cell | ||||
| @dataclass | |||||
| class Slice: | |||||
| """ | |||||
| Slice class | |||||
| """ | |||||
| start: int | |||||
| end: int | |||||
| step: int | |||||
| @dataclass | |||||
| class Ellipsis_: | |||||
| """ | |||||
| Ellipsis class | |||||
| """ | |||||
| @@ -932,206 +932,6 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) { | |||||
| unsigned int number_dec = 0; | |||||
| for (size_t index = 0; index < number_bin.size(); index++) { | |||||
| number_dec |= number_bin[index] << index; | |||||
| } | |||||
| return static_cast<int>(number_dec); | |||||
| } | |||||
| void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end, | |||||
| std::vector<int> *strides, int length) { | |||||
| MS_EXCEPTION_IF_NULL(slice); | |||||
| MS_EXCEPTION_IF_NULL(begin); | |||||
| MS_EXCEPTION_IF_NULL(end); | |||||
| MS_EXCEPTION_IF_NULL(strides); | |||||
| if (length <= 0) { | |||||
| MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1"; | |||||
| } | |||||
| int start_default = 0; | |||||
| int stop_default = length; | |||||
| int step_default = 1; | |||||
| int step_value = CheckSliceMember(slice->step(), step_default, "step"); | |||||
| if (step_value < 0) { | |||||
| start_default = -1; | |||||
| stop_default = -(length + 1); | |||||
| } | |||||
| begin->push_back(CheckSliceMember(slice->start(), start_default, "begin")); | |||||
| end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop")); | |||||
| strides->push_back(step_value); | |||||
| } | |||||
| int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape, | |||||
| std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) { | |||||
| MS_EXCEPTION_IF_NULL(slice_tuple); | |||||
| MS_EXCEPTION_IF_NULL(begin); | |||||
| MS_EXCEPTION_IF_NULL(end); | |||||
| MS_EXCEPTION_IF_NULL(strides); | |||||
| size_t slice_tuple_size = slice_tuple->size(); | |||||
| size_t shape_size = shape.size(); | |||||
| if (slice_tuple_size > shape_size) { | |||||
| MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor," | |||||
| "when the rank of tensor is " | |||||
| << shape_size << ", the number of slice is " << slice_tuple_size; | |||||
| } | |||||
| std::vector<unsigned int> shrink; | |||||
| auto slice_tuple_eles = slice_tuple->elements(); | |||||
| size_t ellipsis_num = 0; | |||||
| for (size_t index = 0; index < slice_tuple_size; index++) { | |||||
| if (slice_tuple_eles[index]->isa<AbstractSlice>()) { | |||||
| AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]); | |||||
| ParseSlice(slice, begin, end, strides, shape[index]); | |||||
| shrink.push_back(0); | |||||
| continue; | |||||
| } | |||||
| if (slice_tuple_eles[index]->isa<AbstractScalar>()) { | |||||
| int ele_index = GetArgScalarValue(dyn_cast<AbstractScalar>(slice_tuple_eles[index]), "slice_tuple"); | |||||
| begin->push_back(ele_index); | |||||
| end->push_back(ele_index + 1); | |||||
| strides->push_back(1); | |||||
| shrink.push_back(1); | |||||
| continue; | |||||
| } | |||||
| if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) { | |||||
| ellipsis_num++; | |||||
| if (ellipsis_num > 1) { | |||||
| MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis"; | |||||
| } | |||||
| size_t ellipsis_len = shape_size - (slice_tuple_size - 1); | |||||
| begin->insert(begin->end(), ellipsis_len, 0); | |||||
| end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len); | |||||
| strides->insert(strides->end(), ellipsis_len, 1); | |||||
| shrink.insert(shrink.end(), ellipsis_len, 0); | |||||
| continue; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got " | |||||
| << slice_tuple_eles[index]->ToString(); | |||||
| } | |||||
| if (ellipsis_num == 0) { | |||||
| for (size_t index = slice_tuple_size; index < shape_size; index++) { | |||||
| begin->push_back(0); | |||||
| end->push_back(shape[index]); | |||||
| strides->push_back(1); | |||||
| } | |||||
| } | |||||
| return ConvertBinaryToDecimal(shrink); | |||||
| } | |||||
| int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape, | |||||
| std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) { | |||||
| MS_EXCEPTION_IF_NULL(begin); | |||||
| MS_EXCEPTION_IF_NULL(end); | |||||
| MS_EXCEPTION_IF_NULL(strides); | |||||
| size_t shape_size = shape.size(); | |||||
| if (shape_size == 0) { | |||||
| MS_LOG(EXCEPTION) << "Could slice a scalar tensor"; | |||||
| } | |||||
| ParseSlice(slice, begin, end, strides, shape[0]); | |||||
| for (size_t index = 1; index < shape_size; index++) { | |||||
| begin->push_back(0); | |||||
| end->push_back(shape[index]); | |||||
| strides->push_back(1); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape, | |||||
| std::vector<int> *begin, std::vector<int> *end, | |||||
| std::vector<int> *strides) { | |||||
| MS_EXCEPTION_IF_NULL(begin); | |||||
| MS_EXCEPTION_IF_NULL(end); | |||||
| MS_EXCEPTION_IF_NULL(strides); | |||||
| int ele_index = GetArgScalarValue(scalar, "slice_tuple"); | |||||
| begin->push_back(ele_index); | |||||
| end->push_back(ele_index + 1); | |||||
| strides->push_back(1); | |||||
| for (size_t index = 1; index < shape.size(); index++) { | |||||
| begin->push_back(0); | |||||
| end->push_back(shape[index]); | |||||
| strides->push_back(1); | |||||
| } | |||||
| return 1; | |||||
| } | |||||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) { | |||||
| auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); | |||||
| ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); | |||||
| return ret_graph; | |||||
| } | |||||
| FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // slice a tensor | |||||
| // args: tensor, slice or slice tuple | |||||
| const std::string op_name = std::string("TensorSlice"); | |||||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | |||||
| AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr tensor_node = ret_graph->add_parameter(); | |||||
| (void)ret_graph->add_parameter(); | |||||
| auto shape = tensorPtr->shape()->shape(); | |||||
| std::vector<int> begin; | |||||
| std::vector<int> end; | |||||
| std::vector<int> strides; | |||||
| int shrink_axis_mask; | |||||
| if (args_spec_list[1]->isa<AbstractTuple>()) { | |||||
| AbstractTuplePtr tuple_ptr = dyn_cast<AbstractTuple>(args_spec_list[1]); | |||||
| shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides); | |||||
| } else if (args_spec_list[1]->isa<AbstractSlice>()) { | |||||
| AbstractSlicePtr slice_ptr = dyn_cast<AbstractSlice>(args_spec_list[1]); | |||||
| shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); | |||||
| } else if (args_spec_list[1]->isa<AbstractScalar>()) { | |||||
| AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]); | |||||
| if (scalar_ptr->BuildValue()->isa<BoolImm>()) { | |||||
| if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) { | |||||
| return ExpandADim(ret_graph, tensor_node); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "TensorSlice not support the index is False."; | |||||
| } | |||||
| shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); | |||||
| } else if (args_spec_list[1]->isa<AbstractEllipsis>()) { | |||||
| ret_graph->set_output(tensor_node); | |||||
| return ret_graph; | |||||
| } else if (args_spec_list[1]->isa<AbstractNone>()) { | |||||
| return ExpandADim(ret_graph, tensor_node); | |||||
| } else { | |||||
| std::ostringstream args_info; | |||||
| for (const auto &arg : args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| args_info << arg->ToString() << "\n"; | |||||
| } | |||||
| MS_LOG(EXCEPTION) | |||||
| << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got " | |||||
| << args_info.str(); | |||||
| } | |||||
| auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); | |||||
| auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), | |||||
| NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); | |||||
| ret_graph->set_output(ret_graph->NewCNode( | |||||
| {PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)})); | |||||
| return ret_graph; | |||||
| } | |||||
| FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | ||||
| // select indexed item | // select indexed item | ||||
| // args: tuple of items, index | // args: tuple of items, index | ||||
| @@ -1162,11 +962,6 @@ REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { | |||||
| .def(py::init<std::string &>()); | .def(py::init<std::string &>()); | ||||
| })); | })); | ||||
| REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { | |||||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | |||||
| .def(py::init<std::string &>()); | |||||
| })); | |||||
| REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { | ||||
| (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>( | (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>( | ||||
| *m, "TupleGetItemTensor_") | *m, "TupleGetItemTensor_") | ||||
| @@ -175,16 +175,6 @@ class TupleSlice : public MetaFuncGraph { | |||||
| }; | }; | ||||
| using TupleSlicePtr = std::shared_ptr<TupleSlice>; | using TupleSlicePtr = std::shared_ptr<TupleSlice>; | ||||
| class TensorSlice : public MetaFuncGraph { | |||||
| public: | |||||
| explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~TensorSlice() override = default; | |||||
| MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | |||||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | |||||
| class TupleGetItemTensor : public MetaFuncGraph { | class TupleGetItemTensor : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} | explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} | ||||
| @@ -209,6 +209,28 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ConvertSlice(const py::object &obj, ValuePtr *const data) { | |||||
| MS_LOG(DEBUG) << "Converting slice object"; | |||||
| py::slice slice_obj = obj.cast<py::slice>(); | |||||
| auto convert_func = [obj](std::string attr) -> ValuePtr { | |||||
| auto py_attr = py::getattr(obj, attr.c_str()); | |||||
| if (py::isinstance<py::none>(py_attr)) { | |||||
| return kNone; | |||||
| } else if (py::isinstance<py::int_>(py_attr)) { | |||||
| int value = py::cast<int>(py_attr); | |||||
| return MakeValue(value); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Slice should contain only int or none"; | |||||
| } | |||||
| }; | |||||
| ValuePtr start = convert_func("start"); | |||||
| ValuePtr stop = convert_func("stop"); | |||||
| ValuePtr step = convert_func("step"); | |||||
| *data = std::make_shared<ValueSlice>(start, stop, step); | |||||
| return true; | |||||
| } | |||||
| FuncGraphPtr ConvertToBpropCut(py::object obj) { | FuncGraphPtr ConvertToBpropCut(py::object obj) { | ||||
| std::vector<std::string> results = data_converter::GetObjKey(obj); | std::vector<std::string> results = data_converter::GetObjKey(obj); | ||||
| std::string obj_key = results[0]; | std::string obj_key = results[0]; | ||||
| @@ -321,6 +343,10 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||||
| converted = std::make_shared<StringImm>(py::cast<std::string>(obj)); | converted = std::make_shared<StringImm>(py::cast<std::string>(obj)); | ||||
| } else if (py::isinstance<py::dict>(obj)) { | } else if (py::isinstance<py::dict>(obj)) { | ||||
| ret = ConvertDict(obj, &converted, use_signature); | ret = ConvertDict(obj, &converted, use_signature); | ||||
| } else if (py::isinstance<py::slice>(obj)) { | |||||
| ret = ConvertSlice(obj, &converted); | |||||
| } else if (py::isinstance<py::ellipsis>(obj)) { | |||||
| converted = kEllipsis; | |||||
| } else if (py::isinstance<py::tuple>(obj)) { | } else if (py::isinstance<py::tuple>(obj)) { | ||||
| ret = ConvertTuple(obj, &converted, use_signature); | ret = ConvertTuple(obj, &converted, use_signature); | ||||
| } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | ||||
| @@ -353,11 +353,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| auto value = abs_base->cast<AbstractRefPtr>()->ref(); | auto value = abs_base->cast<AbstractRefPtr>()->ref(); | ||||
| dic = ConvertAbstractToPython(value); | dic = ConvertAbstractToPython(value); | ||||
| } else if (abs_base->isa<AbstractEllipsis>()) { | } else if (abs_base->isa<AbstractEllipsis>()) { | ||||
| auto arg_slice = dyn_cast<AbstractEllipsis>(abs_base); | |||||
| std::vector<int> shape; | |||||
| dic["shape"] = shape; | |||||
| dic["dtype"] = arg_slice->BuildType(); | |||||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | |||||
| dic["shape"] = py::none(); | |||||
| dic["dtype"] = py::ellipsis(); | |||||
| dic["value"] = py::ellipsis(); | |||||
| } else if (abs_base->isa<AbstractTuple>()) { | } else if (abs_base->isa<AbstractTuple>()) { | ||||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | ||||
| size_t len = arg_tuple->size(); | size_t len = arg_tuple->size(); | ||||
| @@ -106,7 +106,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) { | |||||
| } | } | ||||
| ret = rets; | ret = rets; | ||||
| } else if (value->isa<EllipsisObj>()) { | } else if (value->isa<EllipsisObj>()) { | ||||
| ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS); | |||||
| ret = py::ellipsis(); | |||||
| } else if (value->isa<ValueSlice>()) { | } else if (value->isa<ValueSlice>()) { | ||||
| auto slice = value->cast<ValueSlicePtr>(); | auto slice = value->cast<ValueSlicePtr>(); | ||||
| auto start = ValuePtrToPyData(slice->start()); | auto start = ValuePtrToPyData(slice->start()); | ||||
| @@ -206,6 +206,9 @@ class Parameter: | |||||
| res.default_input = res.default_input / other | res.default_input = res.default_input / other | ||||
| return res | return res | ||||
| def __setitem__(self, index, value): | |||||
| return self | |||||
| def set_parameter_data(self, data): | def set_parameter_data(self, data): | ||||
| """Set `default_input` of current `Parameter`.""" | """Set `default_input` of current `Parameter`.""" | ||||
| if isinstance(data, bool): | if isinstance(data, bool): | ||||
| @@ -144,6 +144,13 @@ class Tensor(Tensor_): | |||||
| out = tensor_operator_registry.get('__le__')(self, other) | out = tensor_operator_registry.get('__le__')(self, other) | ||||
| return out | return out | ||||
| def __getitem__(self, index): | |||||
| out = tensor_operator_registry.get('__getitem__')(self, index) | |||||
| return out | |||||
| def __setitem__(self, index, value): | |||||
| return self | |||||
| def __gt__(self, other): | def __gt__(self, other): | ||||
| out = tensor_operator_registry.get('__gt__')(self, other) | out = tensor_operator_registry.get('__gt__')(self, other) | ||||
| return out | return out | ||||
| @@ -19,7 +19,7 @@ | |||||
| from functools import partial | from functools import partial | ||||
| from mindspore import context | from mindspore import context | ||||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ | |||||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.api import ms_function, _pynative_exec, _wrap_func | from ...common.api import ms_function, _pynative_exec, _wrap_func | ||||
| @@ -27,7 +27,7 @@ from .. import functional as F | |||||
| from ...common.parameter import Parameter | from ...common.parameter import Parameter | ||||
| __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | |||||
| __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | |||||
| def add_flags(fn, **flags): | def add_flags(fn, **flags): | ||||
| @@ -18,7 +18,9 @@ from . import _constexpr_utils as const_utils | |||||
| from ... import functional as F | from ... import functional as F | ||||
| from ... import operations as P | from ... import operations as P | ||||
| from ...composite import base | from ...composite import base | ||||
| from ....common.tensor import Tensor | |||||
| from ....common import dtype as mstype | from ....common import dtype as mstype | ||||
| from ....common._register_for_tensor import tensor_operator_registry | |||||
| hyper_map = base.HyperMap() | hyper_map = base.HyperMap() | ||||
| pack = P.Pack(axis=-1) | pack = P.Pack(axis=-1) | ||||
| @@ -152,3 +154,101 @@ def generate_updates_from_tensor(data, index, value, op_type): | |||||
| if need_broadcast: | if need_broadcast: | ||||
| return broadcast(updates_shape, value) | return broadcast(updates_shape, value) | ||||
| return value | return value | ||||
| def tensor_getitem(self, index): | |||||
| """Handle tensor getitem""" | |||||
| if isinstance(index, Tensor): | |||||
| return tensor_index_by_tensor(self, index) | |||||
| if isinstance(index, tuple): | |||||
| return tensor_index_by_tuple(self, index) | |||||
| if isinstance(index, int): | |||||
| return tensor_index_by_number(self, index) | |||||
| if isinstance(index, slice): | |||||
| return tensor_index_by_slice(self, index) | |||||
| if isinstance(index, bool): | |||||
| return tensor_index_by_bool(self, index) | |||||
| if index is ...: | |||||
| return self | |||||
| raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\ | |||||
| got {} with type{}".format(index, type(index))) | |||||
| tensor_operator_registry.register("__getitem__", tensor_getitem) | |||||
| def tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||||
| """Tensor getitem by a tuple of tensor.""" | |||||
| indices = generate_indices_from_tuple_of_tensor(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | |||||
| return result | |||||
| def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||||
| """Tensor getitem by a tuple of mixed tensor.""" | |||||
| indices = generate_indices_from_tuple_of_mixed_tensors(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | |||||
| return result | |||||
| def tensor_index_by_slice(data, slice_index): | |||||
| """Tensor getitem by a single slice""" | |||||
| begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index) | |||||
| return F.strided_slice(data, begin_strides, end_strides, step_strides) | |||||
| def tensor_index_by_integer(data, number): | |||||
| """Tensor getitem by a single integer number""" | |||||
| begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number) | |||||
| shrink_axis_mask = 1 | |||||
| return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | |||||
| def tensor_index_by_bool(data, bool_value): | |||||
| """Tensor getitem by a single bool value""" | |||||
| if bool_value: | |||||
| return F.expand_dims(data, 0) | |||||
| return const_utils.raise_index_error("bool value as indexing ,false is not supported") | |||||
| def tensor_index_by_number(data, number): | |||||
| """Tensor getitem by a Number which may be integer/float/bool value""" | |||||
| number_type = const_utils.check_number_index_type(number) | |||||
| if number_type == const_utils.BOOL_: | |||||
| return tensor_index_by_bool(data, number) | |||||
| if number_type == const_utils.INT_: | |||||
| return tensor_index_by_integer(data, number) | |||||
| return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") | |||||
| def tensor_index_by_tensor(data, tensor_index): | |||||
| """Tensor getitem by a single tensor""" | |||||
| dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), | |||||
| const_utils.TENSOR_GETITEM) | |||||
| if dtype_valid: | |||||
| return F.gather(data, tensor_index, 0) | |||||
| return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") | |||||
| def tensor_index_by_tuple_slice(data, t): | |||||
| """Tensor getitem by a tuple of slice""" | |||||
| begin_strides, end_strides, step_strides, shrink_axis_mask = \ | |||||
| const_utils.get_stride_info_from_tuple(F.shape(data), t) | |||||
| return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | |||||
| def tensor_index_by_tuple(data, tuple_index): | |||||
| """Tensor getitem by tuple of various types""" | |||||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | |||||
| if index_elements_type == const_utils.NO_TENSOR: | |||||
| return tensor_index_by_tuple_slice(data, tuple_index) | |||||
| if index_elements_type == const_utils.ALL_TENSOR: | |||||
| return tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||||
| return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||||
| @@ -20,7 +20,6 @@ import numpy as np | |||||
| from ...primitive import constexpr | from ...primitive import constexpr | ||||
| from .... import log as logger | from .... import log as logger | ||||
| from ...._extends.utils import Slice, Ellipsis_ | |||||
| from ....common import dtype as mstype | from ....common import dtype as mstype | ||||
| from ....common.tensor import Tensor | from ....common.tensor import Tensor | ||||
| from ....ops import _utils as op_utils | from ....ops import _utils as op_utils | ||||
| @@ -41,6 +40,11 @@ SET_ITEM_BY_ONE_TENSOR = 0 | |||||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | ||||
| @constexpr | |||||
| def raise_index_error(msg): | |||||
| raise IndexError(msg) | |||||
| @constexpr | @constexpr | ||||
| def check_equal(param1, param2, msg="{},{}"): | def check_equal(param1, param2, msg="{},{}"): | ||||
| """Checks whether the two parameters are equal or not.""" | """Checks whether the two parameters are equal or not.""" | ||||
| @@ -54,7 +58,8 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): | |||||
| """Checks the shape and size of the sensor and value.""" | """Checks the shape and size of the sensor and value.""" | ||||
| if data_shape == value_shape or data_size == value_size or value_size == 1: | if data_shape == value_shape or data_size == value_size or value_size == 1: | ||||
| return True | return True | ||||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) | |||||
| raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format( | |||||
| value_shape, data_shape)) | |||||
| @constexpr | @constexpr | ||||
| @@ -63,16 +68,18 @@ def check_tensor_setitem_index(index, element_type=None): | |||||
| if index is None: | if index is None: | ||||
| raise IndexError("Tensor's index cannot be None.") | raise IndexError("Tensor's index cannot be None.") | ||||
| # eg. Tensor[Slice] = u | # eg. Tensor[Slice] = u | ||||
| if isinstance(index, Slice): | |||||
| if isinstance(index, slice): | |||||
| return True | return True | ||||
| # eg. Tensor[tuple] = u | # eg. Tensor[tuple] = u | ||||
| if isinstance(index, tuple): | if isinstance(index, tuple): | ||||
| if not index: | if not index: | ||||
| raise IndexError("Tensor's index cannot be empty.") | raise IndexError("Tensor's index cannot be empty.") | ||||
| # eg. Tensor[tuple(Slice...)] = u | |||||
| if isinstance(index[0], (Slice, Ellipsis_, int)): | |||||
| return True | |||||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) | |||||
| # eg. Tensor[tuple(Slice,...)] = u | |||||
| for item in index: | |||||
| if not isinstance(item, (slice, type(...), int)): | |||||
| raise IndexError( | |||||
| "Index of type '{}' is not supported yet.".format(type(item))) | |||||
| return True | |||||
| # eg. Tensor[Tensor[dtype=bool]] = u | # eg. Tensor[Tensor[dtype=bool]] = u | ||||
| if isinstance(index, mstype.tensor_type): | if isinstance(index, mstype.tensor_type): | ||||
| if element_type is None or element_type != mstype.bool_: | if element_type is None or element_type != mstype.bool_: | ||||
| @@ -81,7 +88,8 @@ def check_tensor_setitem_index(index, element_type=None): | |||||
| "{} type is not supported yet.".format(element_type)) | "{} type is not supported yet.".format(element_type)) | ||||
| return True | return True | ||||
| raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) | |||||
| raise IndexError( | |||||
| "Index of type '{}' is not supported yet.".format(type(index))) | |||||
| @constexpr | @constexpr | ||||
| @@ -116,12 +124,12 @@ def slice_expand(input_slices, shape): | |||||
| index = 0 | index = 0 | ||||
| slices = None | slices = None | ||||
| # Slice or tuple(Slice...) | # Slice or tuple(Slice...) | ||||
| if isinstance(input_slices, Slice): | |||||
| if isinstance(input_slices, slice): | |||||
| slices = (input_slices,) | slices = (input_slices,) | ||||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): | |||||
| elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (slice, type(...))): | |||||
| is_have_ellipsis = False | is_have_ellipsis = False | ||||
| for _, element in enumerate(input_slices): | for _, element in enumerate(input_slices): | ||||
| if isinstance(element, Ellipsis_): | |||||
| if isinstance(element, type(...)): | |||||
| is_have_ellipsis = True | is_have_ellipsis = True | ||||
| break | break | ||||
| if is_have_ellipsis: | if is_have_ellipsis: | ||||
| @@ -130,10 +138,9 @@ def slice_expand(input_slices, shape): | |||||
| slices = input_slices | slices = input_slices | ||||
| else: | else: | ||||
| raise IndexError("Tensor's index type is not supported yet.") | raise IndexError("Tensor's index type is not supported yet.") | ||||
| for s in slices: | for s in slices: | ||||
| start = 0 if (s.start is None) else s.start | start = 0 if (s.start is None) else s.start | ||||
| stop = shape[index] if (s.end is None) else s.end | |||||
| stop = shape[index] if (s.stop is None) else s.stop | |||||
| step = 1 if (s.step is None) else s.step | step = 1 if (s.step is None) else s.step | ||||
| begin.append(start) | begin.append(start) | ||||
| end.append(stop) | end.append(stop) | ||||
| @@ -151,11 +158,11 @@ def ellipsis2slice(input_, shape): | |||||
| """Converts ellipsis to slice.""" | """Converts ellipsis to slice.""" | ||||
| input_slice = input_ | input_slice = input_ | ||||
| result = [] | result = [] | ||||
| if isinstance(input_, Ellipsis_): | |||||
| if isinstance(input_, type(...)): | |||||
| input_slice = (input_,) | input_slice = (input_,) | ||||
| ell_count = 0 | ell_count = 0 | ||||
| for _, element in enumerate(input_slice): | for _, element in enumerate(input_slice): | ||||
| if not isinstance(element, Ellipsis_): | |||||
| if not isinstance(element, type(...)): | |||||
| result.append(element) | result.append(element) | ||||
| continue | continue | ||||
| ell_count += 1 | ell_count += 1 | ||||
| @@ -163,7 +170,7 @@ def ellipsis2slice(input_, shape): | |||||
| raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " | raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " | ||||
| "but it is currently {}".format(input_slice)) | "but it is currently {}".format(input_slice)) | ||||
| for _ in range(len(shape) - len(input_slice) + 1): | for _ in range(len(shape) - len(input_slice) + 1): | ||||
| result.append(Slice(None, None, None)) | |||||
| result.append(slice(None, None, None)) | |||||
| return tuple(result) | return tuple(result) | ||||
| @@ -196,7 +203,8 @@ def slice2indices(input_slices, shape): | |||||
| def check_indices(indices_size, index): | def check_indices(indices_size, index): | ||||
| """Checks indices whether is empty.""" | """Checks indices whether is empty.""" | ||||
| if indices_size < 1: | if indices_size < 1: | ||||
| raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) | |||||
| raise IndexError( | |||||
| "The tensor's index is unreasonable. index:{}".format(index)) | |||||
| return indices_size | return indices_size | ||||
| @@ -230,7 +238,7 @@ def tuple_element_is_slice(indexs): | |||||
| raise IndexError("Tensor's index cannot be empty.") | raise IndexError("Tensor's index cannot be empty.") | ||||
| if isinstance(indexs, tuple): | if isinstance(indexs, tuple): | ||||
| for _, ele in enumerate(indexs): | for _, ele in enumerate(indexs): | ||||
| if not isinstance(ele, Slice): | |||||
| if not isinstance(ele, slice): | |||||
| return False | return False | ||||
| return True | return True | ||||
| return False | return False | ||||
| @@ -285,7 +293,8 @@ def check_value_elements(data_dtype, types): | |||||
| return ALL_TENSOR | return ALL_TENSOR | ||||
| if scalars_number == len(types): | if scalars_number == len(types): | ||||
| return ALL_SCALAR | return ALL_SCALAR | ||||
| raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") | |||||
| raise TypeError( | |||||
| f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") | |||||
| @constexpr | @constexpr | ||||
| @@ -295,7 +304,8 @@ def get_index_tensor_dtype(dtype): | |||||
| return INT_ | return INT_ | ||||
| if dtype == mstype.bool_: | if dtype == mstype.bool_: | ||||
| return BOOL_ | return BOOL_ | ||||
| raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||||
| raise IndexError( | |||||
| f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||||
| @constexpr | @constexpr | ||||
| @@ -313,7 +323,8 @@ def check_index_tensor_dtype(dtype, op_name): | |||||
| """Check a tensor data type.""" | """Check a tensor data type.""" | ||||
| if dtype == mstype.int32: | if dtype == mstype.int32: | ||||
| return True | return True | ||||
| raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") | |||||
| raise IndexError( | |||||
| f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") | |||||
| @constexpr | @constexpr | ||||
| @@ -332,7 +343,8 @@ def generate_broadcast_shape(shapes, op_name): | |||||
| for i, shape in enumerate(shapes): | for i, shape in enumerate(shapes): | ||||
| logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | ||||
| try: | try: | ||||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||||
| broadcast_shape = op_utils.get_broadcast_shape( | |||||
| broadcast_shape, shape, op_name) | |||||
| except ValueError as ex: | except ValueError as ex: | ||||
| raise IndexError(ex) | raise IndexError(ex) | ||||
| return tuple(broadcast_shape) | return tuple(broadcast_shape) | ||||
| @@ -398,7 +410,8 @@ def convert_ellipsis_to_tensors(slice_number, | |||||
| if isinstance(ele, tuple): | if isinstance(ele, tuple): | ||||
| shape.extend([1] * len(ele)) | shape.extend([1] * len(ele)) | ||||
| if array is None: | if array is None: | ||||
| raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.") | |||||
| raise ValueError( | |||||
| f"For '{op_name}', generate tensors from ellipsis failed.") | |||||
| array = np.reshape(array, shape) | array = np.reshape(array, shape) | ||||
| reps = compute_multiples(shape, final_shape) | reps = compute_multiples(shape, final_shape) | ||||
| tensor = Tensor(np.tile(array, reps)) | tensor = Tensor(np.tile(array, reps)) | ||||
| @@ -428,7 +441,8 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n | |||||
| else: | else: | ||||
| shape.append(1) | shape.append(1) | ||||
| if array is None: | if array is None: | ||||
| raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.") | |||||
| raise ValueError( | |||||
| f"For '{op_name}', generate tensor from 'slice' failed.") | |||||
| array = np.reshape(array, shape) | array = np.reshape(array, shape) | ||||
| reps = compute_multiples(shape, final_shape) | reps = compute_multiples(shape, final_shape) | ||||
| tensor = Tensor(np.tile(array, reps)) | tensor = Tensor(np.tile(array, reps)) | ||||
| @@ -523,14 +537,15 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||||
| tensor_count += 1 | tensor_count += 1 | ||||
| elif isinstance(ele_type, mstype.slice_type): | elif isinstance(ele_type, mstype.slice_type): | ||||
| slice_obj = slice(slice_indexes[slice_count].start, | slice_obj = slice(slice_indexes[slice_count].start, | ||||
| slice_indexes[slice_count].end, | |||||
| slice_indexes[slice_count].stop, | |||||
| slice_indexes[slice_count].step) | slice_indexes[slice_count].step) | ||||
| # Use list to represent slicing result. | # Use list to represent slicing result. | ||||
| indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] | indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] | ||||
| slice_count += 1 | slice_count += 1 | ||||
| elif isinstance(ele_type, mstype.ellipsis_type): | elif isinstance(ele_type, mstype.ellipsis_type): | ||||
| if ellipsis_num != 0: | if ellipsis_num != 0: | ||||
| raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.") | |||||
| raise IndexError( | |||||
| f"For '{op_name}', the index could only contain one ellipsis.") | |||||
| ellipsis_occupied_dims = data_rank - indexes_size + 1 | ellipsis_occupied_dims = data_rank - indexes_size + 1 | ||||
| for j in range(pos, pos + ellipsis_occupied_dims): | for j in range(pos, pos + ellipsis_occupied_dims): | ||||
| # Use list to represent slicing result. | # Use list to represent slicing result. | ||||
| @@ -540,7 +555,8 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||||
| raise IndexError(f"For '{op_name}', the index elements only support " | raise IndexError(f"For '{op_name}', the index elements only support " | ||||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") | f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") | ||||
| broadcast_shape, final_shape, indexes_shapes_info = \ | broadcast_shape, final_shape, indexes_shapes_info = \ | ||||
| _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name) | |||||
| _derive_result_shape_info_from_tuple_of_mixed_tensors( | |||||
| indexes_info, index_tensors_info, op_name) | |||||
| return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims | return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims | ||||
| @@ -556,10 +572,12 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te | |||||
| """Derive the resulting shape information from the a tuple index of mixed tensors.""" | """Derive the resulting shape information from the a tuple index of mixed tensors.""" | ||||
| index_tensor_info_key = list(index_tensors_info.keys()) | index_tensor_info_key = list(index_tensors_info.keys()) | ||||
| index_tensor_info_value = list(index_tensors_info.values()) | index_tensor_info_value = list(index_tensors_info.values()) | ||||
| broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name) | |||||
| broadcast_shape = generate_broadcast_shape( | |||||
| index_tensor_info_value, op_name) | |||||
| final_shape = [] | final_shape = [] | ||||
| indexes_shapes_info = [] | indexes_shapes_info = [] | ||||
| mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key) | |||||
| mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous( | |||||
| index_tensor_info_key) | |||||
| if mixed_tensors_continuous: | if mixed_tensors_continuous: | ||||
| tensor_shape_dealt = False | tensor_shape_dealt = False | ||||
| for ele in indexes_info.values(): | for ele in indexes_info.values(): | ||||
| @@ -638,3 +656,98 @@ def get_np_eps(input_dtype): | |||||
| nptype = mstype.dtype_to_nptype(input_dtype) | nptype = mstype.dtype_to_nptype(input_dtype) | ||||
| eps = np.finfo(nptype).eps | eps = np.finfo(nptype).eps | ||||
| return float(eps) | return float(eps) | ||||
| @constexpr | |||||
| def check_number_index_type(number): | |||||
| """Check if it is int or bool number""" | |||||
| if isinstance(number, bool): | |||||
| return BOOL_ | |||||
| if isinstance(number, int): | |||||
| return INT_ | |||||
| raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} " | |||||
| .format(number, type(number))) | |||||
| @constexpr | |||||
| def get_stride_info_from_slice(data_shape, slice_index): | |||||
| """Get stride info from a python slice""" | |||||
| begin, end, step = get_slice_stride(data_shape[0], slice_index) | |||||
| begin_strides = [begin] | |||||
| end_strides = [end] | |||||
| step_strides = [step] | |||||
| for end in data_shape[1:]: | |||||
| begin_strides.append(0) | |||||
| end_strides.append(end) | |||||
| step_strides.append(1) | |||||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides) | |||||
| @constexpr | |||||
| def get_stride_info_from_integer(data_shape, number): | |||||
| """Get stride info from a integer""" | |||||
| begin_strides = [number] | |||||
| end_strides = [number+1] | |||||
| step_strides = [1] | |||||
| for end in data_shape[1:]: | |||||
| begin_strides.append(0) | |||||
| end_strides.append(end) | |||||
| step_strides.append(1) | |||||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides) | |||||
| def get_slice_stride(dim_size, index_slice): | |||||
| """Get slice stride info""" | |||||
| step = 1 if index_slice.step is None else index_slice.step | |||||
| start_default = 0 | |||||
| stop_default = dim_size | |||||
| if step < 0: | |||||
| start_default = -1 | |||||
| stop_default = -(dim_size+1) | |||||
| start = start_default if index_slice.start is None else index_slice.start | |||||
| stop = stop_default if index_slice.stop is None else index_slice.stop | |||||
| return start, stop, step | |||||
| @constexpr | |||||
| def get_stride_info_from_tuple(data_shape, index_tuple): | |||||
| """Get stride info from a tuple""" | |||||
| begin_strides = [] | |||||
| end_strides = [] | |||||
| step_strides = [] | |||||
| index_size = len(index_tuple) | |||||
| data_shape_size = len(data_shape) | |||||
| shrink_axis = 0 | |||||
| index_count = 0 | |||||
| ellipsis_count = 0 | |||||
| for idx, item in enumerate(index_tuple): | |||||
| if isinstance(item, slice): | |||||
| start, stop, step = get_slice_stride(data_shape[idx], item) | |||||
| begin_strides.append(start) | |||||
| end_strides.append(stop) | |||||
| step_strides.append(step) | |||||
| index_count = index_count + 1 | |||||
| elif isinstance(item, int): | |||||
| begin_strides.append(item) | |||||
| end_strides.append(item + 1) | |||||
| step_strides.append(1) | |||||
| shrink_axis = shrink_axis + (1 << index_count) | |||||
| index_count = index_count + 1 | |||||
| elif item is ...: | |||||
| ellipsis_count = ellipsis_count + 1 | |||||
| if ellipsis_count > 1: | |||||
| raise IndexError("An index can have only one ellipsis (...)") | |||||
| ellipsis_range_size = data_shape_size - (index_size - 1) | |||||
| begin_strides.extend([0] * (ellipsis_range_size)) | |||||
| end_strides.extend( | |||||
| [i for i in data_shape[index_count: index_count + (ellipsis_range_size)]]) | |||||
| step_strides.extend([1] * (ellipsis_range_size)) | |||||
| index_count = index_count + ellipsis_range_size | |||||
| else: | |||||
| raise IndexError("Not supported index data type, got ", | |||||
| item, " type is ", type(item)) | |||||
| for item in range(index_count, data_shape_size): | |||||
| begin_strides.append(0) | |||||
| end_strides.append(data_shape[item]) | |||||
| step_strides.append(1) | |||||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis | |||||
| @@ -15,7 +15,6 @@ | |||||
| """Implementation for getitem.""" | """Implementation for getitem.""" | ||||
| from . import _compile_utils as compile_utils | from . import _compile_utils as compile_utils | ||||
| from . import _constexpr_utils as const_utils | |||||
| from .. import base | from .. import base | ||||
| from ... import functional as F | from ... import functional as F | ||||
| @@ -50,29 +49,6 @@ _tuple_slice = _TupleSlice('tuple_slice') | |||||
| """_tuple_slice is an metafuncgraph object which will slice a tuple.""" | """_tuple_slice is an metafuncgraph object which will slice a tuple.""" | ||||
| class _TensorSlice(base.TensorSlice_): | |||||
| """ | |||||
| Slices a tensor. | |||||
| Inputs: | |||||
| data (Tensor): A tensor to be sliced. | |||||
| s (slice): The index to slice tuple data. | |||||
| Outputs: | |||||
| Tensor, consists of some elements of data. | |||||
| """ | |||||
| def __init__(self, name): | |||||
| base.TensorSlice_.__init__(self, name) | |||||
| def __call__(self, *args): | |||||
| pass | |||||
| _tensor_slice = _TensorSlice('tensor_slice') | |||||
| """_tensor_slice is an metafuncgraph object which will slice a tensor.""" | |||||
| class _TupleGetItemTensor(base.TupleGetItemTensor_): | class _TupleGetItemTensor(base.TupleGetItemTensor_): | ||||
| """ | """ | ||||
| Getting item of tuple by tensor index. | Getting item of tuple by tensor index. | ||||
| @@ -182,13 +158,13 @@ def _tensor_getitem_by_number(data, number_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is as same as the element type of data. | Tensor, element type is as same as the element type of data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, number_index) | |||||
| return compile_utils.tensor_index_by_number(data, number_index) | |||||
| @getitem.register("Tensor", "None") | @getitem.register("Tensor", "None") | ||||
| def _tensor_getitem_by_none(data, index): | def _tensor_getitem_by_none(data, index): | ||||
| """ | """ | ||||
| Getting item of tensor by None. | |||||
| For none indexing , expand data with one dim | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): A tensor. | data (Tensor): A tensor. | ||||
| @@ -197,7 +173,7 @@ def _tensor_getitem_by_none(data, index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is as same as the element type of data. | Tensor, element type is as same as the element type of data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, index) | |||||
| return F.expand_dims(data, 0) | |||||
| @getitem.register("Tensor", "Slice") | @getitem.register("Tensor", "Slice") | ||||
| @@ -212,13 +188,13 @@ def _tensor_getitem_by_slice(data, slice_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, slice_index) | |||||
| return compile_utils.tensor_index_by_slice(data, slice_index) | |||||
| @getitem.register("Tensor", "Tensor") | @getitem.register("Tensor", "Tensor") | ||||
| def _tensor_getitem_by_tensor(data, tensor_index): | def _tensor_getitem_by_tensor(data, tensor_index): | ||||
| """ | """ | ||||
| Getting item of tensor by slice. | |||||
| Getting item of tensor by tensor indice. | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): A tensor. | data (Tensor): A tensor. | ||||
| @@ -227,18 +203,13 @@ def _tensor_getitem_by_tensor(data, tensor_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = None | |||||
| if check_dtypes: | |||||
| result = F.gather(data, tensor_index, 0) | |||||
| return result | |||||
| return compile_utils.tensor_index_by_tensor(data, tensor_index) | |||||
| @getitem.register("Tensor", "Tuple") | @getitem.register("Tensor", "Tuple") | ||||
| def _tensor_getitem_by_tuple(data, tuple_index): | def _tensor_getitem_by_tuple(data, tuple_index): | ||||
| """ | """ | ||||
| Getting item of tensor by slice tuple. | |||||
| Getting item of tensor by tuple. | |||||
| Inputs: | Inputs: | ||||
| data (Tensor): A tensor. | data (Tensor): A tensor. | ||||
| @@ -247,13 +218,7 @@ def _tensor_getitem_by_tuple(data, tuple_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, element type is same as the element type of data. | Tensor, element type is same as the element type of data. | ||||
| """ | """ | ||||
| indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) | |||||
| index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) | |||||
| if index_elements_type == const_utils.NO_TENSOR: | |||||
| return _tensor_slice(data, tuple_index) | |||||
| if index_elements_type == const_utils.ALL_TENSOR: | |||||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) | |||||
| return compile_utils.tensor_index_by_tuple(data, tuple_index) | |||||
| @getitem.register("Tensor", "Ellipsis") | @getitem.register("Tensor", "Ellipsis") | ||||
| @@ -268,22 +233,4 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||||
| Outputs: | Outputs: | ||||
| Tensor, same as data. | Tensor, same as data. | ||||
| """ | """ | ||||
| return _tensor_slice(data, ellipsis_index) | |||||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||||
| """Tensor getitem by a tuple of tensor.""" | |||||
| indices = compile_utils.generate_indices_from_tuple_of_tensor(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | |||||
| return result | |||||
| def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||||
| """Tensor getitem by a tuple of mixed tensor.""" | |||||
| indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, | |||||
| tuple_index, | |||||
| const_utils.TENSOR_GETITEM) | |||||
| result = F.gather_nd(data, indices) | |||||
| return result | |||||
| return data | |||||
| @@ -0,0 +1,741 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test_tensor_slice """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore import context | |||||
| from mindspore import dtype as mstype | |||||
| from mindspore.nn import Cell | |||||
| def setup_module(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| class NetWorkSlicePositive(Cell): | |||||
| def __init__(self): | |||||
| super(NetWorkSlicePositive, self).__init__() | |||||
| self.tensor_ret0 = Tensor(np.ones([1, 2, 3], np.int32)) | |||||
| self.tensor_ret1 = Tensor(np.ones([4, 8, 10], np.int32)) | |||||
| self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32)) | |||||
| self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret0 = tensor[3:4:1, 1:5:2, 3:6:1] + self.tensor_ret0 | |||||
| ret1 = tensor[-6:4:1, 0:8:1, ::1] + self.tensor_ret1 | |||||
| ret2 = tensor[::, ::, ::] + self.tensor_ret2 | |||||
| ret3 = tensor[::2] + self.tensor_ret3 | |||||
| return ret0, ret1, ret2, ret3 | |||||
| def test_slice_positive(): | |||||
| net = NetWorkSlicePositive() | |||||
| input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) | |||||
| input_0 = Tensor(input_np) | |||||
| output0, output1, output2, output3 = net(input_0) | |||||
| assert np.all(output0.asnumpy() == input_np[3:4:1, 1:5:2, 3:6:1] + np.ones([1, 2, 3])) | |||||
| assert np.all(output1.asnumpy() == input_np[-6:4:1, 0:8:1, ::1] + np.ones([4, 8, 10])) | |||||
| assert np.all(output2.asnumpy() == input_np[::, ::, ::] + np.ones([6, 8, 10])) | |||||
| assert np.all(output3.asnumpy() == input_np[::2] + np.ones([3, 8, 10])) | |||||
| class NetWorkSliceEllipsis(Cell): | |||||
| def __init__(self): | |||||
| super(NetWorkSliceEllipsis, self).__init__() | |||||
| self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32)) | |||||
| self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32)) | |||||
| self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0 | |||||
| ret1 = tensor[...] + self.tensor_ret1 | |||||
| ret2 = tensor[None] + self.tensor_ret2 | |||||
| ret3 = tensor[True] + self.tensor_ret2 | |||||
| return ret0, ret1, ret2, ret3 | |||||
| def Xtest_slice_ellipsis(): | |||||
| net = NetWorkSliceEllipsis() | |||||
| input_np = np.arange(6*7*8*9).reshape(6, 7, 8, 9).astype(np.int32) | |||||
| input_0 = Tensor(input_np) | |||||
| output0, output1, output2, output3 = net(input_0) | |||||
| assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([1, 2, 3])) | |||||
| assert np.all(output1.asnumpy() == input_np[...] + np.ones([6, 7, 8, 9])) | |||||
| assert np.all(output2.asnumpy() == input_np[None] + np.ones([6, 7, 8, 9])) | |||||
| assert np.all(output3.asnumpy() == input_np[True] + np.ones([1, 6, 7, 8, 9])) | |||||
| class NetWorkReduceDimension(Cell): | |||||
| def __init__(self): | |||||
| super(NetWorkReduceDimension, self).__init__() | |||||
| self.tensor_ret1 = Tensor(np.ones([3, 10], np.int32)) | |||||
| self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32)) | |||||
| self.tensor_ret3 = Tensor(np.array(8, np.int32)) | |||||
| self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret1 = tensor[::2, 1, ::1] + self.tensor_ret1 | |||||
| ret2 = tensor[::, ::, 0] + self.tensor_ret2 | |||||
| ret3 = tensor[3, 2, 5] + self.tensor_ret3 | |||||
| ret4 = tensor[1] + self.tensor_ret4 | |||||
| return ret1, ret2, ret3, ret4 | |||||
| def Xtest_reduce_dimension(): | |||||
| net = NetWorkReduceDimension() | |||||
| input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) | |||||
| input_0 = Tensor(input_np) | |||||
| output1, output2, output3, output4 = net(input_0) | |||||
| assert np.all(output1.asnumpy() == input_np[::2, 1, ::1] + np.ones([3, 10])) | |||||
| assert np.all(output2.asnumpy() == input_np[::, ::, 0] + np.ones([6, 8])) | |||||
| assert np.all(output3.asnumpy() == input_np[3, 2, 5] + np.array(8, np.int32)) | |||||
| assert np.all(output4.asnumpy() == input_np[1] + np.ones([8, 10])) | |||||
| class NetWorkSliceStep(Cell): | |||||
| def __init__(self): | |||||
| super(NetWorkSliceStep, self).__init__() | |||||
| self.tensor_ret1 = Tensor(np.ones([6, 5, 10], np.int32)) | |||||
| self.tensor_ret2 = Tensor(np.ones([3, 5, 5], np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret1 = tensor[::1, -5::, ::-1] + self.tensor_ret1 | |||||
| ret2 = tensor[::2, -5::, ::2] + self.tensor_ret2 | |||||
| return ret1, ret2 | |||||
| def Xtest_step_negative(): | |||||
| net = NetWorkSliceEllipsis() | |||||
| input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) | |||||
| input_0 = Tensor(input_np) | |||||
| output1, output2 = net(input_0) | |||||
| assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 8, 10])) | |||||
| assert np.all(output2.asnumpy() == input_np[::2, -5::, ::2] + np.ones([3, 5, 5])) | |||||
| class TensorGetItemByThreeTensors(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByThreeTensors, self).__init__() | |||||
| self.const0 = Tensor(np.ones((4, 5, 8, 10)), mstype.int32) | |||||
| self.const1 = Tensor(np.ones((3, 4, 5, 10)), mstype.int32) | |||||
| self.const2 = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) | |||||
| def construct(self, x, index_0, index_1, index_2): | |||||
| ret0 = x[index_0] + self.const0 | |||||
| ret1 = x[index_0, index_1] + self.const1 | |||||
| ret2 = x[index_0, index_1, index_2] + self.const2 | |||||
| return ret0, ret1, ret2 | |||||
| def Xtest_getitem_by_tensors(): | |||||
| net = TensorGetItemByThreeTensors() | |||||
| input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) | |||||
| index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32) | |||||
| index_1 = np.random.randint(6, size=(4, 5)).astype(np.int32) | |||||
| index_2 = np.random.randint(6, size=(5, 3, 4, 5)).astype(np.int32) | |||||
| input_x_ms = Tensor(input_x) | |||||
| index_0_ms = Tensor(index_0) | |||||
| index_1_ms = Tensor(index_1) | |||||
| input_2_ms = Tensor(index_2) | |||||
| output0, output1, output2 = net(input_x_ms, index_0_ms, index_1_ms, input_2_ms) | |||||
| assert np.all(output0.asnumpy() == input_x[index_0] + np.ones([4, 5, 8, 10])) | |||||
| assert np.all(output1.asnumpy() == input_x[index_0, index_1] + np.ones([3, 4, 5, 10])) | |||||
| assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5])) | |||||
| class TensorGetItemByMixedTensors_0(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_0, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_1(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_1, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_2, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[0, index_0, index_1, ..., 3] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_3(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_3, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1): | |||||
| ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_4(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_4, self).__init__() | |||||
| self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1, index_2): | |||||
| ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_5(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_5, self).__init__() | |||||
| self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1, index_2): | |||||
| ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensors_6(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensors_6, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) | |||||
| def construct(self, tensor, index_0, index_1, index_2): | |||||
| ret = tensor[..., index_0, index_1, index_2, 3] + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors_0(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByMixedTensors_0, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) | |||||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), | |||||
| mstype.float32), | |||||
| name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors_1(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByMixedTensors_1, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) | |||||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), | |||||
| name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors_2(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByMixedTensors_2, self).__init__() | |||||
| self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16)) | |||||
| self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16), | |||||
| name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[..., index_0, index_1, index_2, 3] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorGetItemByMixedTensorsTypeError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensorsTypeError, self).__init__() | |||||
| def construct(self, x, index_0, index_1): | |||||
| ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] | |||||
| return ret | |||||
| class TensorGetItemByMixedTensorsNumberError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorGetItemByMixedTensorsNumberError, self).__init__() | |||||
| def construct(self, x, index_0, index_1): | |||||
| ret = x[index_0, index_1, 0:3, ..., index_1, index_0] | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByOneTensorWithNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index): | |||||
| self.param[index] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByOneTensorWithTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index, value): | |||||
| self.param[index] = value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithTupleOfNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index): | |||||
| self.param[index] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByOneTensorWithTupleOfTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x") | |||||
| def construct(self, index, value_0, value_1, value_2): | |||||
| self.param[index] = (value_0, value_1, value_2) | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByTensorsWithNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[index_0, index_1, index_2] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value): | |||||
| self.param[index_0, index_1, index_2] = value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTensorNumberError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, index_3, value): | |||||
| self.param[index_0, index_1, index_2, index_3] = value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTupleOfNumber(Cell): | |||||
| def __init__(self, value): | |||||
| super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = value | |||||
| def construct(self, index_0, index_1, index_2): | |||||
| self.param[index_0, index_1, index_2] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTupleOfTensor(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): | |||||
| self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| def construct(self, index_0, index_1, index_2, value_0, value_1): | |||||
| self.param[index_0, index_1, index_2] = (value_0, value_1) | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorSetItemByMixedTensors(Cell): | |||||
| def __init__(self): | |||||
| super(TensorSetItemByMixedTensors, self).__init__() | |||||
| self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) | |||||
| self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") | |||||
| self.value = 99.0 | |||||
| def construct(self, index_0, index_1): | |||||
| self.param[index_0, index_1, 0:6] = self.value | |||||
| ret = self.param + self.const | |||||
| return ret | |||||
| class TensorAssignWithSliceError1(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSliceError1, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[1:3:-1, ::] = b | |||||
| return a | |||||
| class TensorAssignWithSliceError2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSliceError2, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[1:3:-1] = b | |||||
| return a | |||||
| class TensorAssignWithSlice2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSlice2, self).__init__() | |||||
| def construct(self, a, b, ck): | |||||
| a[1:5] = b | |||||
| a[3:4] = 5 | |||||
| a[-1:1:-1] = b | |||||
| a[-1:3:-1] = 5 | |||||
| a[::] = b | |||||
| a[::] = 9 | |||||
| z = a + ck | |||||
| return z | |||||
| class TensorAssignWithSlice(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithSlice, self).__init__() | |||||
| self.c = 2 | |||||
| def construct(self, a, b, ck): | |||||
| a[1:3, ::] = b | |||||
| a[2:3:, 3:] = b | |||||
| a[::] = b | |||||
| a[::] = self.c | |||||
| a[::, ::] = b | |||||
| a[::, ::] = self.c | |||||
| a[2:3:, 0:, 4:1:-1] = b | |||||
| a[2:3:, 0:, 4:1:-1] = self.c | |||||
| z = a + ck | |||||
| return z | |||||
| def test_tensor_assign(): | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| net = TensorAssignWithSlice() | |||||
| net2 = TensorAssignWithSlice2() | |||||
| net_e1 = TensorAssignWithSliceError1() | |||||
| net_e2 = TensorAssignWithSliceError2() | |||||
| a = np.arange(60).reshape(3, 4, 5) | |||||
| ck = np.arange(60).reshape(3, 4, 5) | |||||
| b = Tensor([1], dtype=mstype.float32) | |||||
| Ta = Tensor(a, dtype=mstype.float32) | |||||
| Tck = Tensor(ck, dtype=mstype.float32) | |||||
| Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32) | |||||
| Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32) | |||||
| Tb = Tensor([1, 3], dtype=mstype.float32) | |||||
| Tc = Tensor([], dtype=mstype.float32) | |||||
| t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) | |||||
| tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) | |||||
| net(Ta, b, Tck) | |||||
| net2(t, b, tck) | |||||
| # Error for A[Slice] = Number | |||||
| # 1. A[Slice] = Number, Slice error | |||||
| with pytest.raises(IndexError): | |||||
| net_e2(t, 2) | |||||
| # Error for A[Slice] = U, U is a Tensor | |||||
| # 1. A[Slice] = U, u.size is error | |||||
| with pytest.raises(ValueError): | |||||
| net2(t, Tb, tck) | |||||
| # 2. A[Slice] = U, U is empty | |||||
| with pytest.raises(ValueError): | |||||
| net2(t, Tc, tck) | |||||
| # 3. A[Slice] = U, U.size error | |||||
| with pytest.raises(ValueError): | |||||
| net2(t, Tb, tck) | |||||
| # Error for A[Tuple(Slice...)] = Tensor | |||||
| # 1. A[Tuple(Slice...)] = U, U is empty | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tc, Tck) | |||||
| # 2. A[Tuple(Slice...)] = U, U.size error | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tb, Tck) | |||||
| # 3. A[Tuple(Slice...)] = U, Slice error | |||||
| with pytest.raises(IndexError): | |||||
| net_e1(Ta, b) | |||||
| # Error for A[Tuple(Slice...)] = Number | |||||
| # 1. A[Tuple(Slice...)] = Number, Slice error | |||||
| with pytest.raises(IndexError): | |||||
| net_e1(Ta, 2) | |||||
| net = TensorAssignWithInteger() | |||||
| # Error for A[Number] = scalar/Tensor | |||||
| # 1. A[Number] = U, U is a Tensor, u.size not match | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tb, Tck) | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tc, Tck) | |||||
| # 2. A[Number] = U, the number index error | |||||
| with pytest.raises(IndexError): | |||||
| net(Ta4d, b, Ta4d_ck) | |||||
| # Error for A[(n,m)] = scalar/Tensor | |||||
| # 1. A[(n,m)] = U, U is a tensor. u.size not match | |||||
| net = TensorAssignWithTupleInteger() | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tc, Tck) | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tb, Tck) | |||||
| # 2. A[(n,m)] = U, the number index error | |||||
| with pytest.raises(IndexError): | |||||
| net(Ta4d, b, Ta4d_ck) | |||||
| # Error for A[...] = U or A[1:, ...] = u | |||||
| # 1. A[...] = scalar/tensor | |||||
| net = TensorAssignWithEllipsis() | |||||
| net(Ta, Ta4d) | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tc) | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tb) | |||||
| # 2. A[::, 1:, ...] = scalar/tensor | |||||
| net = TensorAssignWithTupleEllipsis() | |||||
| net(Ta, b) | |||||
| Tc = Tensor(1, mstype.float32) | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tc) | |||||
| with pytest.raises(ValueError): | |||||
| net(Ta, Tb) | |||||
| class TensorAssignWithTupleEllipsis2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithTupleEllipsis2, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[1:, ..., ::] = b | |||||
| return a | |||||
| class TensorAssignWithTupleEllipsis(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithTupleEllipsis, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[:2, ...] = 1 | |||||
| a[1:, ...] = b | |||||
| return a | |||||
| class TensorAssignWithEllipsis(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithEllipsis, self).__init__() | |||||
| def construct(self, a, b): | |||||
| a[...] = 1 | |||||
| a[...] = b | |||||
| return a | |||||
| class TensorAssignWithInteger(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithInteger, self).__init__() | |||||
| def construct(self, a, b, ck): | |||||
| a[1] = 1 | |||||
| a[0] = b | |||||
| z = a + ck | |||||
| return z | |||||
| class TensorAssignWithTupleInteger(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithTupleInteger, self).__init__() | |||||
| def construct(self, a, b, ck): | |||||
| a[(1)] = 1 | |||||
| a[(1)] = b | |||||
| a[(1, 1)] = b | |||||
| a[(1, 1)] = 1 | |||||
| z = a + ck | |||||
| return z | |||||
| class TensorAssignWithBoolTensorIndex(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndex, self).__init__() | |||||
| self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) | |||||
| self.u_scalar = 5 | |||||
| def construct(self, a, b, c, u_tensor): | |||||
| a[c] = self.u_scalar | |||||
| a[b] = u_tensor | |||||
| z = a + self.t | |||||
| return z | |||||
| class TensorAssignWithBoolTensorIndexError(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndexError, self).__init__() | |||||
| def construct(self, a, b, c, u_tensor): | |||||
| a[b][c] = u_tensor | |||||
| return a | |||||
| class TensorAssignWithBoolTensorIndex2(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndex2, self).__init__() | |||||
| self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32) | |||||
| self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) | |||||
| self.u_scalar = 5 | |||||
| def construct(self, a, u_tensor): | |||||
| a[a > 8] = u_tensor | |||||
| a[a >= 6] = self.u_scalar | |||||
| a[a < 3] = self.u_scalar | |||||
| a[a <= 5] = u_tensor | |||||
| a[a == 5] = self.u_scalar | |||||
| z = a + self.t | |||||
| return z | |||||
| class TensorAssignWithBoolTensorIndex2Error(Cell): | |||||
| def __init__(self): | |||||
| super(TensorAssignWithBoolTensorIndex2Error, self).__init__() | |||||
| def construct(self, a, u_tensor): | |||||
| a[a > 8][a > 5] = u_tensor | |||||
| return a | |||||
| def Xtest_tensor_assign_bool_index(): | |||||
| a = np.arange(60).reshape(3, 4, 5) | |||||
| b = a > 5 | |||||
| c = a < 3 | |||||
| Ta = Tensor(a, dtype=mstype.float32) | |||||
| Tb = Tensor(b) | |||||
| Tc = Tensor(c) | |||||
| Td = Tensor([True, True]) | |||||
| u_tensor = Tensor([1], dtype=mstype.float32) | |||||
| u_tensor_error = Tensor([1, 2], dtype=mstype.float32) | |||||
| u_scalar = 5 | |||||
| net1 = TensorAssignWithBoolTensorIndex() | |||||
| net2 = TensorAssignWithBoolTensorIndex2() | |||||
| net1(Ta, Tb, Tc, u_tensor) | |||||
| net1(Ta, Tb, Tc, u_tensor) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Td, Tc, u_tensor) | |||||
| with pytest.raises(IndexError): | |||||
| net1(Ta, u_tensor, Tc, u_tensor) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Tb, Td, u_tensor) | |||||
| with pytest.raises(IndexError): | |||||
| net1(Ta, Tb, Ta, u_tensor) | |||||
| with pytest.raises(ValueError): | |||||
| net1(Ta, Tb, Tc, u_tensor_error) | |||||
| # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) | |||||
| with pytest.raises(ValueError): | |||||
| net2(Ta, u_tensor_error) | |||||
| net3 = TensorAssignWithBoolTensorIndexError() | |||||
| with pytest.raises(AttributeError): | |||||
| net3(Ta, Tb, Tc, u_tensor) | |||||
| with pytest.raises(AttributeError): | |||||
| net3(Ta, Tb, Tc, u_scalar) | |||||
| net4 = TensorAssignWithBoolTensorIndex2Error() | |||||
| with pytest.raises(AttributeError): | |||||
| net4(Ta, u_tensor) | |||||
| with pytest.raises(AttributeError): | |||||
| net4(Ta, u_scalar) | |||||
| def Xtest_tensor_slice_reduce_out_of_bounds_neg(): | |||||
| class NetWork(Cell): | |||||
| def __init__(self): | |||||
| super(NetWork, self).__init__() | |||||
| self.tensor_ret = Tensor(np.array(9, np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret = tensor[-7, 3, 4] | |||||
| return ret | |||||
| input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) | |||||
| net = NetWork() | |||||
| with pytest.raises(ValueError) as ex: | |||||
| net(input_tensor) | |||||
| assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str( | |||||
| ex.value) | |||||
| def Xtest_tensor_slice_reduce_out_of_bounds_positive(): | |||||
| class NetWork(Cell): | |||||
| def __init__(self): | |||||
| super(NetWork, self).__init__() | |||||
| self.tensor_ret = Tensor(np.array(9, np.int32)) | |||||
| def construct(self, tensor): | |||||
| ret = tensor[6, 3, 4] | |||||
| return ret | |||||
| input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) | |||||
| net = NetWork() | |||||
| with pytest.raises(ValueError) as ex: | |||||
| net(input_tensor) | |||||
| assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value) | |||||
| @@ -240,156 +240,6 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) { | |||||
| ASSERT_EQ(real, expect); | ASSERT_EQ(real, expect); | ||||
| } | } | ||||
| TEST_F(TestComposite, test_TensorSliceBySlice) { | |||||
| MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice"); | |||||
| FuncGraphPtr tensorSlicePtrGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); | |||||
| AbstractBasePtrList eles; | |||||
| AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1); | |||||
| AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6); | |||||
| AbstractScalarPtr step = std::make_shared<AbstractScalar>(2); | |||||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); | |||||
| AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||||
| AbstractBasePtrList args_spec_list = {tensor, slice}; | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | |||||
| FAIL() << "Cast ret to abstract array failed."; | |||||
| } | |||||
| AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 7, 8}); | |||||
| ASSERT_EQ(*ret, *expect); | |||||
| } | |||||
| TEST_F(TestComposite, test_TensorSliceBySliceTuple) { | |||||
| MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice"); | |||||
| FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); | |||||
| AbstractBasePtrList eles; | |||||
| AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(0); | |||||
| AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6); | |||||
| AbstractScalarPtr step = std::make_shared<AbstractScalar>(2); | |||||
| AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||||
| eles.push_back(slice); | |||||
| start_index = std::make_shared<AbstractScalar>(1); | |||||
| stop_index = std::make_shared<AbstractScalar>(5); | |||||
| step = std::make_shared<AbstractScalar>(1); | |||||
| slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||||
| eles.push_back(slice); | |||||
| start_index = std::make_shared<AbstractScalar>(2); | |||||
| stop_index = std::make_shared<AbstractScalar>(8); | |||||
| step = std::make_shared<AbstractScalar>(3); | |||||
| slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||||
| eles.push_back(slice); | |||||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | |||||
| FAIL() << "Cast ret to abstract array failed."; | |||||
| } | |||||
| AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 4, 2}); | |||||
| ASSERT_EQ(*ret, *expect); | |||||
| } | |||||
| TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) { | |||||
| MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice"); | |||||
| FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); | |||||
| AbstractBasePtrList eles; | |||||
| AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1); | |||||
| AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(5); | |||||
| AbstractScalarPtr step = std::make_shared<AbstractScalar>(2); | |||||
| AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||||
| eles.push_back(slice); | |||||
| AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1); | |||||
| eles.push_back(elem_index); | |||||
| start_index = std::make_shared<AbstractScalar>(2); | |||||
| stop_index = std::make_shared<AbstractScalar>(6); | |||||
| step = std::make_shared<AbstractScalar>(1); | |||||
| slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); | |||||
| eles.push_back(slice); | |||||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | |||||
| FAIL() << "Cast ret to abstract array failed."; | |||||
| } | |||||
| AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({2, 4}); | |||||
| ASSERT_EQ(*ret, *expect); | |||||
| } | |||||
| TEST_F(TestComposite, test_TensorSliceByScalar) { | |||||
| MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice"); | |||||
| FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); | |||||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); | |||||
| AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2); | |||||
| AbstractBasePtrList args_spec_list = {tensor, start_index}; | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | |||||
| FAIL() << "Cast ret to abstract array failed."; | |||||
| } | |||||
| AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({7, 8}); | |||||
| ASSERT_EQ(*ret, *expect); | |||||
| } | |||||
| TEST_F(TestComposite, test_TensorSliceByScalarTuple) { | |||||
| MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice"); | |||||
| FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); | |||||
| AbstractBasePtrList eles; | |||||
| AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1); | |||||
| eles.push_back(elem_index); | |||||
| elem_index = std::make_shared<AbstractScalar>(3); | |||||
| eles.push_back(elem_index); | |||||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | |||||
| FAIL() << "Cast ret to abstract array failed."; | |||||
| } | |||||
| AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({8}); | |||||
| ASSERT_EQ(*ret, *expect); | |||||
| } | |||||
| TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) { | |||||
| MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice"); | |||||
| FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); | |||||
| AbstractBasePtrList eles; | |||||
| AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(3); | |||||
| eles.push_back(elem_index); | |||||
| elem_index = std::make_shared<AbstractScalar>(0); | |||||
| eles.push_back(elem_index); | |||||
| elem_index = std::make_shared<AbstractScalar>(6); | |||||
| eles.push_back(elem_index); | |||||
| AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); | |||||
| AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); | |||||
| AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; | |||||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); | |||||
| if (ret == nullptr) { | |||||
| FAIL() << "Cast ret to abstract array failed."; | |||||
| } | |||||
| AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({}); | |||||
| ASSERT_EQ(*ret, *expect); | |||||
| } | |||||
| TEST_F(TestComposite, test_UnpackCall_3args) { | TEST_F(TestComposite, test_UnpackCall_3args) { | ||||
| MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall"); | MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall"); | ||||
| FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3); | FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3); | ||||
| @@ -107,5 +107,5 @@ class TestUnsupportParam(): | |||||
| def test_Sgd_init(self): | def test_Sgd_init(self): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| paramsTensor = Tensor(np.zeros([1, 2, 3])) | |||||
| paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x") | |||||
| SGD(paramsTensor) | SGD(paramsTensor) | ||||
| @@ -25,7 +25,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ | import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ | ||||
| pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception | pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception | ||||
| class NetWorkSlicePositive(Cell): | class NetWorkSlicePositive(Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(NetWorkSlicePositive, self).__init__() | super(NetWorkSlicePositive, self).__init__() | ||||
| @@ -528,6 +527,7 @@ def test_tensor_assign(): | |||||
| # 2. A[::, 1:, ...] = scalar/tensor | # 2. A[::, 1:, ...] = scalar/tensor | ||||
| net = TensorAssignWithTupleEllipsis() | net = TensorAssignWithTupleEllipsis() | ||||
| net(Ta, b) | net(Ta, b) | ||||
| Tc = Tensor(1, mstype.float32) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net(Ta, Tc) | net(Ta, Tc) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| @@ -168,7 +168,7 @@ def test_select_grad(): | |||||
| sens = Tensor(np.ones_like(out.asnumpy()).astype(np.float32)) | sens = Tensor(np.ones_like(out.asnumpy()).astype(np.float32)) | ||||
| args = [cond, x, y, sens] | args = [cond, x, y, sens] | ||||
| gout = gfn(*args) | gout = gfn(*args) | ||||
| expect_cond = np.zeros_like(cond) | |||||
| expect_cond = np.zeros_like(cond.asnumpy()) | |||||
| expect_x = np.array([[1, 0, 0], [0, 1, 1]]) | expect_x = np.array([[1, 0, 0], [0, 1, 1]]) | ||||
| expect_y = np.array([[0, 1, 1], [1, 0, 0]]) | expect_y = np.array([[0, 1, 1], [1, 0, 0]]) | ||||
| assert np.all(gout[0].asnumpy() == expect_cond) | assert np.all(gout[0].asnumpy() == expect_cond) | ||||