Merge pull request !700 from penn/validate_bprop_rulestags/v0.3.0-alpha
| @@ -695,6 +695,7 @@ REGISTER_PYBIND_DEFINE( | |||||
| (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | ||||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | ||||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | ||||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | |||||
| })); | })); | ||||
| const TypePtr kTypeExternal = std::make_shared<External>(); | const TypePtr kTypeExternal = std::make_shared<External>(); | ||||
| @@ -213,6 +213,7 @@ const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_orig | |||||
| const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | ||||
| const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | ||||
| const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | ||||
| const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||||
| const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | ||||
| const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | ||||
| @@ -220,6 +220,7 @@ extern const PrimitivePtr kPrimInsertGradientOf; | |||||
| extern const PrimitivePtr kPrimPrintShapeType; | extern const PrimitivePtr kPrimPrintShapeType; | ||||
| extern const PrimitivePtr kPrimPrint; | extern const PrimitivePtr kPrimPrint; | ||||
| extern const PrimitivePtr kPrimSameTypeShape; | extern const PrimitivePtr kPrimSameTypeShape; | ||||
| extern const PrimitivePtr kPrimCheckBprop; | |||||
| extern const PrimitivePtr kPrimDepend; | extern const PrimitivePtr kPrimDepend; | ||||
| extern const PrimitivePtr kPrimStateSetItem; | extern const PrimitivePtr kPrimStateSetItem; | ||||
| extern const PrimitivePtr kPrimScalarSummary; | extern const PrimitivePtr kPrimScalarSummary; | ||||
| @@ -309,14 +309,6 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { | |||||
| auto bprop = primal->transforms().find("bprop"); | auto bprop = primal->transforms().find("bprop"); | ||||
| if (bprop != primal->transforms().end()) { | if (bprop != primal->transforms().end()) { | ||||
| FuncGraphPtr bprop_graph = bprop->second.func_graph(); | FuncGraphPtr bprop_graph = bprop->second.func_graph(); | ||||
| const size_t param_diff = 1; | |||||
| if (bprop_graph->output()->isa<CNode>() && | |||||
| bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) { | |||||
| // It does not matter with the final tangents, just a tip for debugging | |||||
| MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope " | |||||
| << primal->output()->scope()->name() | |||||
| << " output must be a tuple and output number should be the same with inputs."; | |||||
| } | |||||
| resources_->manager()->AddFuncGraph(bprop_graph); | resources_->manager()->AddFuncGraph(bprop_graph); | ||||
| if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { | if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { | ||||
| @@ -127,7 +127,7 @@ class KPrim { | |||||
| AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); | AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); | ||||
| void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, | void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, | ||||
| std::vector<AnfNodePtr> *const transf_args); | std::vector<AnfNodePtr> *const transf_args); | ||||
| void AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg); | |||||
| void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); | |||||
| Registry bprop_registry_; | Registry bprop_registry_; | ||||
| std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_; | std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_; | ||||
| @@ -137,10 +137,7 @@ template <typename T> | |||||
| FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | ||||
| MS_EXCEPTION_IF_NULL(primal); | MS_EXCEPTION_IF_NULL(primal); | ||||
| MS_EXCEPTION_IF_NULL(bprop_fg); | MS_EXCEPTION_IF_NULL(bprop_fg); | ||||
| if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { | |||||
| AddCheckTypeShapeOp(bprop_fg); | |||||
| } | |||||
| CheckBprop(bprop_fg, primal->ToString()); | |||||
| auto debug_info = std::make_shared<GraphDebugInfo>(); | auto debug_info = std::make_shared<GraphDebugInfo>(); | ||||
| debug_info->set_name(primal->ToString()); | debug_info->set_name(primal->ToString()); | ||||
| @@ -50,9 +50,13 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { | |||||
| grad_op_child_scope_prefix + prim->name()); | grad_op_child_scope_prefix + prim->name()); | ||||
| ScopeGuard scope_guard(scope); | ScopeGuard scope_guard(scope); | ||||
| py::function fn = prim->GetBpropFunction(); | py::function fn = prim->GetBpropFunction(); | ||||
| if (fn == nullptr || py::isinstance<py::none>(fn)) { | |||||
| MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; | |||||
| return nullptr; | |||||
| } | |||||
| FuncGraphPtr func_graph = parse::ParsePythonCode(fn); | FuncGraphPtr func_graph = parse::ParsePythonCode(fn); | ||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << "."; | |||||
| MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return func_graph; | return func_graph; | ||||
| @@ -153,31 +157,23 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp | |||||
| } | } | ||||
| } | } | ||||
| void KPrim::AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg) { | |||||
| void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { | |||||
| // bprop_fg has been checked in caller | // bprop_fg has been checked in caller | ||||
| auto same_type_shape = prim::GetPythonOps("same_type_shape", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(same_type_shape); | |||||
| std::vector<AnfNodePtr> bout_input; | |||||
| bout_input.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| auto fg_out = bprop_fg->output(); | |||||
| MS_EXCEPTION_IF_NULL(fg_out); | |||||
| auto cnode = fg_out->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto &inputs = cnode->inputs(); | |||||
| auto params = bprop_fg->parameters(); | |||||
| std::vector<AnfNodePtr> sub_input; | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| sub_input.clear(); | |||||
| sub_input.push_back(NewValueNode(same_type_shape)); | |||||
| sub_input.push_back(inputs[i]); | |||||
| sub_input.push_back(params[i - 1]); | |||||
| bout_input.push_back(bprop_fg->NewCNode(sub_input)); | |||||
| } | |||||
| AnfNodePtr cbout = bprop_fg->NewCNode(bout_input); | |||||
| bprop_fg->set_output(cbout); | |||||
| auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(check_bprop); | |||||
| check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check)); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); | |||||
| AnfNodePtr params = bprop_fg->NewCNode(inputs); | |||||
| inputs.clear(); | |||||
| inputs.push_back(NewValueNode(check_bprop)); | |||||
| inputs.push_back(bprop_fg->output()); | |||||
| inputs.push_back(params); | |||||
| AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); | |||||
| bprop_fg->set_output(bprop_out); | |||||
| } | } | ||||
| FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | ||||
| @@ -67,6 +67,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | ||||
| partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); | partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); | ||||
| same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | ||||
| check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | ||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| @@ -45,6 +45,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr reduce_eliminate_; | SubstitutionPtr reduce_eliminate_; | ||||
| SubstitutionPtr partial_eliminate_; | SubstitutionPtr partial_eliminate_; | ||||
| SubstitutionPtr same_eliminate_; | SubstitutionPtr same_eliminate_; | ||||
| SubstitutionPtr check_bprop_eliminate_; | |||||
| SubstitutionPtr reset_defer_inline_; | SubstitutionPtr reset_defer_inline_; | ||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| @@ -109,6 +109,25 @@ class SameEliminater : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}; | AnfNodePtr x_{nullptr}; | ||||
| }; | }; | ||||
| // {prim::kPrimCheckBprop, X, Y} -> X | |||||
| class CheckBpropEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| x_ = nullptr; | |||||
| AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); | |||||
| return x_; | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (x_ == nullptr) { | |||||
| x_ = node; | |||||
| } | |||||
| } | |||||
| private: | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // Reset defer_inline flag | // Reset defer_inline flag | ||||
| class ResetDeferInline : public AnfVisitor { | class ResetDeferInline : public AnfVisitor { | ||||
| public: | public: | ||||
| @@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| }); | }); | ||||
| opt::OptPassConfig a_3 = opt::OptPassConfig({ | opt::OptPassConfig a_3 = opt::OptPassConfig({ | ||||
| irpass.same_eliminate_, | irpass.same_eliminate_, | ||||
| irpass.check_bprop_eliminate_, | |||||
| irpass.replace_applicator_, | irpass.replace_applicator_, | ||||
| }); | }); | ||||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | ||||
| @@ -295,6 +295,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| dic["shape"] = shape; | dic["shape"] = shape; | ||||
| dic["dtype"] = arg_slice->BuildType(); | dic["dtype"] = arg_slice->BuildType(); | ||||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | dic["value"] = BuildValue(arg_slice->BuildValue()); | ||||
| } else if (abs_base->isa<AbstractRef>()) { | |||||
| auto value = abs_base->cast<AbstractRefPtr>()->ref(); | |||||
| dic = ConvertAbstractToPython(value); | |||||
| } else if (abs_base->isa<AbstractTuple>()) { | } else if (abs_base->isa<AbstractTuple>()) { | ||||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | ||||
| size_t len = arg_tuple->size(); | size_t len = arg_tuple->size(); | ||||
| @@ -327,6 +330,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||||
| dic["shape"] = py::none(); | dic["shape"] = py::none(); | ||||
| dic["dtype"] = py::none(); | dic["dtype"] = py::none(); | ||||
| dic["value"] = py::none(); | dic["value"] = py::none(); | ||||
| } else if (abs_base->isa<AbstractFunction>()) { | |||||
| dic["shape"] = py::none(); | |||||
| dic["dtype"] = abs_base->BuildType(); | |||||
| dic["value"] = py::none(); | |||||
| } else { | } else { | ||||
| auto value = abs_base->BuildValue(); | auto value = abs_base->BuildValue(); | ||||
| if ((*value == *kAnyValue)) { | if ((*value == *kAnyValue)) { | ||||
| @@ -85,13 +85,16 @@ list_ = typing.List() | |||||
| tuple_ = typing.Tuple() | tuple_ = typing.Tuple() | ||||
| tensor = typing.TensorType() | tensor = typing.TensorType() | ||||
| function = typing.Function() | function = typing.Function() | ||||
| function_type = typing.Function | |||||
| symbolic_key = typing.SymbolicKeyType() | symbolic_key = typing.SymbolicKeyType() | ||||
| env_type = typing.EnvType() | env_type = typing.EnvType() | ||||
| env_type_type = typing.EnvType | |||||
| type_type = typing.TypeType() | type_type = typing.TypeType() | ||||
| type_none = typing.TypeNone() | type_none = typing.TypeNone() | ||||
| string = typing.String() | string = typing.String() | ||||
| type_refkey = typing.RefKeyType() | type_refkey = typing.RefKeyType() | ||||
| tensor_type = typing.TensorType | tensor_type = typing.TensorType | ||||
| anything_type = typing.TypeAnything | |||||
| number_type = (int8, | number_type = (int8, | ||||
| int16, | int16, | ||||
| @@ -211,11 +211,11 @@ def get_bprop_slice(self): | |||||
| def bprop(x, begin, size, out, dout): | def bprop(x, begin, size, out, dout): | ||||
| dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout) | dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout) | ||||
| return (dx,) | |||||
| return (dx, zeros_like(begin), zeros_like(size)) | |||||
| def bprop_gpu(x, begin, size, out, dout): | def bprop_gpu(x, begin, size, out, dout): | ||||
| dx = dx = G.SliceGrad()(dout, x, begin, size) | dx = dx = G.SliceGrad()(dout, x, begin, size) | ||||
| return (dx,) | |||||
| return (dx, zeros_like(begin), zeros_like(size)) | |||||
| if context.get_context('device_target') == "GPU": | if context.get_context('device_target') == "GPU": | ||||
| return bprop_gpu | return bprop_gpu | ||||
| @@ -262,7 +262,7 @@ def get_bprop_gather_v2(self): | |||||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | ||||
| perm_2 = _generate_inverse_index(x_shp, axis) | perm_2 = _generate_inverse_index(x_shp, axis) | ||||
| params_grad = transpose(params_grad, perm_2) | params_grad = transpose(params_grad, perm_2) | ||||
| return params_grad, zeros_like(indices) | |||||
| return params_grad, zeros_like(indices), zeros_like(axis) | |||||
| return bprop | return bprop | ||||
| @@ -505,7 +505,7 @@ def get_bprop_reducemax(self): | |||||
| def bprop(x, axis, out, dout): | def bprop(x, axis, out, dout): | ||||
| dx = _min_or_max_grad(x, axis, out, dout) | dx = _min_or_max_grad(x, axis, out, dout) | ||||
| return (dx,) | |||||
| return (dx, zeros_like(axis)) | |||||
| return bprop | return bprop | ||||
| @@ -528,7 +528,7 @@ def get_bprop_reducemin(self): | |||||
| def bprop(x, axis, out, dout): | def bprop(x, axis, out, dout): | ||||
| dx = _min_or_max_grad(x, axis, out, dout) | dx = _min_or_max_grad(x, axis, out, dout) | ||||
| return (dx,) | |||||
| return (dx, zeros_like(axis)) | |||||
| return bprop | return bprop | ||||
| @@ -436,7 +436,7 @@ def get_bprop_onehot(self): | |||||
| """Grad definition for `OneHot` operation.""" | """Grad definition for `OneHot` operation.""" | ||||
| def bprop(indices, depth, on_value, off_value, out, dout): | def bprop(indices, depth, on_value, off_value, out, dout): | ||||
| return zeros_like(indices), zeros_like(depth) | |||||
| return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value) | |||||
| return bprop | return bprop | ||||
| @@ -31,6 +31,10 @@ def _zeros_like_scala(x): | |||||
| """Returns 0 which has the same dtype as x where x is a scalar.""" | """Returns 0 which has the same dtype as x where x is a scalar.""" | ||||
| return 0 | return 0 | ||||
| @zeros_like_leaf.register("Bool") | |||||
| def _zeros_like_bool(x): | |||||
| """Returns False if x is a bool.""" | |||||
| return False | |||||
| newenv = base.EnvInstance_() | newenv = base.EnvInstance_() | ||||
| @@ -56,6 +56,7 @@ tensor_pow = P.Pow() | |||||
| tensor_mod = P.FloorMod() | tensor_mod = P.FloorMod() | ||||
| strided_slice = P.StridedSlice() | strided_slice = P.StridedSlice() | ||||
| same_type_shape = P.SameTypeShape() | same_type_shape = P.SameTypeShape() | ||||
| check_bprop = P.CheckBprop() | |||||
| equal = P.Equal() | equal = P.Equal() | ||||
| not_equal = P.NotEqual() | not_equal = P.NotEqual() | ||||
| assign_sub = P.AssignSub() | assign_sub = P.AssignSub() | ||||
| @@ -67,7 +67,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | SparseSoftmaxCrossEntropyWithLogits, Tanh, | ||||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | ||||
| ApplyRMSProp, ApplyCenteredRMSProp) | ApplyRMSProp, ApplyCenteredRMSProp) | ||||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey | |||||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | |||||
| from . import _quant_ops | from . import _quant_ops | ||||
| from ._quant_ops import * | from ._quant_ops import * | ||||
| @@ -179,6 +179,7 @@ __all__ = [ | |||||
| 'GeSwitch', | 'GeSwitch', | ||||
| 'Merge', | 'Merge', | ||||
| 'SameTypeShape', | 'SameTypeShape', | ||||
| 'CheckBprop', | |||||
| 'CheckValid', | 'CheckValid', | ||||
| 'BoundingBoxEncode', | 'BoundingBoxEncode', | ||||
| 'BoundingBoxDecode', | 'BoundingBoxDecode', | ||||
| @@ -269,3 +269,66 @@ class MakeRefKey(Primitive): | |||||
| def __call__(self): | def __call__(self): | ||||
| pass | pass | ||||
| class CheckBprop(PrimitiveWithInfer): | |||||
| """ | |||||
| Checks whether data type and shape of corresponding element from tuple x and y are the same. | |||||
| Raises: | |||||
| TypeError: If not the same. | |||||
| Inputs: | |||||
| - **input_x** (tuple[Tensor]) - The input_x contains the outputs of bprop to be checked. | |||||
| - **input_y** (tuple[Tensor]) - The input_y contains the inputs of bprop to check against. | |||||
| Outputs: | |||||
| (tuple[Tensor]), the input_x, | |||||
| if data type and shape of corresponding elements from `input_x` and `input_y` are the same. | |||||
| Examples: | |||||
| >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) | |||||
| >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) | |||||
| >>> out = P.CheckBprop()(input_x, input_y) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CheckBprop""" | |||||
| def infer_shape(self, xshapes, yshapes): | |||||
| tips = f'Bprop of {self.prim_to_check}' | |||||
| if len(xshapes) < len(yshapes): | |||||
| raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," | |||||
| f" but got {len(xshapes)}.") | |||||
| checking_range = len(yshapes) | |||||
| for i in range(checking_range): | |||||
| xshape = xshapes[i] | |||||
| yshape = yshapes[i] | |||||
| if not xshape or not yshape: | |||||
| continue | |||||
| if xshape != yshape: | |||||
| raise TypeError(f"{tips}, the shape of {i}th output should be {yshape}," | |||||
| f" but got {xshape}.") | |||||
| return xshapes | |||||
| def infer_dtype(self, xdtypes, ydtypes): | |||||
| tips = f'Bprop of {self.prim_to_check}' | |||||
| if len(xdtypes) < len(ydtypes): | |||||
| raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," | |||||
| f" but got {len(xdtypes)}.") | |||||
| checking_range = len(ydtypes) | |||||
| for i in range(checking_range): | |||||
| xdtype = xdtypes[i] | |||||
| ydtype = ydtypes[i] | |||||
| if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type): | |||||
| continue | |||||
| if isinstance(ydtype, mstype.function_type): | |||||
| if not isinstance(xdtype, mstype.env_type_type): | |||||
| raise TypeError(f"{tips}, the dtype of {i}th output should be {mstype.env_type_type}," | |||||
| f" but got {xdtype}.") | |||||
| continue | |||||
| if xdtype != ydtype: | |||||
| raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," | |||||
| f" but got {xdtype}.") | |||||
| return xdtypes | |||||
| @@ -317,7 +317,7 @@ test_case_cell_ops = [ | |||||
| initializer_range=0.02, | initializer_range=0.02, | ||||
| dropout_prob=0.1), | dropout_prob=0.1), | ||||
| 'desc_inputs': [[1, 768], [1, 768]], | 'desc_inputs': [[1, 768], [1, 768]], | ||||
| 'desc_bprop': [[1, 128, 768]]}), # maybe not right | |||||
| 'desc_bprop': [[1, 768]]}), | |||||
| ('BertTransformer_2', { | ('BertTransformer_2', { | ||||
| 'block': bert_trans(), | 'block': bert_trans(), | ||||
| 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}), | 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}), | ||||
| @@ -331,7 +331,7 @@ test_case_cell_ops = [ | |||||
| 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | ||||
| Tensor(np.random.rand(128).astype(np.int32)), [128]], | Tensor(np.random.rand(128).astype(np.int32)), [128]], | ||||
| 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | ||||
| 'num_output': 3}), # maybe not right | |||||
| 'num_output': 3}), | |||||
| ('BertModel_1', { | ('BertModel_1', { | ||||
| 'block': BertModel(config=BertConfig(batch_size=1, | 'block': BertModel(config=BertConfig(batch_size=1, | ||||
| @@ -342,7 +342,7 @@ test_case_cell_ops = [ | |||||
| 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | ||||
| Tensor(np.random.rand(128).astype(np.int32)), [128]], | Tensor(np.random.rand(128).astype(np.int32)), [128]], | ||||
| 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | ||||
| 'num_output': 3}), # maybe not right | |||||
| 'num_output': 3}), | |||||
| ('BertModel_2', { | ('BertModel_2', { | ||||
| 'block': BertModel(config=BertConfig(batch_size=1, | 'block': BertModel(config=BertConfig(batch_size=1, | ||||
| @@ -354,7 +354,7 @@ test_case_cell_ops = [ | |||||
| 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | ||||
| Tensor(np.random.rand(128).astype(np.int32)), [128]], | Tensor(np.random.rand(128).astype(np.int32)), [128]], | ||||
| 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | ||||
| 'num_output': 3}), # maybe not right | |||||
| 'num_output': 3}), | |||||
| ('BertPretrainingLoss', { | ('BertPretrainingLoss', { | ||||
| 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)), | 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)), | ||||
| @@ -175,7 +175,7 @@ class GetParamGrad(nn.Cell): | |||||
| def test_grad_conv_prelu(): | def test_grad_conv_prelu(): | ||||
| shapes = [[64, 64, 112, 112]] | shapes = [[64, 64, 112, 112]] | ||||
| outshape = [[64, 64, 56, 56]] | |||||
| outshape = [[64, 64, 112, 112]] | |||||
| net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True) | net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True) | ||||
| inputs = [convert(shp, dtype=np.float16) for shp in shapes] | inputs = [convert(shp, dtype=np.float16) for shp in shapes] | ||||
| sens_shape = outshape[0] | sens_shape = outshape[0] | ||||
| @@ -585,7 +585,7 @@ test_case_nn_ops = [ | |||||
| ('ReLUV2', { | ('ReLUV2', { | ||||
| 'block': P.ReLUV2(), | 'block': P.ReLUV2(), | ||||
| 'desc_inputs': [[1, 3, 4, 4]], | 'desc_inputs': [[1, 3, 4, 4]], | ||||
| 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), | |||||
| 'desc_bprop': [[1, 3, 4, 4], ([1, 1, 4, 4, 2], {'dtype': np.uint8})]}), | |||||
| ('ReLUGrad', { | ('ReLUGrad', { | ||||
| 'block': G.ReluGrad(), | 'block': G.ReluGrad(), | ||||
| 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | ||||
| @@ -626,7 +626,7 @@ test_case_nn_ops = [ | |||||
| ('MaxPoolWithArgmax', { | ('MaxPoolWithArgmax', { | ||||
| 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), | 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), | ||||
| 'desc_inputs': [[128, 32, 32, 64]], | 'desc_inputs': [[128, 32, 32, 64]], | ||||
| 'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}), | |||||
| 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}), | |||||
| ('SoftmaxCrossEntropyWithLogits', { | ('SoftmaxCrossEntropyWithLogits', { | ||||
| 'block': P.SoftmaxCrossEntropyWithLogits(), | 'block': P.SoftmaxCrossEntropyWithLogits(), | ||||
| 'desc_inputs': [[1, 10], [1, 10]], | 'desc_inputs': [[1, 10], [1, 10]], | ||||
| @@ -639,7 +639,7 @@ test_case_nn_ops = [ | |||||
| ('LogSoftmax', { | ('LogSoftmax', { | ||||
| 'block': P.LogSoftmax(), | 'block': P.LogSoftmax(), | ||||
| 'desc_inputs': [[64, 2]], | 'desc_inputs': [[64, 2]], | ||||
| 'desc_bprop': [[160, 30522]]}), | |||||
| 'desc_bprop': [[64, 2]]}), | |||||
| ('LogSoftmaxGrad', { | ('LogSoftmaxGrad', { | ||||
| 'block': G.LogSoftmaxGrad(), | 'block': G.LogSoftmaxGrad(), | ||||
| 'desc_inputs': [[16, 1234], [16, 1234]], | 'desc_inputs': [[16, 1234], [16, 1234]], | ||||
| @@ -648,7 +648,7 @@ test_case_nn_ops = [ | |||||
| ('LayerNorm', { | ('LayerNorm', { | ||||
| 'block': P.LayerNorm(), | 'block': P.LayerNorm(), | ||||
| 'desc_inputs': [[2, 16], [16], [16]], | 'desc_inputs': [[2, 16], [16], [16]], | ||||
| 'desc_bprop': [[2, 16], [2, 16], [2, 16]]}), | |||||
| 'desc_bprop': [[2, 16], [2, 1], [2, 1]]}), | |||||
| ('LayerNormGrad', { | ('LayerNormGrad', { | ||||
| 'block': G.LayerNormGrad(), | 'block': G.LayerNormGrad(), | ||||
| 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], | 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], | ||||
| @@ -845,7 +845,7 @@ test_case_nn_ops = [ | |||||
| 'block': P.OneHot(), | 'block': P.OneHot(), | ||||
| 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], | 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], | ||||
| 'desc_inputs': [Tensor(np.array([64]).astype(np.int32))], | 'desc_inputs': [Tensor(np.array([64]).astype(np.int32))], | ||||
| 'desc_bprop': [[64, 2]]}), | |||||
| 'desc_bprop': [[1, 3]]}), | |||||
| ('ReduceProd_0', { | ('ReduceProd_0', { | ||||
| 'block': P.ReduceProd(), | 'block': P.ReduceProd(), | ||||
| 'desc_const': [0], | 'desc_const': [0], | ||||
| @@ -950,7 +950,7 @@ test_case_array_ops = [ | |||||
| 'block': P.Cast(), | 'block': P.Cast(), | ||||
| 'desc_const': [mstype.int32], | 'desc_const': [mstype.int32], | ||||
| 'desc_inputs': [[2, 3, 4, 5]], | 'desc_inputs': [[2, 3, 4, 5]], | ||||
| 'desc_bprop': [Tensor(np.ones((2, 3, 3, 5)).astype(np.int32))]}), | |||||
| 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5)).astype(np.int32))]}), | |||||
| ('ExpandDims', { | ('ExpandDims', { | ||||
| 'block': P.ExpandDims(), | 'block': P.ExpandDims(), | ||||
| 'desc_const': [0], | 'desc_const': [0], | ||||
| @@ -1002,12 +1002,12 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [ | 'desc_inputs': [ | ||||
| (Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)), | (Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)), | ||||
| Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))], | Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))], | ||||
| 'desc_bprop': [[4, 2]]}), | |||||
| 'desc_bprop': [([4, 2], {'dtype': np.int32})]}), | |||||
| ('ConcatV2_1', { | ('ConcatV2_1', { | ||||
| 'block': P.Concat(axis=2), | 'block': P.Concat(axis=2), | ||||
| 'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)), | 'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)), | ||||
| Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))], | Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))], | ||||
| 'desc_bprop': [[2, 1, 5]]}), | |||||
| 'desc_bprop': [([2, 1, 5], {'dtype': np.int32})]}), | |||||
| ('ConcatV2_2', { | ('ConcatV2_2', { | ||||
| 'block': NetForConcat(), | 'block': NetForConcat(), | ||||
| 'desc_inputs': [[2, 2]], | 'desc_inputs': [[2, 2]], | ||||
| @@ -1042,7 +1042,7 @@ test_case_array_ops = [ | |||||
| ('Pack_2', { | ('Pack_2', { | ||||
| 'block': NetForPackInput(P.Pack()), | 'block': NetForPackInput(P.Pack()), | ||||
| 'desc_inputs':[[2, 2]], | 'desc_inputs':[[2, 2]], | ||||
| 'desc_bprop':[[2, 2, 2]], | |||||
| 'desc_bprop':[[1, 2, 2]], | |||||
| }), | }), | ||||
| ('Pack_3', { | ('Pack_3', { | ||||
| 'block': NetForPackInput(P.Pack()), | 'block': NetForPackInput(P.Pack()), | ||||
| @@ -1077,7 +1077,7 @@ test_case_array_ops = [ | |||||
| ('SpaceToBatch_2', { | ('SpaceToBatch_2', { | ||||
| 'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]), | 'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]), | ||||
| 'desc_inputs': [[1, 3, 2, 2]], | 'desc_inputs': [[1, 3, 2, 2]], | ||||
| 'desc_bprop': [[4, 3, 2, 4]], | |||||
| 'desc_bprop': [[4, 3, 2, 3]], | |||||
| }), | }), | ||||
| ('BatchToSpace_1', { | ('BatchToSpace_1', { | ||||
| 'block': P.BatchToSpace(2, [[0, 0], [0, 0]]), | 'block': P.BatchToSpace(2, [[0, 0], [0, 0]]), | ||||
| @@ -1124,7 +1124,7 @@ test_case_other_ops = [ | |||||
| 'desc_const': [(3, 3)], | 'desc_const': [(3, 3)], | ||||
| 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), | 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), | ||||
| Tensor(np.ones((2,), np.int32))), | Tensor(np.ones((2,), np.int32))), | ||||
| 'desc_bprop': [[3, 3]]}), | |||||
| 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), | |||||
| ('SmoothL1Loss', { | ('SmoothL1Loss', { | ||||
| 'block': P.SmoothL1Loss(), | 'block': P.SmoothL1Loss(), | ||||
| 'desc_inputs': [[256, 4], [256, 4]], | 'desc_inputs': [[256, 4], [256, 4]], | ||||
| @@ -229,12 +229,6 @@ class TwoInputBprop(nn.Cell): | |||||
| def bprop(self, x, y, out, dout): | def bprop(self, x, y, out, dout): | ||||
| return 5 * x, 8 * y | return 5 * x, 8 * y | ||||
| class TwoInput(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.op = P.Mul() | |||||
| def construct(self, x, y): | |||||
| return self.op(x, y) | |||||
| class TwoInputWithParameter(nn.Cell): | class TwoInputWithParameter(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -301,8 +295,37 @@ class MulAddWithWrongOutputNum(nn.Cell): | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return 2 * x + y | return 2 * x + y | ||||
| def bprop(self, x, y, out, dout): | def bprop(self, x, y, out, dout): | ||||
| return 2 * dout, 2 * y, out | |||||
| return 2 * dout, | |||||
| def test_grad_mul_add_with_wrong_output_num(): | def test_grad_mul_add_with_wrong_output_num(): | ||||
| mul_add = MulAddWithWrongOutputNum() | mul_add = MulAddWithWrongOutputNum() | ||||
| C.grad_all(mul_add)(1, 2) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(mul_add)(1, 2) | |||||
| class MulAddWithWrongOutputType(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MulAddWithWrongOutputType, self).__init__() | |||||
| def construct(self, x, y): | |||||
| return 2 * x + y | |||||
| def bprop(self, x, y, out, dout): | |||||
| return 2 * dout, 2 | |||||
| def test_grad_mul_add_with_wrong_output_type(): | |||||
| mul_add = MulAddWithWrongOutputType() | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) | |||||
| class MulAddWithWrongOutputShape(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MulAddWithWrongOutputShape, self).__init__() | |||||
| self.ones = Tensor(np.ones([2,])) | |||||
| def construct(self, x, y): | |||||
| return 2 * x + y | |||||
| def bprop(self, x, y, out, dout): | |||||
| return 2, self.ones | |||||
| def test_grad_mul_add_with_wrong_output_shape(): | |||||
| mul_add = MulAddWithWrongOutputShape() | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) | |||||
| @@ -32,6 +32,8 @@ from ....mindspore_test_framework.utils.check_gradient import ( | |||||
| OperationGradChecker, check_gradient, ScalarGradChecker) | OperationGradChecker, check_gradient, ScalarGradChecker) | ||||
| from ....mindspore_test_framework.utils.bprop_util import bprop | from ....mindspore_test_framework.utils.bprop_util import bprop | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore.ops._grad.grad_base import bprop_getters | |||||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||||
| def setup_module(module): | def setup_module(module): | ||||
| @@ -721,3 +723,94 @@ def test_grad_if_defer_inline(): | |||||
| inp = Tensor(np.ones([128, 96]).astype(np.float32)) | inp = Tensor(np.ones([128, 96]).astype(np.float32)) | ||||
| grads = C.grad_all(network)(inp) | grads = C.grad_all(network)(inp) | ||||
| assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) | assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) | ||||
| def test_bprop_with_wrong_output_num(): | |||||
| class BpropWithWrongOutputNum(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') | |||||
| def __call__(self, x, y): | |||||
| return x | |||||
| def infer_shape(self, x_shape, yshape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type, y_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputNum) | |||||
| def get_bprop_with_wrong_output_num(self): | |||||
| """Generate bprop for BpropWithWrongOutputNum""" | |||||
| def bprop(x, y, out, dout): | |||||
| return (dout,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputNumCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputNumCell, self).__init__() | |||||
| def construct(self, x, y): | |||||
| return BpropWithWrongOutputNum()(x, y) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||||
| def test_bprop_with_wrong_output_type(): | |||||
| class BpropWithWrongOutputType(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') | |||||
| def __call__(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputType) | |||||
| def get_bprop_with_wrong_output_type(self): | |||||
| """Generate bprop for BpropWithWrongOutputType""" | |||||
| def bprop(x, out, dout): | |||||
| return (1,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputTypeCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputTypeCell, self).__init__() | |||||
| def construct(self, x): | |||||
| return BpropWithWrongOutputType()(x) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||||
| def test_bprop_with_wrong_output_shape(): | |||||
| class BpropWithWrongOutputShape(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') | |||||
| def __call__(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputShape) | |||||
| def get_bprop_with_wrong_output_shape(self): | |||||
| """Generate bprop for BpropWithWrongOutputShape""" | |||||
| ones = Tensor(np.ones([2,]).astype(np.int32)) | |||||
| def bprop(x, out, dout): | |||||
| return (ones,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputShapeCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputShapeCell, self).__init__() | |||||
| def construct(self, x): | |||||
| return BpropWithWrongOutputShape()(x) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||||
| @@ -79,7 +79,7 @@ def test_InsertGradientOf_2(): | |||||
| summary = P.ScalarSummary() | summary = P.ScalarSummary() | ||||
| def debug_gradient(dx): | def debug_gradient(dx): | ||||
| """ debug_gradient """ | """ debug_gradient """ | ||||
| dx = summary("dx: ", dx) | |||||
| summary("dx: ", dx) | |||||
| return dx | return dx | ||||
| debug = P.InsertGradientOf(debug_gradient) | debug = P.InsertGradientOf(debug_gradient) | ||||