Merge pull request !1160 from wangqiuliang/add-backward-hook-in-pynative-modetags/v0.3.0-alpha
| @@ -102,7 +102,10 @@ def get_parse_method_of_class(obj, parse_method=None): | |||
| method_name = parse_method | |||
| else: | |||
| if isinstance(obj, nn.Cell): | |||
| method_name = "construct" | |||
| if obj.enable_hook: | |||
| method_name = "_hook_construct" | |||
| else: | |||
| method_name = "construct" | |||
| if method_name is not None: | |||
| if hasattr(obj, method_name): | |||
| method = getattr(obj, method_name) | |||
| @@ -115,6 +115,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | |||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | |||
| .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | |||
| .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | |||
| .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -23,7 +23,6 @@ | |||
| #include <string> | |||
| #include <tuple> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "utils/misc.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -31,8 +30,6 @@ | |||
| #include "ir/signature.h" | |||
| #include "parallel/ops_info/operator_info.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| class PrimitivePy : public Primitive { | |||
| public: | |||
| @@ -24,6 +24,9 @@ | |||
| #include <tuple> | |||
| #include "ir/dtype/type.h" | |||
| #include "pybind11/pybind11.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| // Supported meta type | |||
| @@ -73,6 +76,9 @@ class Primitive : public Named { | |||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||
| } | |||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||
| py::function hook() const { return hook_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||
| @@ -103,6 +109,7 @@ class Primitive : public Named { | |||
| private: | |||
| std::string instance_name_; | |||
| py::function hook_; | |||
| bool is_base_; | |||
| bool has_signature_; | |||
| PrimType prim_type_; | |||
| @@ -213,6 +213,7 @@ const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||
| const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | |||
| const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor"); | |||
| const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | |||
| const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||
| // Other miscellaneous | |||
| const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | |||
| @@ -226,6 +227,7 @@ const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||
| const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | |||
| const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | |||
| const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | |||
| const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||
| const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | |||
| const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | |||
| const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | |||
| @@ -218,6 +218,7 @@ extern const PrimitivePtr kPrimReluV2; | |||
| extern const PrimitivePtr kPrimActivation; | |||
| extern const PrimitivePtr kPrimZerosLikeTensor; | |||
| extern const PrimitivePtr kPrimFakeBprop; | |||
| extern const PrimitivePtr kPrimBpropCut; | |||
| // Other Miscellaneous | |||
| extern const PrimitivePtr kPrimIdentity; | |||
| @@ -232,6 +233,7 @@ extern const PrimitivePtr kPrimGetRefKey; | |||
| extern const PrimitivePtr kPrimGetRefValue; | |||
| extern const PrimitivePtr kPrimGetRefOrigin; | |||
| extern const PrimitivePtr kPrimInsertGradientOf; | |||
| extern const PrimitivePtr kPrimHookBackward; | |||
| extern const PrimitivePtr kPrimPrintShapeType; | |||
| extern const PrimitivePtr kPrimPrint; | |||
| extern const PrimitivePtr kPrimSameTypeShape; | |||
| @@ -285,6 +285,16 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| AbstractBasePtrList args_list; | |||
| for (size_t i = 0; i < args_spec_list.size() - 2; i++) { | |||
| args_list.push_back(args_spec_list[i]->Broaden()); | |||
| } | |||
| return std::make_shared<AbstractTuple>(args_list); | |||
| } | |||
| AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: three tensors(x, gamma, beta). | |||
| @@ -32,6 +32,7 @@ | |||
| #include "operator/ops.h" | |||
| #include "operator/composite/composite.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "./common.h" | |||
| namespace mindspore { | |||
| @@ -125,6 +125,7 @@ class KPrim { | |||
| FuncGraphPtr GetBprop(const PrimitivePtr &prim); | |||
| FuncGraphPtr GetFprop(const PrimitivePtr &prim); | |||
| FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||
| FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||
| // Given a bprop rule, do the K mapping. | |||
| template <typename T> | |||
| FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); | |||
| @@ -115,10 +115,15 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R | |||
| } | |||
| bool is_faked_bprop = false; | |||
| auto bprop_fg = GetBprop(prim); | |||
| if (bprop_fg == nullptr) { | |||
| bprop_fg = FakeBprop(value_node, resources); | |||
| is_faked_bprop = true; | |||
| FuncGraphPtr bprop_fg = nullptr; | |||
| if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") { | |||
| bprop_fg = BpropCut(value_node, resources); | |||
| } else { | |||
| bprop_fg = GetBprop(prim); | |||
| if (bprop_fg == nullptr) { | |||
| bprop_fg = FakeBprop(value_node, resources); | |||
| is_faked_bprop = true; | |||
| } | |||
| } | |||
| auto expanded_fg = BpropToK(prim, bprop_fg); | |||
| @@ -206,6 +211,45 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | |||
| return expanded_fg; | |||
| } | |||
| FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | |||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto &node_users = resources->manager()->node_users(); | |||
| auto &users = node_users[value_node]; | |||
| auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool { | |||
| return IsPrimitiveCNode(user.first, prim); | |||
| }); | |||
| if (cnode == users.end()) { | |||
| MS_LOG(EXCEPTION) << "Fail to find cnode."; | |||
| } | |||
| auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1; | |||
| auto func_graph = std::make_shared<FuncGraph>(); | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto bprop_cut = std::make_shared<Primitive>("bprop_cut"); | |||
| bprop_cut->set_hook(prim->hook()); | |||
| auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||
| if (cell_id != "") { | |||
| (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); | |||
| (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); | |||
| } | |||
| outputs.push_back(NewValueNode(bprop_cut)); | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| auto param = func_graph->add_parameter(); | |||
| outputs.push_back(param); | |||
| } | |||
| auto p1 = func_graph->add_parameter(); | |||
| auto p2 = func_graph->add_parameter(); | |||
| outputs.push_back(p1); | |||
| outputs.push_back(p2); | |||
| func_graph->set_output(func_graph->NewCNode(outputs)); | |||
| return func_graph; | |||
| } | |||
| FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | |||
| auto prim = value_node->value()->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| @@ -49,9 +49,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| special_op_eliminate_ = | |||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | |||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| @@ -35,11 +35,13 @@ class SpecialOpEliminater { | |||
| public: | |||
| SpecialOpEliminater() | |||
| : insert_gradient_of_(prim::kPrimInsertGradientOf), | |||
| hook_backward_(prim::kPrimHookBackward), | |||
| print_shape_type_(prim::kPrimPrintShapeType), | |||
| get_ref_value_(prim::kPrimGetRefValue), | |||
| mirror_(prim::kPrimMirror), | |||
| virtual_div_(prim::kPrimVirtualDiv) { | |||
| eliminaters_.emplace_back(insert_gradient_of_); | |||
| eliminaters_.emplace_back(hook_backward_); | |||
| eliminaters_.emplace_back(print_shape_type_); | |||
| eliminaters_.emplace_back(get_ref_value_); | |||
| eliminaters_.emplace_back(mirror_); | |||
| @@ -59,7 +61,7 @@ class SpecialOpEliminater { | |||
| } | |||
| private: | |||
| PrimEliminater insert_gradient_of_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; | |||
| PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| @@ -30,6 +30,7 @@ | |||
| #include "operator/composite/composite.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "debug/trace.h" | |||
| namespace mindspore { | |||
| @@ -207,6 +208,35 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) { | |||
| return true; | |||
| } | |||
| FuncGraphPtr ConvertToBpropCut(py::object obj) { | |||
| std::vector<std::string> results = data_converter::GetObjKey(obj); | |||
| std::string obj_key = results[0]; | |||
| py::function bprop_func = py::getattr(obj, "bprop"); | |||
| FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>(); | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto fake_bprop = std::make_shared<Primitive>("bprop_cut"); | |||
| fake_bprop->set_hook(bprop_func); | |||
| (void)fake_bprop->AddAttr("bprop", MakeValue(true)); | |||
| outputs.push_back(NewValueNode(fake_bprop)); | |||
| py::object code_obj = py::getattr(bprop_func, "__code__"); | |||
| size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3; | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| auto param = bprop_graph->add_parameter(); | |||
| outputs.push_back(param); | |||
| } | |||
| auto p1 = bprop_graph->add_parameter(); | |||
| auto p2 = bprop_graph->add_parameter(); | |||
| outputs.push_back(p1); | |||
| outputs.push_back(p2); | |||
| bprop_graph->set_output(bprop_graph->NewCNode(outputs)); | |||
| data_converter::SetObjGraphValue(obj_key, bprop_graph); | |||
| return bprop_graph; | |||
| } | |||
| bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||
| auto obj_type = data_converter::GetObjType(obj); | |||
| MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; | |||
| @@ -238,7 +268,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||
| } | |||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | |||
| if (py::hasattr(obj, "bprop")) { | |||
| FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||
| FuncGraphPtr bprop_graph = nullptr; | |||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||
| if (enable_bprop_debug) { | |||
| bprop_graph = ConvertToBpropCut(obj); | |||
| } else { | |||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||
| } | |||
| if (bprop_graph != nullptr) { | |||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | |||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||
| @@ -108,6 +108,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimRelu, {InferImplRelu, true}}, | |||
| {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, | |||
| {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, | |||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | |||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | |||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | |||
| @@ -210,6 +210,8 @@ AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const Primit | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -64,6 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { | |||
| result.outputs = outputs; | |||
| result.graph_id = kInvalidGraphId; | |||
| auto graph_id = sess_->CompileGraph(lst, outputs); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||
| sess_->BuildGraph(graph_id); | |||
| } | |||
| if (MsContext::GetInstance()->precompile_only()) { | |||
| MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | |||
| return result; | |||
| @@ -40,9 +40,10 @@ using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>; | |||
| using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>; | |||
| std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | |||
| prim::kPrimMakeTuple}; | |||
| prim::kPrimMakeTuple, prim::kPrimBpropCut}; | |||
| const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | |||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch}; | |||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | |||
| prim::kPrimBpropCut}; | |||
| return ms_nonlinear_ops; | |||
| } | |||
| @@ -646,8 +647,13 @@ BackendPtr CreateBackend() { | |||
| auto backend = std::make_shared<MsBackend>(name, target, device_id); | |||
| std::string device_target = MsContext::GetInstance()->device_target(); | |||
| if (device_target == kAscendDevice) { | |||
| backend->set_is_multi_graph_sink(true); | |||
| context_ptr->set_is_multi_graph_sink(true); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||
| backend->set_is_multi_graph_sink(false); | |||
| context_ptr->set_is_multi_graph_sink(false); | |||
| } else { | |||
| backend->set_is_multi_graph_sink(true); | |||
| context_ptr->set_is_multi_graph_sink(true); | |||
| } | |||
| } | |||
| return backend; | |||
| } | |||
| @@ -587,15 +587,65 @@ void FinalVM::InstPushPrim(const VectorRef &args) { | |||
| VectorRef tuple; | |||
| auto prim = utils::cast<PrimitivePtr>(args[0]); | |||
| for (size_t i = 1; i < args.size(); ++i) { | |||
| auto index = utils::cast<int>(args[1]); | |||
| auto index = utils::cast<int>(args[i]); | |||
| tuple.push_back(Ref(index)); | |||
| } | |||
| auto outs = RunOperation(prim, tuple); | |||
| Push(outs); | |||
| if (prim->name() == "bprop_cut") { | |||
| auto outs = RunHook(prim, tuple); | |||
| Push(outs); | |||
| } else { | |||
| auto outs = RunOperation(prim, tuple); | |||
| Push(outs); | |||
| } | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||
| py::tuple py_args = py::tuple(args.size()); | |||
| MS_LOG(DEBUG) << "input for operation:"; | |||
| size_t i = 0; | |||
| for (auto &arg : args) { | |||
| py_args[i] = BaseRefToPyData(arg); | |||
| MS_LOG(DEBUG) << "arg: " << i << ":"; | |||
| i++; | |||
| } | |||
| py::object obj; | |||
| bool is_bprop = prim->HasAttr("bprop"); | |||
| if (is_bprop) { | |||
| py::function fn_bprop = prim->hook(); | |||
| obj = fn_bprop(*py_args); | |||
| return obj; | |||
| } | |||
| bool is_cell = prim->HasAttr("cell_hook"); | |||
| if (is_cell) { | |||
| std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||
| if (_hook_grad.find(cell_id) != _hook_grad.end()) { | |||
| py::tuple hook_args = py::tuple(3); | |||
| hook_args[0] = cell_id; | |||
| hook_args[1] = _hook_grad[cell_id]; | |||
| hook_args[2] = py_args[2]; | |||
| py::function fn_hook = prim->hook(); | |||
| obj = fn_hook(*hook_args); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| obj = py_args[2]; | |||
| } | |||
| _hook_grad.erase(cell_id); | |||
| } else { | |||
| _hook_grad[cell_id] = py_args[2]; | |||
| obj = py_args[2]; | |||
| } | |||
| } else { | |||
| py::function fn_hook = prim->hook(); | |||
| obj = fn_hook(py_args[2]); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| obj = py_args[2]; | |||
| } | |||
| } | |||
| obj = py::make_tuple(obj); | |||
| return obj; | |||
| } | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -115,6 +115,7 @@ class FinalVM { | |||
| void InstPushPrim(const VectorRef &args); | |||
| void InstSwitchReturn(const VectorRef &args); | |||
| void set_insts(const InstSet &value) { insts_ = value; } | |||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args); | |||
| protected: | |||
| BaseRef Ref(int i); | |||
| @@ -156,6 +157,7 @@ class FinalVM { | |||
| {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | |||
| {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | |||
| }; | |||
| std::map<std::string, py::object> _hook_grad; | |||
| }; | |||
| using FinalVMPtr = std::shared_ptr<FinalVM>; | |||
| @@ -24,6 +24,7 @@ from .._checkparam import _check_str_by_regular | |||
| from ..common.parameter import Parameter, ParameterTuple | |||
| from .._c_expression import init_backend | |||
| from ..ops.primitive import Primitive | |||
| from ..ops.operations import HookBackward | |||
| from ..parallel._tensor import _load_tensor_by_layout | |||
| from ..common.tensor import Tensor | |||
| @@ -75,6 +76,9 @@ class Cell: | |||
| self._parallel_inputs_run = None | |||
| if flags: | |||
| self.add_flags(**flags) | |||
| self._backward_hook = None | |||
| self._enable_hook = False | |||
| self._bprop_debug = False | |||
| @property | |||
| def create_time(self): | |||
| @@ -91,6 +95,16 @@ class Cell: | |||
| """ | |||
| return self._param_prefix | |||
| @property | |||
| def bprop_debug(self): | |||
| return self._bprop_debug | |||
| @bprop_debug.setter | |||
| def bprop_debug(self, value): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("'bprop debug' value must be bool type.") | |||
| self._bprop_debug = value | |||
| def update_cell_prefix(self): | |||
| """ | |||
| Update the all child cells' self.param_prefix. | |||
| @@ -728,3 +742,25 @@ class Cell: | |||
| self._auto_parallel_mode = True | |||
| self.add_flags(auto_parallel=True) | |||
| self._get_construct_inputs_number_and_name() | |||
| def _hook_construct(self, inputs): | |||
| """Hook construct method to replace original construct method when hook function enabled.""" | |||
| inputs = self._backward_hook(inputs) | |||
| inputs = self.construct(inputs) | |||
| outputs = self._backward_hook(inputs) | |||
| return outputs | |||
| @property | |||
| def enable_hook(self): | |||
| """Whether the cell register hook function""" | |||
| return self._enable_hook | |||
| def register_backward_hook(self, fn): | |||
| """ | |||
| Set the cell backward hook function. | |||
| Args: | |||
| fn (function): Specifies the hook function with grad as input. | |||
| """ | |||
| self._backward_hook = HookBackward(fn, str(id(self))) | |||
| self._enable_hook = True | |||
| @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary, | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| from .inner_ops import ScalarCast | |||
| @@ -155,6 +155,7 @@ __all__ = [ | |||
| 'HistogramSummary', | |||
| "Print", | |||
| 'InsertGradientOf', | |||
| 'HookBackward', | |||
| 'InvertPermutation', | |||
| 'Shape', | |||
| 'DropoutDoMask', | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """debug_ops""" | |||
| from types import FunctionType | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | |||
| @@ -241,6 +242,65 @@ class InsertGradientOf(PrimitiveWithInfer): | |||
| return x_type | |||
| class HookBackward(PrimitiveWithInfer): | |||
| """ | |||
| Used as tag to hook gradient in intermediate variables. | |||
| Note: | |||
| The hook function should have one input of gradient of the variable. | |||
| hook function will be executed in python environment, while callback | |||
| of InsertGradientOf will be parsed and added to the graph. | |||
| Args: | |||
| hook_fn (Function): Python function. hook function. | |||
| Inputs: | |||
| - **inputs** (Tensor) - The variable to hook. | |||
| Examples: | |||
| >>> def hook_fn(grad_out): | |||
| >>> print(grad_out) | |||
| >>> | |||
| >>> hook = P.HookBackward(hook_fn) | |||
| >>> | |||
| >>> def hook_test(x, y): | |||
| >>> z = x * y | |||
| >>> z = hook(z) | |||
| >>> z = z * y | |||
| >>> return z | |||
| >>> | |||
| >>> def backward(x, y): | |||
| >>> return C.grad_all(hook_test)(x, y) | |||
| >>> | |||
| >>> backward(1, 2) | |||
| """ | |||
| def __init__(self, hook_fn, cell_id=""): | |||
| super(HookBackward, self).__init__(self.__class__.__name__) | |||
| self.add_prim_attr("cell_id", cell_id) | |||
| self.init_attrs["cell_id"] = cell_id | |||
| if not isinstance(hook_fn, FunctionType): | |||
| raise TypeError("Hook function should be python function type.") | |||
| self.register_hook(hook_fn) | |||
| self.cell_id = cell_id | |||
| def __call__(self, *inputs): | |||
| """run in PyNative mode.""" | |||
| if len(inputs) == 1: | |||
| return inputs[0] | |||
| return inputs | |||
| def infer_shape(self, *inputs_shape): | |||
| if len(inputs_shape) == 1: | |||
| return inputs_shape[0] | |||
| return inputs_shape | |||
| def infer_dtype(self, *inputs_type): | |||
| if len(inputs_type) == 1: | |||
| return inputs_type[0] | |||
| return inputs_type | |||
| class Print(PrimitiveWithInfer): | |||
| """ | |||
| Output tensor or string to stdout. | |||
| @@ -0,0 +1,133 @@ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| from mindspore import context | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore import context, Tensor, ParameterTuple | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.nn import Dense, WithLossCell, SoftmaxCrossEntropyWithLogits, Momentum | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| """weight initial for conv layer""" | |||
| weight = weight_variable() | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| """weight initial for fc layer""" | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| """weight initial""" | |||
| return TruncatedNormal(0.02) | |||
| def cell_hook_function(cell_id, grad_input, grad_output): | |||
| print(cell_id) | |||
| assert(grad_output.asnumpy().shape == (32, 6, 14, 14)) | |||
| assert(grad_input.asnumpy().shape == (32, 16, 10, 10)) | |||
| def var_hook_function(grad_out): | |||
| print("grad:", grad_out) | |||
| assert(grad_out.asnumpy().shape == (32, 120)) | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| Args: | |||
| num_class (int): Num classes. Default: 10. | |||
| Returns: | |||
| Tensor, output tensor | |||
| Examples: | |||
| >>> LeNet(num_class=10) | |||
| """ | |||
| def __init__(self, num_class=10): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.batch_size = 32 | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.conv2.register_backward_hook(cell_hook_function) | |||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.reshape = P.Reshape() | |||
| self.hook = P.HookBackward(var_hook_function) | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.reshape(x, (self.batch_size, -1)) | |||
| x = self.fc1(x) | |||
| x = self.hook(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| class GradWrap(nn.Cell): | |||
| """ GradWrap definition """ | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) | |||
| def construct(self, x, label): | |||
| weights = self.weights | |||
| return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label) | |||
| def test_hook(): | |||
| net = LeNet5() | |||
| optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False) | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| train_network = GradWrap(net_with_criterion) | |||
| train_network.set_train() | |||
| input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) | |||
| output = net(Tensor(input_data)) | |||
| loss_output = criterion(output, label) | |||
| grads = train_network(input_data, label) | |||
| success = optimizer(grads) | |||
| print(loss_output.asnumpy().shape) | |||
| class MulAdd(nn.Cell): | |||
| def __init__(self): | |||
| super(MulAdd, self).__init__() | |||
| def construct(self, x, y): | |||
| return 2 * x + y | |||
| def bprop(self, x, y, out, dout): | |||
| assert(x == 1) | |||
| assert(y == 2) | |||
| assert(out == 4) | |||
| assert(dout == 1) | |||
| return 3 * dout, 2 * y | |||
| def test_custom_bprop(): | |||
| mul_add = MulAdd() | |||
| mul_add.bprop_debug = True | |||
| assert C.grad_all(mul_add)(1, 2) == (3, 4) | |||