| @@ -125,7 +125,7 @@ def list_len(x): | |||||
| return len(x) | return len(x) | ||||
| # only used in PyNative modes | |||||
| # only used in PyNative mode | |||||
| def partial(*args): | def partial(*args): | ||||
| """Implement `partial`.""" | """Implement `partial`.""" | ||||
| func = args[0].__call__ | func = args[0].__call__ | ||||
| @@ -133,10 +133,14 @@ def partial(*args): | |||||
| return partial_func | return partial_func | ||||
| # only used in PyNative modes | |||||
| # only used in PyNative mode | |||||
| def depend(value, expr): | def depend(value, expr): | ||||
| return value | return value | ||||
| # only used in PyNative mode | |||||
| def make_ref(key, value, ref): | |||||
| return value | |||||
| def scalar_cast(x, t): | def scalar_cast(x, t): | ||||
| """Implement scalar_cast.""" | """Implement scalar_cast.""" | ||||
| @@ -616,17 +616,19 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { | |||||
| return ExecDFGraph(info_, args, phase_s); | return ExecDFGraph(info_, args, phase_s); | ||||
| } | } | ||||
| #else | #else | ||||
| if (backend == "ge") { | |||||
| std::shared_ptr<py::object> ret_val = std::make_shared<py::object>(); | |||||
| if (backend == "ms" || backend == "ge") { | |||||
| auto ret_val = std::make_shared<py::object>(); | |||||
| if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { | if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { | ||||
| if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { | if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { | ||||
| return *ret_val; | return *ret_val; | ||||
| } | } | ||||
| } | } | ||||
| if (args.size() > 0) { | |||||
| return args[0]; | |||||
| if (backend == "ge") { | |||||
| if (args.size() > 0) { | |||||
| return args[0]; | |||||
| } | |||||
| return args; | |||||
| } | } | ||||
| return args; | |||||
| } | } | ||||
| #endif | #endif | ||||
| std::size_t full_arg_size = ArgListSize(phase_s); | std::size_t full_arg_size = ArgListSize(phase_s); | ||||
| @@ -20,11 +20,13 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <set> | #include <set> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <algorithm> | |||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "operator/composite/do_signature.h" | |||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| #include "pipeline/static_analysis/prim.h" | #include "pipeline/static_analysis/prim.h" | ||||
| #include "session/session_factory.h" | #include "session/session_factory.h" | ||||
| @@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) { | |||||
| return converted_ret; | return converted_ret; | ||||
| } | } | ||||
| py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { | |||||
| auto signature = prim->signatures(); | |||||
| std::vector<SignatureEnumDType> dtypes; | |||||
| (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), | |||||
| [](const Signature& sig) { return sig.dtype; }); | |||||
| int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | |||||
| if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||||
| return py_args; | |||||
| } | |||||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | |||||
| for (size_t i = 0; i < dtypes.size(); ++i) { | |||||
| auto it = type_indexs.find(dtypes[i]); | |||||
| if (it == type_indexs.end()) { | |||||
| (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); | |||||
| } else { | |||||
| it->second.push_back(i); | |||||
| } | |||||
| } | |||||
| std::map<SignatureEnumDType, size_t> dst_type; | |||||
| for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { | |||||
| auto type = it->first; | |||||
| auto indexs = it->second; | |||||
| if (indexs.size() < 2) { | |||||
| continue; | |||||
| } | |||||
| size_t m_index = indexs[0]; | |||||
| for (size_t i = 1; i < indexs.size(); ++i) { | |||||
| if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) { | |||||
| m_index = indexs[i]; | |||||
| } | |||||
| } | |||||
| (void)dst_type.insert(std::make_pair(type, m_index)); | |||||
| } | |||||
| py::tuple py_inputs(py_args.size()); | |||||
| for (size_t i = 0; i < py_args.size(); ++i) { | |||||
| auto it = dst_type.find(dtypes[i]); | |||||
| if (it != dst_type.end() && it->second != i && | |||||
| (py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) { | |||||
| auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]); | |||||
| if (py::isinstance<py::int_>(py_args[i])) { | |||||
| py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype()); | |||||
| } else { | |||||
| py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype()); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| py_inputs[i] = py_args[i]; | |||||
| } | |||||
| return py_inputs; | |||||
| } | |||||
| void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { | void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { | ||||
| size_t size = py_args.size(); | size_t size = py_args.size(); | ||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| @@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { | |||||
| auto op_exec_info = std::make_shared<OpExecInfo>(); | auto op_exec_info = std::make_shared<OpExecInfo>(); | ||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]); | op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]); | ||||
| if (py::isinstance<py::none>(args[PY_PRIM])) { | |||||
| py::module ops_mod = py::module::import("mindspore.ops.operations"); | |||||
| py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())(); | |||||
| op_exec_info->py_primitive = py::cast<PrimitivePyPtr>(py_primitive); | |||||
| py::dict none_attrs = py::dict(); | |||||
| op_exec_info->op_attrs = none_attrs; | |||||
| } else { | |||||
| PrimitivePyPtr prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | |||||
| auto pyobj = prim->GetPyObj(); | |||||
| if (pyobj == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "pyobj is empty"; | |||||
| } | |||||
| py::tuple py_args = args[PY_INPUTS]; | |||||
| // use python infer method | |||||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | |||||
| PynativeInfer(prim, py_args, op_exec_info.get()); | |||||
| } | |||||
| op_exec_info->py_primitive = prim; | |||||
| op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); | |||||
| auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | |||||
| auto pyobj = prim->GetPyObj(); | |||||
| if (pyobj == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "pyobj is empty"; | |||||
| } | |||||
| py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]); | |||||
| // use python infer method | |||||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | |||||
| PynativeInfer(prim, py_args, op_exec_info.get()); | |||||
| } | } | ||||
| op_exec_info->op_inputs = args[PY_INPUTS]; | |||||
| op_exec_info->py_primitive = prim; | |||||
| op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); | |||||
| op_exec_info->op_inputs = py_args; | |||||
| op_exec_info->inputs_mask = args[PY_INPUT_MASK]; | op_exec_info->inputs_mask = args[PY_INPUT_MASK]; | ||||
| if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { | if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { | ||||
| MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask"; | |||||
| MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return op_exec_info; | return op_exec_info; | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Parameter for cell.""" | """Parameter for cell.""" | ||||
| from copy import copy | |||||
| from copy import copy, deepcopy | |||||
| import numpy as np | import numpy as np | ||||
| from .initializer import initializer | from .initializer import initializer | ||||
| from .tensor import Tensor | from .tensor import Tensor | ||||
| @@ -156,16 +156,24 @@ class Parameter: | |||||
| return self.default_input | return self.default_input | ||||
| def __add__(self, other): | def __add__(self, other): | ||||
| return self.default_input + other | |||||
| res = deepcopy(self) | |||||
| res.default_input = res.default_input + other | |||||
| return res | |||||
| def __sub__(self, other): | def __sub__(self, other): | ||||
| return self.default_input - other | |||||
| res = deepcopy(self) | |||||
| res.default_input = res.default_input - other | |||||
| return res | |||||
| def __mul__(self, other): | def __mul__(self, other): | ||||
| return self.default_input * other | |||||
| res = deepcopy(self) | |||||
| res.default_input = res.default_input * other | |||||
| return res | |||||
| def __truediv__(self, other): | def __truediv__(self, other): | ||||
| return self.default_input / other | |||||
| res = deepcopy(self) | |||||
| res.default_input = res.default_input / other | |||||
| return res | |||||
| def set_parameter_data(self, data): | def set_parameter_data(self, data): | ||||
| if isinstance(data, (Tensor, list, int, float, | if isinstance(data, (Tensor, list, int, float, | ||||
| @@ -70,45 +70,60 @@ class Tensor(Tensor_): | |||||
| return str(self.__str__()) | return str(self.__str__()) | ||||
| def __add__(self, other): | def __add__(self, other): | ||||
| if not isinstance(other, Tensor): | |||||
| raise TypeError("input_data must be a tensor") | |||||
| check_type('tensor input_data', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__add__')(self, other) | out = tensor_operator_registry.get('__add__')(self, other) | ||||
| return out | return out | ||||
| def __mul__(self, other): | def __mul__(self, other): | ||||
| if not isinstance(other, Tensor): | |||||
| raise TypeError("input_data must be a tensor") | |||||
| check_type('tensor input_data', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__mul__')(self, other) | out = tensor_operator_registry.get('__mul__')(self, other) | ||||
| return out | return out | ||||
| def __neg__(self): | |||||
| return Tensor(-self.asnumpy()) | |||||
| def __iadd__(self, other): | def __iadd__(self, other): | ||||
| out = self.__add__(other) | out = self.__add__(other) | ||||
| return out | return out | ||||
| def __radd__(self, other): | |||||
| check_type('tensor operation input', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__add__')(other, self) | |||||
| return out | |||||
| def __imul__(self, other): | def __imul__(self, other): | ||||
| out = self.__mul__(other) | out = self.__mul__(other) | ||||
| return out | return out | ||||
| def __rmul__(self, other): | |||||
| check_type('tensor operation input', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__mul__')(other, self) | |||||
| return out | |||||
| def __truediv__(self, other): | def __truediv__(self, other): | ||||
| if isinstance(other, (int, float)): | |||||
| other_tensor = Tensor(other, self.dtype()) | |||||
| elif isinstance(other, Tensor): | |||||
| other_tensor = other | |||||
| else: | |||||
| raise TypeError("unsupported type for div operation") | |||||
| out = tensor_operator_registry.get('__div__')(self, other_tensor) | |||||
| check_type('tensor operation input', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__div__')(self, other) | |||||
| return out | |||||
| def __rtruediv__(self, other): | |||||
| check_type('tensor operation input', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__div__')(other, self) | |||||
| return out | return out | ||||
| def __sub__(self, other): | def __sub__(self, other): | ||||
| if not isinstance(other, Tensor): | |||||
| raise TypeError("input_data must be a tensor") | |||||
| out = self.__add__(Tensor(-other.asnumpy())) | |||||
| check_type('tensor operation input', other, (Tensor, float, int)) | |||||
| out = self.__add__(-other) | |||||
| return out | return out | ||||
| def __isub__(self, other): | def __isub__(self, other): | ||||
| out = self.__sub__(other) | out = self.__sub__(other) | ||||
| return out | return out | ||||
| def __rsub__(self, other): | |||||
| check_type('tensor operation input', other, (Tensor, float, int)) | |||||
| out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy())) | |||||
| return out | |||||
| def __str__(self): | def __str__(self): | ||||
| if self.dtype() == mstype.type_none: | if self.dtype() == mstype.type_none: | ||||
| return "Unknown Tensor type!" | return "Unknown Tensor type!" | ||||
| @@ -191,7 +191,7 @@ def get_bprop_concat(self): | |||||
| def bprop(x, out, dout): | def bprop(x, out, dout): | ||||
| dx = () | dx = () | ||||
| out_offset = P.ConcatOffset(F.tuple_len(x), axis)(x) | |||||
| out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) | |||||
| for i in range(F.tuple_len(x)): | for i in range(F.tuple_len(x)): | ||||
| slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) | slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) | ||||
| dx = dx + (slice_out,) | dx = dx + (slice_out,) | ||||
| @@ -14,6 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ops utils.""" | """ops utils.""" | ||||
| from .broadcast import _get_broadcast_shape | |||||
| from .utils import _get_broadcast_shape, _get_concat_offset | |||||
| __all__ = ['_get_broadcast_shape'] | |||||
| __all__ = ['_get_broadcast_shape', '_get_concat_offset'] | |||||
| @@ -13,8 +13,11 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """broadcast""" | |||||
| """utils for operator""" | |||||
| from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Rel | |||||
| from ...common import dtype as mstype | |||||
| def _get_broadcast_shape(x_shape, y_shape, prim_name): | def _get_broadcast_shape(x_shape, y_shape, prim_name): | ||||
| """ | """ | ||||
| @@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||||
| broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | ||||
| broadcast_shape = broadcast_shape_front + broadcast_shape_back | broadcast_shape = broadcast_shape_front + broadcast_shape_back | ||||
| return broadcast_shape | return broadcast_shape | ||||
| def _get_concat_offset(x_shp, x_type, axis): | |||||
| """for concat and concatoffset check args and compute offset""" | |||||
| validator.check_type("shape", x_shp, [tuple]) | |||||
| validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) | |||||
| validator.check_subclass("shape0", x_type[0], mstype.tensor) | |||||
| validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) | |||||
| rank_base = len(x_shp[0]) | |||||
| validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) | |||||
| if axis < 0: | |||||
| axis = axis + rank_base | |||||
| all_shp = x_shp[0][axis] | |||||
| offset = [0,] | |||||
| for i in range(1, len(x_shp)): | |||||
| v = x_shp[i] | |||||
| validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) | |||||
| validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) | |||||
| for j in range(rank_base): | |||||
| if j != axis and v[j] != x_shp[0][j]: | |||||
| raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) | |||||
| offset.append(all_shp) | |||||
| all_shp += v[axis] | |||||
| return offset, all_shp, axis | |||||
| @@ -19,7 +19,7 @@ Primitive operator classes. | |||||
| A collection of operators to build nerual networks or computing functions. | A collection of operators to build nerual networks or computing functions. | ||||
| """ | """ | ||||
| from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Pack, Unpack, | |||||
| from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Diag, DiagPart, DType, ExpandDims, Eye, | Diag, DiagPart, DType, ExpandDims, Eye, | ||||
| Fill, GatherNd, GatherV2, InvertPermutation, | Fill, GatherNd, GatherV2, InvertPermutation, | ||||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | ||||
| @@ -200,7 +200,6 @@ __all__ = [ | |||||
| 'LogicalOr', | 'LogicalOr', | ||||
| 'Size', | 'Size', | ||||
| 'DepthwiseConv2dNative', | 'DepthwiseConv2dNative', | ||||
| 'ConcatOffset', | |||||
| 'UnsortedSegmentSum', | 'UnsortedSegmentSum', | ||||
| "AllGather", | "AllGather", | ||||
| "AllReduce", | "AllReduce", | ||||
| @@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| from ..._checkparam import ParamValidator as validator | from ..._checkparam import ParamValidator as validator | ||||
| from ..._checkparam import Rel, check_int_positive, check_bool | from ..._checkparam import Rel, check_int_positive, check_bool | ||||
| from .._utils import _get_concat_offset | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| @@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): | |||||
| validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) | validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) | ||||
| return x_type | return x_type | ||||
| class ConcatOffset(PrimitiveWithInfer): | |||||
| """primitive for computing Concat's gradient.""" | |||||
| @prim_attr_register | |||||
| def __init__(self, N=2, axis=0): | |||||
| """init ConcatOffset""" | |||||
| def __infer__(self, input_x): | |||||
| axis = self.axis | |||||
| x_shp = input_x['shape'] | |||||
| x_type = input_x['dtype'] | |||||
| offset, _, axis = _get_concat_offset(x_shp, x_type, axis) | |||||
| self.add_prim_attr('T', x_type[0].element_type()) | |||||
| offset_values = [] | |||||
| for i in range(len(x_shp)): | |||||
| values = [] | |||||
| for j in range(len(x_shp[0])): | |||||
| value = 0 | |||||
| if j == axis: | |||||
| value = offset[i] | |||||
| values.append(value) | |||||
| offset_values.append(tuple(values)) | |||||
| out = {'shape': None, | |||||
| 'dtype': None, | |||||
| 'value': tuple(offset_values)} | |||||
| return out | |||||
| class Conv2DBackpropFilter(PrimitiveWithInfer): | class Conv2DBackpropFilter(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -29,6 +29,7 @@ from ..._checkparam import Rel | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from .._utils import _get_concat_offset | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| def _check_infer_attr_reduce(axis, keep_dims): | def _check_infer_attr_reduce(axis, keep_dims): | ||||
| @@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| def _get_concat_offset(x_shp, x_type, axis): | |||||
| """for concat and concatoffset check args and compute offset""" | |||||
| validator.check_type("shape", x_shp, [tuple]) | |||||
| validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) | |||||
| validator.check_subclass("shape0", x_type[0], mstype.tensor) | |||||
| validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) | |||||
| rank_base = len(x_shp[0]) | |||||
| validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) | |||||
| if axis < 0: | |||||
| axis = axis + rank_base | |||||
| all_shp = x_shp[0][axis] | |||||
| offset = [0,] | |||||
| for i in range(1, len(x_shp)): | |||||
| v = x_shp[i] | |||||
| validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) | |||||
| validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) | |||||
| for j in range(rank_base): | |||||
| if j != axis and v[j] != x_shp[0][j]: | |||||
| raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) | |||||
| offset.append(all_shp) | |||||
| all_shp += v[axis] | |||||
| return offset, all_shp, axis | |||||
| class Concat(PrimitiveWithInfer): | class Concat(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Concat tensor in specified axis. | Concat tensor in specified axis. | ||||
| @@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer): | |||||
| 'value': None} | 'value': None} | ||||
| class ConcatOffset(PrimitiveWithInfer): | |||||
| """primitive for computing Concat's gradient.""" | |||||
| @prim_attr_register | |||||
| def __init__(self, N=2, axis=0): | |||||
| """init ConcatOffset""" | |||||
| def __infer__(self, input_x): | |||||
| axis = self.axis | |||||
| x_shp = input_x['shape'] | |||||
| x_type = input_x['dtype'] | |||||
| offset, _, axis = _get_concat_offset(x_shp, x_type, axis) | |||||
| self.add_prim_attr('T', x_type[0].element_type()) | |||||
| offset_values = [] | |||||
| for i in range(len(x_shp)): | |||||
| values = [] | |||||
| for j in range(len(x_shp[0])): | |||||
| value = 0 | |||||
| if j == axis: | |||||
| value = offset[i] | |||||
| values.append(value) | |||||
| offset_values.append(tuple(values)) | |||||
| out = {'shape': None, | |||||
| 'dtype': None, | |||||
| 'value': tuple(offset_values)} | |||||
| return out | |||||
| class Select(PrimitiveWithInfer): | class Select(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| @@ -271,3 +271,6 @@ class MakeRefKey(Primitive): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, tag): | def __init__(self, tag): | ||||
| validator.check_type('tag', tag, (str,)) | validator.check_type('tag', tag, (str,)) | ||||
| def __call__(self): | |||||
| pass | |||||
| @@ -24,6 +24,7 @@ import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.common.api as me | import mindspore.common.api as me | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| @@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool(): | |||||
| input = ms.Tensor(input) | input = ms.Tensor(input) | ||||
| input_me = ms.Tensor(input, dtype=ms.bool_) | input_me = ms.Tensor(input, dtype=ms.bool_) | ||||
| def test_tensor_operation(): | |||||
| x = Tensor(np.ones((3,3)) * 4) | |||||
| res = x + 1 | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 5) | |||||
| res = 1 + x | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 5) | |||||
| res = x - 2 | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) | |||||
| res = 6 - x | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) | |||||
| res = x * 3 | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 12) | |||||
| res = 3 * x | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 12) | |||||
| res = x / 2 | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) | |||||
| res = 8 / x | |||||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) | |||||
| with pytest.raises(TypeError): | |||||
| res = x * (2, 3) | |||||
| @@ -190,7 +190,7 @@ def vm_impl_slice(self): | |||||
| return vm_impl | return vm_impl | ||||
| @vm_impl_getters.register(P.ConcatOffset) | |||||
| @vm_impl_getters.register(P._grad_ops.ConcatOffset) | |||||
| def vm_impl_concatOffset(self): | def vm_impl_concatOffset(self): | ||||
| """Generate vm_impl function for ConcatOffset""" | """Generate vm_impl function for ConcatOffset""" | ||||
| def vm_impl(x): | def vm_impl(x): | ||||