Merge pull request !700 from penn/validate_bprop_rulestags/v0.3.0-alpha
| @@ -695,6 +695,7 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | |||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | |||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | |||
| })); | |||
| const TypePtr kTypeExternal = std::make_shared<External>(); | |||
| @@ -213,6 +213,7 @@ const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_orig | |||
| const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||
| const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||
| const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); | |||
| const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | |||
| @@ -220,6 +220,7 @@ extern const PrimitivePtr kPrimInsertGradientOf; | |||
| extern const PrimitivePtr kPrimPrintShapeType; | |||
| extern const PrimitivePtr kPrimPrint; | |||
| extern const PrimitivePtr kPrimSameTypeShape; | |||
| extern const PrimitivePtr kPrimCheckBprop; | |||
| extern const PrimitivePtr kPrimDepend; | |||
| extern const PrimitivePtr kPrimStateSetItem; | |||
| extern const PrimitivePtr kPrimScalarSummary; | |||
| @@ -309,14 +309,6 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { | |||
| auto bprop = primal->transforms().find("bprop"); | |||
| if (bprop != primal->transforms().end()) { | |||
| FuncGraphPtr bprop_graph = bprop->second.func_graph(); | |||
| const size_t param_diff = 1; | |||
| if (bprop_graph->output()->isa<CNode>() && | |||
| bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) { | |||
| // It does not matter with the final tangents, just a tip for debugging | |||
| MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope " | |||
| << primal->output()->scope()->name() | |||
| << " output must be a tuple and output number should be the same with inputs."; | |||
| } | |||
| resources_->manager()->AddFuncGraph(bprop_graph); | |||
| if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { | |||
| @@ -127,7 +127,7 @@ class KPrim { | |||
| AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); | |||
| void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, | |||
| std::vector<AnfNodePtr> *const transf_args); | |||
| void AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg); | |||
| void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); | |||
| Registry bprop_registry_; | |||
| std::unordered_map<PrimitivePtr, MetaFuncGraphPtr> bprop_registry_meta_; | |||
| @@ -137,10 +137,7 @@ template <typename T> | |||
| FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { | |||
| MS_EXCEPTION_IF_NULL(primal); | |||
| MS_EXCEPTION_IF_NULL(bprop_fg); | |||
| if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { | |||
| AddCheckTypeShapeOp(bprop_fg); | |||
| } | |||
| CheckBprop(bprop_fg, primal->ToString()); | |||
| auto debug_info = std::make_shared<GraphDebugInfo>(); | |||
| debug_info->set_name(primal->ToString()); | |||
| @@ -50,9 +50,13 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { | |||
| grad_op_child_scope_prefix + prim->name()); | |||
| ScopeGuard scope_guard(scope); | |||
| py::function fn = prim->GetBpropFunction(); | |||
| if (fn == nullptr || py::isinstance<py::none>(fn)) { | |||
| MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr func_graph = parse::ParsePythonCode(fn); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << "."; | |||
| MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; | |||
| return nullptr; | |||
| } | |||
| return func_graph; | |||
| @@ -153,31 +157,23 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp | |||
| } | |||
| } | |||
| void KPrim::AddCheckTypeShapeOp(const FuncGraphPtr &bprop_fg) { | |||
| void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { | |||
| // bprop_fg has been checked in caller | |||
| auto same_type_shape = prim::GetPythonOps("same_type_shape", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(same_type_shape); | |||
| std::vector<AnfNodePtr> bout_input; | |||
| bout_input.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| auto fg_out = bprop_fg->output(); | |||
| MS_EXCEPTION_IF_NULL(fg_out); | |||
| auto cnode = fg_out->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| auto params = bprop_fg->parameters(); | |||
| std::vector<AnfNodePtr> sub_input; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| sub_input.clear(); | |||
| sub_input.push_back(NewValueNode(same_type_shape)); | |||
| sub_input.push_back(inputs[i]); | |||
| sub_input.push_back(params[i - 1]); | |||
| bout_input.push_back(bprop_fg->NewCNode(sub_input)); | |||
| } | |||
| AnfNodePtr cbout = bprop_fg->NewCNode(bout_input); | |||
| bprop_fg->set_output(cbout); | |||
| auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(check_bprop); | |||
| check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check)); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); | |||
| AnfNodePtr params = bprop_fg->NewCNode(inputs); | |||
| inputs.clear(); | |||
| inputs.push_back(NewValueNode(check_bprop)); | |||
| inputs.push_back(bprop_fg->output()); | |||
| inputs.push_back(params); | |||
| AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); | |||
| bprop_fg->set_output(bprop_out); | |||
| } | |||
| FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | |||
| @@ -67,6 +67,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | |||
| partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); | |||
| same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||
| // Env Item Eliminate | |||
| @@ -45,6 +45,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr reduce_eliminate_; | |||
| SubstitutionPtr partial_eliminate_; | |||
| SubstitutionPtr same_eliminate_; | |||
| SubstitutionPtr check_bprop_eliminate_; | |||
| SubstitutionPtr reset_defer_inline_; | |||
| // Env Item Eliminate | |||
| @@ -109,6 +109,25 @@ class SameEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| // {prim::kPrimCheckBprop, X, Y} -> X | |||
| class CheckBpropEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| x_ = nullptr; | |||
| AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); | |||
| return x_; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (x_ == nullptr) { | |||
| x_ = node; | |||
| } | |||
| } | |||
| private: | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| // Reset defer_inline flag | |||
| class ResetDeferInline : public AnfVisitor { | |||
| public: | |||
| @@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| }); | |||
| opt::OptPassConfig a_3 = opt::OptPassConfig({ | |||
| irpass.same_eliminate_, | |||
| irpass.check_bprop_eliminate_, | |||
| irpass.replace_applicator_, | |||
| }); | |||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | |||
| @@ -295,6 +295,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| dic["shape"] = shape; | |||
| dic["dtype"] = arg_slice->BuildType(); | |||
| dic["value"] = BuildValue(arg_slice->BuildValue()); | |||
| } else if (abs_base->isa<AbstractRef>()) { | |||
| auto value = abs_base->cast<AbstractRefPtr>()->ref(); | |||
| dic = ConvertAbstractToPython(value); | |||
| } else if (abs_base->isa<AbstractTuple>()) { | |||
| auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); | |||
| size_t len = arg_tuple->size(); | |||
| @@ -327,6 +330,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = py::none(); | |||
| dic["value"] = py::none(); | |||
| } else if (abs_base->isa<AbstractFunction>()) { | |||
| dic["shape"] = py::none(); | |||
| dic["dtype"] = abs_base->BuildType(); | |||
| dic["value"] = py::none(); | |||
| } else { | |||
| auto value = abs_base->BuildValue(); | |||
| if ((*value == *kAnyValue)) { | |||
| @@ -85,13 +85,16 @@ list_ = typing.List() | |||
| tuple_ = typing.Tuple() | |||
| tensor = typing.TensorType() | |||
| function = typing.Function() | |||
| function_type = typing.Function | |||
| symbolic_key = typing.SymbolicKeyType() | |||
| env_type = typing.EnvType() | |||
| env_type_type = typing.EnvType | |||
| type_type = typing.TypeType() | |||
| type_none = typing.TypeNone() | |||
| string = typing.String() | |||
| type_refkey = typing.RefKeyType() | |||
| tensor_type = typing.TensorType | |||
| anything_type = typing.TypeAnything | |||
| number_type = (int8, | |||
| int16, | |||
| @@ -211,11 +211,11 @@ def get_bprop_slice(self): | |||
| def bprop(x, begin, size, out, dout): | |||
| dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout) | |||
| return (dx,) | |||
| return (dx, zeros_like(begin), zeros_like(size)) | |||
| def bprop_gpu(x, begin, size, out, dout): | |||
| dx = dx = G.SliceGrad()(dout, x, begin, size) | |||
| return (dx,) | |||
| return (dx, zeros_like(begin), zeros_like(size)) | |||
| if context.get_context('device_target') == "GPU": | |||
| return bprop_gpu | |||
| @@ -262,7 +262,7 @@ def get_bprop_gather_v2(self): | |||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | |||
| perm_2 = _generate_inverse_index(x_shp, axis) | |||
| params_grad = transpose(params_grad, perm_2) | |||
| return params_grad, zeros_like(indices) | |||
| return params_grad, zeros_like(indices), zeros_like(axis) | |||
| return bprop | |||
| @@ -505,7 +505,7 @@ def get_bprop_reducemax(self): | |||
| def bprop(x, axis, out, dout): | |||
| dx = _min_or_max_grad(x, axis, out, dout) | |||
| return (dx,) | |||
| return (dx, zeros_like(axis)) | |||
| return bprop | |||
| @@ -528,7 +528,7 @@ def get_bprop_reducemin(self): | |||
| def bprop(x, axis, out, dout): | |||
| dx = _min_or_max_grad(x, axis, out, dout) | |||
| return (dx,) | |||
| return (dx, zeros_like(axis)) | |||
| return bprop | |||
| @@ -436,7 +436,7 @@ def get_bprop_onehot(self): | |||
| """Grad definition for `OneHot` operation.""" | |||
| def bprop(indices, depth, on_value, off_value, out, dout): | |||
| return zeros_like(indices), zeros_like(depth) | |||
| return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value) | |||
| return bprop | |||
| @@ -31,6 +31,10 @@ def _zeros_like_scala(x): | |||
| """Returns 0 which has the same dtype as x where x is a scalar.""" | |||
| return 0 | |||
| @zeros_like_leaf.register("Bool") | |||
| def _zeros_like_bool(x): | |||
| """Returns False if x is a bool.""" | |||
| return False | |||
| newenv = base.EnvInstance_() | |||
| @@ -56,6 +56,7 @@ tensor_pow = P.Pow() | |||
| tensor_mod = P.FloorMod() | |||
| strided_slice = P.StridedSlice() | |||
| same_type_shape = P.SameTypeShape() | |||
| check_bprop = P.CheckBprop() | |||
| equal = P.Equal() | |||
| not_equal = P.NotEqual() | |||
| assign_sub = P.AssignSub() | |||
| @@ -67,7 +67,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, | |||
| ApplyRMSProp, ApplyCenteredRMSProp) | |||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey | |||
| from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop | |||
| from . import _quant_ops | |||
| from ._quant_ops import * | |||
| @@ -179,6 +179,7 @@ __all__ = [ | |||
| 'GeSwitch', | |||
| 'Merge', | |||
| 'SameTypeShape', | |||
| 'CheckBprop', | |||
| 'CheckValid', | |||
| 'BoundingBoxEncode', | |||
| 'BoundingBoxDecode', | |||
| @@ -269,3 +269,66 @@ class MakeRefKey(Primitive): | |||
| def __call__(self): | |||
| pass | |||
| class CheckBprop(PrimitiveWithInfer): | |||
| """ | |||
| Checks whether data type and shape of corresponding element from tuple x and y are the same. | |||
| Raises: | |||
| TypeError: If not the same. | |||
| Inputs: | |||
| - **input_x** (tuple[Tensor]) - The input_x contains the outputs of bprop to be checked. | |||
| - **input_y** (tuple[Tensor]) - The input_y contains the inputs of bprop to check against. | |||
| Outputs: | |||
| (tuple[Tensor]), the input_x, | |||
| if data type and shape of corresponding elements from `input_x` and `input_y` are the same. | |||
| Examples: | |||
| >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) | |||
| >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) | |||
| >>> out = P.CheckBprop()(input_x, input_y) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init CheckBprop""" | |||
| def infer_shape(self, xshapes, yshapes): | |||
| tips = f'Bprop of {self.prim_to_check}' | |||
| if len(xshapes) < len(yshapes): | |||
| raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," | |||
| f" but got {len(xshapes)}.") | |||
| checking_range = len(yshapes) | |||
| for i in range(checking_range): | |||
| xshape = xshapes[i] | |||
| yshape = yshapes[i] | |||
| if not xshape or not yshape: | |||
| continue | |||
| if xshape != yshape: | |||
| raise TypeError(f"{tips}, the shape of {i}th output should be {yshape}," | |||
| f" but got {xshape}.") | |||
| return xshapes | |||
| def infer_dtype(self, xdtypes, ydtypes): | |||
| tips = f'Bprop of {self.prim_to_check}' | |||
| if len(xdtypes) < len(ydtypes): | |||
| raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," | |||
| f" but got {len(xdtypes)}.") | |||
| checking_range = len(ydtypes) | |||
| for i in range(checking_range): | |||
| xdtype = xdtypes[i] | |||
| ydtype = ydtypes[i] | |||
| if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type): | |||
| continue | |||
| if isinstance(ydtype, mstype.function_type): | |||
| if not isinstance(xdtype, mstype.env_type_type): | |||
| raise TypeError(f"{tips}, the dtype of {i}th output should be {mstype.env_type_type}," | |||
| f" but got {xdtype}.") | |||
| continue | |||
| if xdtype != ydtype: | |||
| raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," | |||
| f" but got {xdtype}.") | |||
| return xdtypes | |||
| @@ -317,7 +317,7 @@ test_case_cell_ops = [ | |||
| initializer_range=0.02, | |||
| dropout_prob=0.1), | |||
| 'desc_inputs': [[1, 768], [1, 768]], | |||
| 'desc_bprop': [[1, 128, 768]]}), # maybe not right | |||
| 'desc_bprop': [[1, 768]]}), | |||
| ('BertTransformer_2', { | |||
| 'block': bert_trans(), | |||
| 'desc_inputs': [[1, 128, 768], [1, 128, 128]]}), | |||
| @@ -331,7 +331,7 @@ test_case_cell_ops = [ | |||
| 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | |||
| Tensor(np.random.rand(128).astype(np.int32)), [128]], | |||
| 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | |||
| 'num_output': 3}), # maybe not right | |||
| 'num_output': 3}), | |||
| ('BertModel_1', { | |||
| 'block': BertModel(config=BertConfig(batch_size=1, | |||
| @@ -342,7 +342,7 @@ test_case_cell_ops = [ | |||
| 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | |||
| Tensor(np.random.rand(128).astype(np.int32)), [128]], | |||
| 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | |||
| 'num_output': 3}), # maybe not right | |||
| 'num_output': 3}), | |||
| ('BertModel_2', { | |||
| 'block': BertModel(config=BertConfig(batch_size=1, | |||
| @@ -354,7 +354,7 @@ test_case_cell_ops = [ | |||
| 'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), | |||
| Tensor(np.random.rand(128).astype(np.int32)), [128]], | |||
| 'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]], | |||
| 'num_output': 3}), # maybe not right | |||
| 'num_output': 3}), | |||
| ('BertPretrainingLoss', { | |||
| 'block': BertPretrainingLoss(config=BertConfig(batch_size=1)), | |||
| @@ -175,7 +175,7 @@ class GetParamGrad(nn.Cell): | |||
| def test_grad_conv_prelu(): | |||
| shapes = [[64, 64, 112, 112]] | |||
| outshape = [[64, 64, 56, 56]] | |||
| outshape = [[64, 64, 112, 112]] | |||
| net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True) | |||
| inputs = [convert(shp, dtype=np.float16) for shp in shapes] | |||
| sens_shape = outshape[0] | |||
| @@ -585,7 +585,7 @@ test_case_nn_ops = [ | |||
| ('ReLUV2', { | |||
| 'block': P.ReLUV2(), | |||
| 'desc_inputs': [[1, 3, 4, 4]], | |||
| 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), | |||
| 'desc_bprop': [[1, 3, 4, 4], ([1, 1, 4, 4, 2], {'dtype': np.uint8})]}), | |||
| ('ReLUGrad', { | |||
| 'block': G.ReluGrad(), | |||
| 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], | |||
| @@ -626,7 +626,7 @@ test_case_nn_ops = [ | |||
| ('MaxPoolWithArgmax', { | |||
| 'block': P.MaxPoolWithArgmax(ksize=2, strides=2), | |||
| 'desc_inputs': [[128, 32, 32, 64]], | |||
| 'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}), | |||
| 'desc_bprop': [[128, 32, 16, 32], ([128, 32, 4, 33], {'dtype': np.uint16})]}), | |||
| ('SoftmaxCrossEntropyWithLogits', { | |||
| 'block': P.SoftmaxCrossEntropyWithLogits(), | |||
| 'desc_inputs': [[1, 10], [1, 10]], | |||
| @@ -639,7 +639,7 @@ test_case_nn_ops = [ | |||
| ('LogSoftmax', { | |||
| 'block': P.LogSoftmax(), | |||
| 'desc_inputs': [[64, 2]], | |||
| 'desc_bprop': [[160, 30522]]}), | |||
| 'desc_bprop': [[64, 2]]}), | |||
| ('LogSoftmaxGrad', { | |||
| 'block': G.LogSoftmaxGrad(), | |||
| 'desc_inputs': [[16, 1234], [16, 1234]], | |||
| @@ -648,7 +648,7 @@ test_case_nn_ops = [ | |||
| ('LayerNorm', { | |||
| 'block': P.LayerNorm(), | |||
| 'desc_inputs': [[2, 16], [16], [16]], | |||
| 'desc_bprop': [[2, 16], [2, 16], [2, 16]]}), | |||
| 'desc_bprop': [[2, 16], [2, 1], [2, 1]]}), | |||
| ('LayerNormGrad', { | |||
| 'block': G.LayerNormGrad(), | |||
| 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], | |||
| @@ -845,7 +845,7 @@ test_case_nn_ops = [ | |||
| 'block': P.OneHot(), | |||
| 'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)], | |||
| 'desc_inputs': [Tensor(np.array([64]).astype(np.int32))], | |||
| 'desc_bprop': [[64, 2]]}), | |||
| 'desc_bprop': [[1, 3]]}), | |||
| ('ReduceProd_0', { | |||
| 'block': P.ReduceProd(), | |||
| 'desc_const': [0], | |||
| @@ -950,7 +950,7 @@ test_case_array_ops = [ | |||
| 'block': P.Cast(), | |||
| 'desc_const': [mstype.int32], | |||
| 'desc_inputs': [[2, 3, 4, 5]], | |||
| 'desc_bprop': [Tensor(np.ones((2, 3, 3, 5)).astype(np.int32))]}), | |||
| 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5)).astype(np.int32))]}), | |||
| ('ExpandDims', { | |||
| 'block': P.ExpandDims(), | |||
| 'desc_const': [0], | |||
| @@ -1002,12 +1002,12 @@ test_case_array_ops = [ | |||
| 'desc_inputs': [ | |||
| (Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)), | |||
| Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32)))], | |||
| 'desc_bprop': [[4, 2]]}), | |||
| 'desc_bprop': [([4, 2], {'dtype': np.int32})]}), | |||
| ('ConcatV2_1', { | |||
| 'block': P.Concat(axis=2), | |||
| 'desc_inputs': [(Tensor(np.array([[[0, 1, 2]], [[2, 1, 2]]]).astype(np.int32)), | |||
| Tensor(np.array([[[0, 1]], [[2, 1]]]).astype(np.int32)))], | |||
| 'desc_bprop': [[2, 1, 5]]}), | |||
| 'desc_bprop': [([2, 1, 5], {'dtype': np.int32})]}), | |||
| ('ConcatV2_2', { | |||
| 'block': NetForConcat(), | |||
| 'desc_inputs': [[2, 2]], | |||
| @@ -1042,7 +1042,7 @@ test_case_array_ops = [ | |||
| ('Pack_2', { | |||
| 'block': NetForPackInput(P.Pack()), | |||
| 'desc_inputs':[[2, 2]], | |||
| 'desc_bprop':[[2, 2, 2]], | |||
| 'desc_bprop':[[1, 2, 2]], | |||
| }), | |||
| ('Pack_3', { | |||
| 'block': NetForPackInput(P.Pack()), | |||
| @@ -1077,7 +1077,7 @@ test_case_array_ops = [ | |||
| ('SpaceToBatch_2', { | |||
| 'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]), | |||
| 'desc_inputs': [[1, 3, 2, 2]], | |||
| 'desc_bprop': [[4, 3, 2, 4]], | |||
| 'desc_bprop': [[4, 3, 2, 3]], | |||
| }), | |||
| ('BatchToSpace_1', { | |||
| 'block': P.BatchToSpace(2, [[0, 0], [0, 0]]), | |||
| @@ -1124,7 +1124,7 @@ test_case_other_ops = [ | |||
| 'desc_const': [(3, 3)], | |||
| 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), | |||
| Tensor(np.ones((2,), np.int32))), | |||
| 'desc_bprop': [[3, 3]]}), | |||
| 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), | |||
| ('SmoothL1Loss', { | |||
| 'block': P.SmoothL1Loss(), | |||
| 'desc_inputs': [[256, 4], [256, 4]], | |||
| @@ -229,12 +229,6 @@ class TwoInputBprop(nn.Cell): | |||
| def bprop(self, x, y, out, dout): | |||
| return 5 * x, 8 * y | |||
| class TwoInput(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.op = P.Mul() | |||
| def construct(self, x, y): | |||
| return self.op(x, y) | |||
| class TwoInputWithParameter(nn.Cell): | |||
| def __init__(self): | |||
| @@ -301,8 +295,37 @@ class MulAddWithWrongOutputNum(nn.Cell): | |||
| def construct(self, x, y): | |||
| return 2 * x + y | |||
| def bprop(self, x, y, out, dout): | |||
| return 2 * dout, 2 * y, out | |||
| return 2 * dout, | |||
| def test_grad_mul_add_with_wrong_output_num(): | |||
| mul_add = MulAddWithWrongOutputNum() | |||
| C.grad_all(mul_add)(1, 2) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(mul_add)(1, 2) | |||
| class MulAddWithWrongOutputType(nn.Cell): | |||
| def __init__(self): | |||
| super(MulAddWithWrongOutputType, self).__init__() | |||
| def construct(self, x, y): | |||
| return 2 * x + y | |||
| def bprop(self, x, y, out, dout): | |||
| return 2 * dout, 2 | |||
| def test_grad_mul_add_with_wrong_output_type(): | |||
| mul_add = MulAddWithWrongOutputType() | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) | |||
| class MulAddWithWrongOutputShape(nn.Cell): | |||
| def __init__(self): | |||
| super(MulAddWithWrongOutputShape, self).__init__() | |||
| self.ones = Tensor(np.ones([2,])) | |||
| def construct(self, x, y): | |||
| return 2 * x + y | |||
| def bprop(self, x, y, out, dout): | |||
| return 2, self.ones | |||
| def test_grad_mul_add_with_wrong_output_shape(): | |||
| mul_add = MulAddWithWrongOutputShape() | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) | |||
| @@ -32,6 +32,8 @@ from ....mindspore_test_framework.utils.check_gradient import ( | |||
| OperationGradChecker, check_gradient, ScalarGradChecker) | |||
| from ....mindspore_test_framework.utils.bprop_util import bprop | |||
| import mindspore.context as context | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||
| def setup_module(module): | |||
| @@ -721,3 +723,94 @@ def test_grad_if_defer_inline(): | |||
| inp = Tensor(np.ones([128, 96]).astype(np.float32)) | |||
| grads = C.grad_all(network)(inp) | |||
| assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) | |||
| def test_bprop_with_wrong_output_num(): | |||
| class BpropWithWrongOutputNum(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') | |||
| def __call__(self, x, y): | |||
| return x | |||
| def infer_shape(self, x_shape, yshape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type, y_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputNum) | |||
| def get_bprop_with_wrong_output_num(self): | |||
| """Generate bprop for BpropWithWrongOutputNum""" | |||
| def bprop(x, y, out, dout): | |||
| return (dout,) | |||
| return bprop | |||
| class BpropWithWrongOutputNumCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputNumCell, self).__init__() | |||
| def construct(self, x, y): | |||
| return BpropWithWrongOutputNum()(x, y) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||
| def test_bprop_with_wrong_output_type(): | |||
| class BpropWithWrongOutputType(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') | |||
| def __call__(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputType) | |||
| def get_bprop_with_wrong_output_type(self): | |||
| """Generate bprop for BpropWithWrongOutputType""" | |||
| def bprop(x, out, dout): | |||
| return (1,) | |||
| return bprop | |||
| class BpropWithWrongOutputTypeCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputTypeCell, self).__init__() | |||
| def construct(self, x): | |||
| return BpropWithWrongOutputType()(x) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| def test_bprop_with_wrong_output_shape(): | |||
| class BpropWithWrongOutputShape(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') | |||
| def __call__(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputShape) | |||
| def get_bprop_with_wrong_output_shape(self): | |||
| """Generate bprop for BpropWithWrongOutputShape""" | |||
| ones = Tensor(np.ones([2,]).astype(np.int32)) | |||
| def bprop(x, out, dout): | |||
| return (ones,) | |||
| return bprop | |||
| class BpropWithWrongOutputShapeCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputShapeCell, self).__init__() | |||
| def construct(self, x): | |||
| return BpropWithWrongOutputShape()(x) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| @@ -79,7 +79,7 @@ def test_InsertGradientOf_2(): | |||
| summary = P.ScalarSummary() | |||
| def debug_gradient(dx): | |||
| """ debug_gradient """ | |||
| dx = summary("dx: ", dx) | |||
| summary("dx: ", dx) | |||
| return dx | |||
| debug = P.InsertGradientOf(debug_gradient) | |||