Merge pull request !1160 from wangqiuliang/add-backward-hook-in-pynative-modetags/v0.3.0-alpha
| @@ -102,7 +102,10 @@ def get_parse_method_of_class(obj, parse_method=None): | |||||
| method_name = parse_method | method_name = parse_method | ||||
| else: | else: | ||||
| if isinstance(obj, nn.Cell): | if isinstance(obj, nn.Cell): | ||||
| method_name = "construct" | |||||
| if obj.enable_hook: | |||||
| method_name = "_hook_construct" | |||||
| else: | |||||
| method_name = "construct" | |||||
| if method_name is not None: | if method_name is not None: | ||||
| if hasattr(obj, method_name): | if hasattr(obj, method_name): | ||||
| method = getattr(obj, method_name) | method = getattr(obj, method_name) | ||||
| @@ -115,6 +115,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | ||||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | ||||
| .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | ||||
| .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | |||||
| .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | ||||
| })); | })); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include "pybind11/pybind11.h" | |||||
| #include "pipeline/static_analysis/abstract_value.h" | #include "pipeline/static_analysis/abstract_value.h" | ||||
| #include "utils/misc.h" | #include "utils/misc.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -31,8 +30,6 @@ | |||||
| #include "ir/signature.h" | #include "ir/signature.h" | ||||
| #include "parallel/ops_info/operator_info.h" | #include "parallel/ops_info/operator_info.h" | ||||
| namespace py = pybind11; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class PrimitivePy : public Primitive { | class PrimitivePy : public Primitive { | ||||
| public: | public: | ||||
| @@ -24,6 +24,9 @@ | |||||
| #include <tuple> | #include <tuple> | ||||
| #include "ir/dtype/type.h" | #include "ir/dtype/type.h" | ||||
| #include "pybind11/pybind11.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // Supported meta type | // Supported meta type | ||||
| @@ -73,6 +76,9 @@ class Primitive : public Named { | |||||
| return iter == attrs_.cend() ? nullptr : iter->second; | return iter == attrs_.cend() ? nullptr : iter->second; | ||||
| } | } | ||||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||||
| py::function hook() const { return hook_; } | |||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | ||||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | ||||
| @@ -103,6 +109,7 @@ class Primitive : public Named { | |||||
| private: | private: | ||||
| std::string instance_name_; | std::string instance_name_; | ||||
| py::function hook_; | |||||
| bool is_base_; | bool is_base_; | ||||
| bool has_signature_; | bool has_signature_; | ||||
| PrimType prim_type_; | PrimType prim_type_; | ||||
| @@ -213,6 +213,7 @@ const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); | |||||
| const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); | ||||
| const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor"); | const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor"); | ||||
| const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); | ||||
| const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); | |||||
| // Other miscellaneous | // Other miscellaneous | ||||
| const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); | ||||
| @@ -226,6 +227,7 @@ const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); | |||||
| const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | ||||
| const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin"); | ||||
| const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); | ||||
| const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); | |||||
| const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); | ||||
| const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); | ||||
| const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop"); | ||||
| @@ -218,6 +218,7 @@ extern const PrimitivePtr kPrimReluV2; | |||||
| extern const PrimitivePtr kPrimActivation; | extern const PrimitivePtr kPrimActivation; | ||||
| extern const PrimitivePtr kPrimZerosLikeTensor; | extern const PrimitivePtr kPrimZerosLikeTensor; | ||||
| extern const PrimitivePtr kPrimFakeBprop; | extern const PrimitivePtr kPrimFakeBprop; | ||||
| extern const PrimitivePtr kPrimBpropCut; | |||||
| // Other Miscellaneous | // Other Miscellaneous | ||||
| extern const PrimitivePtr kPrimIdentity; | extern const PrimitivePtr kPrimIdentity; | ||||
| @@ -232,6 +233,7 @@ extern const PrimitivePtr kPrimGetRefKey; | |||||
| extern const PrimitivePtr kPrimGetRefValue; | extern const PrimitivePtr kPrimGetRefValue; | ||||
| extern const PrimitivePtr kPrimGetRefOrigin; | extern const PrimitivePtr kPrimGetRefOrigin; | ||||
| extern const PrimitivePtr kPrimInsertGradientOf; | extern const PrimitivePtr kPrimInsertGradientOf; | ||||
| extern const PrimitivePtr kPrimHookBackward; | |||||
| extern const PrimitivePtr kPrimPrintShapeType; | extern const PrimitivePtr kPrimPrintShapeType; | ||||
| extern const PrimitivePtr kPrimPrint; | extern const PrimitivePtr kPrimPrint; | ||||
| extern const PrimitivePtr kPrimSameTypeShape; | extern const PrimitivePtr kPrimSameTypeShape; | ||||
| @@ -285,6 +285,16 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| return args_spec_list[0]->Broaden(); | return args_spec_list[0]->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: a tensor. | |||||
| AbstractBasePtrList args_list; | |||||
| for (size_t i = 0; i < args_spec_list.size() - 2; i++) { | |||||
| args_list.push_back(args_spec_list[i]->Broaden()); | |||||
| } | |||||
| return std::make_shared<AbstractTuple>(args_list); | |||||
| } | |||||
| AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: three tensors(x, gamma, beta). | // Inputs: three tensors(x, gamma, beta). | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "operator/composite/composite.h" | #include "operator/composite/composite.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| #include "./common.h" | #include "./common.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -125,6 +125,7 @@ class KPrim { | |||||
| FuncGraphPtr GetBprop(const PrimitivePtr &prim); | FuncGraphPtr GetBprop(const PrimitivePtr &prim); | ||||
| FuncGraphPtr GetFprop(const PrimitivePtr &prim); | FuncGraphPtr GetFprop(const PrimitivePtr &prim); | ||||
| FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | ||||
| FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); | |||||
| // Given a bprop rule, do the K mapping. | // Given a bprop rule, do the K mapping. | ||||
| template <typename T> | template <typename T> | ||||
| FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); | FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); | ||||
| @@ -115,10 +115,15 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R | |||||
| } | } | ||||
| bool is_faked_bprop = false; | bool is_faked_bprop = false; | ||||
| auto bprop_fg = GetBprop(prim); | |||||
| if (bprop_fg == nullptr) { | |||||
| bprop_fg = FakeBprop(value_node, resources); | |||||
| is_faked_bprop = true; | |||||
| FuncGraphPtr bprop_fg = nullptr; | |||||
| if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") { | |||||
| bprop_fg = BpropCut(value_node, resources); | |||||
| } else { | |||||
| bprop_fg = GetBprop(prim); | |||||
| if (bprop_fg == nullptr) { | |||||
| bprop_fg = FakeBprop(value_node, resources); | |||||
| is_faked_bprop = true; | |||||
| } | |||||
| } | } | ||||
| auto expanded_fg = BpropToK(prim, bprop_fg); | auto expanded_fg = BpropToK(prim, bprop_fg); | ||||
| @@ -206,6 +211,45 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { | |||||
| return expanded_fg; | return expanded_fg; | ||||
| } | } | ||||
| FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | |||||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto &node_users = resources->manager()->node_users(); | |||||
| auto &users = node_users[value_node]; | |||||
| auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool { | |||||
| return IsPrimitiveCNode(user.first, prim); | |||||
| }); | |||||
| if (cnode == users.end()) { | |||||
| MS_LOG(EXCEPTION) << "Fail to find cnode."; | |||||
| } | |||||
| auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1; | |||||
| auto func_graph = std::make_shared<FuncGraph>(); | |||||
| std::vector<AnfNodePtr> outputs; | |||||
| auto bprop_cut = std::make_shared<Primitive>("bprop_cut"); | |||||
| bprop_cut->set_hook(prim->hook()); | |||||
| auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||||
| if (cell_id != "") { | |||||
| (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); | |||||
| (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); | |||||
| } | |||||
| outputs.push_back(NewValueNode(bprop_cut)); | |||||
| for (size_t i = 0; i < inputs_num; ++i) { | |||||
| auto param = func_graph->add_parameter(); | |||||
| outputs.push_back(param); | |||||
| } | |||||
| auto p1 = func_graph->add_parameter(); | |||||
| auto p2 = func_graph->add_parameter(); | |||||
| outputs.push_back(p1); | |||||
| outputs.push_back(p2); | |||||
| func_graph->set_output(func_graph->NewCNode(outputs)); | |||||
| return func_graph; | |||||
| } | |||||
| FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { | ||||
| auto prim = value_node->value()->cast<PrimitivePtr>(); | auto prim = value_node->value()->cast<PrimitivePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| @@ -49,9 +49,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | ||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | ||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | ||||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | |||||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||||
| special_op_eliminate_ = | |||||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||||
| {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, | |||||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | ||||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | ||||
| @@ -35,11 +35,13 @@ class SpecialOpEliminater { | |||||
| public: | public: | ||||
| SpecialOpEliminater() | SpecialOpEliminater() | ||||
| : insert_gradient_of_(prim::kPrimInsertGradientOf), | : insert_gradient_of_(prim::kPrimInsertGradientOf), | ||||
| hook_backward_(prim::kPrimHookBackward), | |||||
| print_shape_type_(prim::kPrimPrintShapeType), | print_shape_type_(prim::kPrimPrintShapeType), | ||||
| get_ref_value_(prim::kPrimGetRefValue), | get_ref_value_(prim::kPrimGetRefValue), | ||||
| mirror_(prim::kPrimMirror), | mirror_(prim::kPrimMirror), | ||||
| virtual_div_(prim::kPrimVirtualDiv) { | virtual_div_(prim::kPrimVirtualDiv) { | ||||
| eliminaters_.emplace_back(insert_gradient_of_); | eliminaters_.emplace_back(insert_gradient_of_); | ||||
| eliminaters_.emplace_back(hook_backward_); | |||||
| eliminaters_.emplace_back(print_shape_type_); | eliminaters_.emplace_back(print_shape_type_); | ||||
| eliminaters_.emplace_back(get_ref_value_); | eliminaters_.emplace_back(get_ref_value_); | ||||
| eliminaters_.emplace_back(mirror_); | eliminaters_.emplace_back(mirror_); | ||||
| @@ -59,7 +61,7 @@ class SpecialOpEliminater { | |||||
| } | } | ||||
| private: | private: | ||||
| PrimEliminater insert_gradient_of_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; | |||||
| PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | std::vector<TransformFuncType> eliminaters_{}; | ||||
| }; | }; | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "operator/composite/composite.h" | #include "operator/composite/composite.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -207,6 +208,35 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| FuncGraphPtr ConvertToBpropCut(py::object obj) { | |||||
| std::vector<std::string> results = data_converter::GetObjKey(obj); | |||||
| std::string obj_key = results[0]; | |||||
| py::function bprop_func = py::getattr(obj, "bprop"); | |||||
| FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>(); | |||||
| std::vector<AnfNodePtr> outputs; | |||||
| auto fake_bprop = std::make_shared<Primitive>("bprop_cut"); | |||||
| fake_bprop->set_hook(bprop_func); | |||||
| (void)fake_bprop->AddAttr("bprop", MakeValue(true)); | |||||
| outputs.push_back(NewValueNode(fake_bprop)); | |||||
| py::object code_obj = py::getattr(bprop_func, "__code__"); | |||||
| size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3; | |||||
| for (size_t i = 0; i < inputs_num; ++i) { | |||||
| auto param = bprop_graph->add_parameter(); | |||||
| outputs.push_back(param); | |||||
| } | |||||
| auto p1 = bprop_graph->add_parameter(); | |||||
| auto p2 = bprop_graph->add_parameter(); | |||||
| outputs.push_back(p1); | |||||
| outputs.push_back(p2); | |||||
| bprop_graph->set_output(bprop_graph->NewCNode(outputs)); | |||||
| data_converter::SetObjGraphValue(obj_key, bprop_graph); | |||||
| return bprop_graph; | |||||
| } | |||||
| bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | ||||
| auto obj_type = data_converter::GetObjType(obj); | auto obj_type = data_converter::GetObjType(obj); | ||||
| MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; | MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; | ||||
| @@ -238,7 +268,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||||
| } | } | ||||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | // if the cell object has specified bprop, it has user-defined bprop function parse and record it | ||||
| if (py::hasattr(obj, "bprop")) { | if (py::hasattr(obj, "bprop")) { | ||||
| FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||||
| FuncGraphPtr bprop_graph = nullptr; | |||||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||||
| if (enable_bprop_debug) { | |||||
| bprop_graph = ConvertToBpropCut(obj); | |||||
| } else { | |||||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||||
| } | |||||
| if (bprop_graph != nullptr) { | if (bprop_graph != nullptr) { | ||||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | ||||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | ||||
| @@ -108,6 +108,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimRelu, {InferImplRelu, true}}, | {prim::kPrimRelu, {InferImplRelu, true}}, | ||||
| {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, | {prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}}, | ||||
| {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, | {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, | ||||
| {prim::kPrimBpropCut, {InferImplBpropCut, true}}, | |||||
| {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, | ||||
| {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, | ||||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | ||||
| @@ -210,6 +210,8 @@ AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const Primit | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -64,6 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) { | |||||
| result.outputs = outputs; | result.outputs = outputs; | ||||
| result.graph_id = kInvalidGraphId; | result.graph_id = kInvalidGraphId; | ||||
| auto graph_id = sess_->CompileGraph(lst, outputs); | auto graph_id = sess_->CompileGraph(lst, outputs); | ||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||||
| sess_->BuildGraph(graph_id); | |||||
| } | |||||
| if (MsContext::GetInstance()->precompile_only()) { | if (MsContext::GetInstance()->precompile_only()) { | ||||
| MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | ||||
| return result; | return result; | ||||
| @@ -40,9 +40,10 @@ using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>; | |||||
| using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>; | using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>; | ||||
| std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | ||||
| prim::kPrimMakeTuple}; | |||||
| prim::kPrimMakeTuple, prim::kPrimBpropCut}; | |||||
| const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | const std::vector<PrimitivePtr> &GetMsNonlinearOps() { | ||||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch}; | |||||
| static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, | |||||
| prim::kPrimBpropCut}; | |||||
| return ms_nonlinear_ops; | return ms_nonlinear_ops; | ||||
| } | } | ||||
| @@ -646,8 +647,13 @@ BackendPtr CreateBackend() { | |||||
| auto backend = std::make_shared<MsBackend>(name, target, device_id); | auto backend = std::make_shared<MsBackend>(name, target, device_id); | ||||
| std::string device_target = MsContext::GetInstance()->device_target(); | std::string device_target = MsContext::GetInstance()->device_target(); | ||||
| if (device_target == kAscendDevice) { | if (device_target == kAscendDevice) { | ||||
| backend->set_is_multi_graph_sink(true); | |||||
| context_ptr->set_is_multi_graph_sink(true); | |||||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { | |||||
| backend->set_is_multi_graph_sink(false); | |||||
| context_ptr->set_is_multi_graph_sink(false); | |||||
| } else { | |||||
| backend->set_is_multi_graph_sink(true); | |||||
| context_ptr->set_is_multi_graph_sink(true); | |||||
| } | |||||
| } | } | ||||
| return backend; | return backend; | ||||
| } | } | ||||
| @@ -587,15 +587,65 @@ void FinalVM::InstPushPrim(const VectorRef &args) { | |||||
| VectorRef tuple; | VectorRef tuple; | ||||
| auto prim = utils::cast<PrimitivePtr>(args[0]); | auto prim = utils::cast<PrimitivePtr>(args[0]); | ||||
| for (size_t i = 1; i < args.size(); ++i) { | for (size_t i = 1; i < args.size(); ++i) { | ||||
| auto index = utils::cast<int>(args[1]); | |||||
| auto index = utils::cast<int>(args[i]); | |||||
| tuple.push_back(Ref(index)); | tuple.push_back(Ref(index)); | ||||
| } | } | ||||
| auto outs = RunOperation(prim, tuple); | |||||
| Push(outs); | |||||
| if (prim->name() == "bprop_cut") { | |||||
| auto outs = RunHook(prim, tuple); | |||||
| Push(outs); | |||||
| } else { | |||||
| auto outs = RunOperation(prim, tuple); | |||||
| Push(outs); | |||||
| } | |||||
| MS_LOG(DEBUG) << "End"; | MS_LOG(DEBUG) << "End"; | ||||
| } | } | ||||
| BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||||
| py::tuple py_args = py::tuple(args.size()); | |||||
| MS_LOG(DEBUG) << "input for operation:"; | |||||
| size_t i = 0; | |||||
| for (auto &arg : args) { | |||||
| py_args[i] = BaseRefToPyData(arg); | |||||
| MS_LOG(DEBUG) << "arg: " << i << ":"; | |||||
| i++; | |||||
| } | |||||
| py::object obj; | |||||
| bool is_bprop = prim->HasAttr("bprop"); | |||||
| if (is_bprop) { | |||||
| py::function fn_bprop = prim->hook(); | |||||
| obj = fn_bprop(*py_args); | |||||
| return obj; | |||||
| } | |||||
| bool is_cell = prim->HasAttr("cell_hook"); | |||||
| if (is_cell) { | |||||
| std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||||
| if (_hook_grad.find(cell_id) != _hook_grad.end()) { | |||||
| py::tuple hook_args = py::tuple(3); | |||||
| hook_args[0] = cell_id; | |||||
| hook_args[1] = _hook_grad[cell_id]; | |||||
| hook_args[2] = py_args[2]; | |||||
| py::function fn_hook = prim->hook(); | |||||
| obj = fn_hook(*hook_args); | |||||
| if (py::isinstance<py::none>(obj)) { | |||||
| obj = py_args[2]; | |||||
| } | |||||
| _hook_grad.erase(cell_id); | |||||
| } else { | |||||
| _hook_grad[cell_id] = py_args[2]; | |||||
| obj = py_args[2]; | |||||
| } | |||||
| } else { | |||||
| py::function fn_hook = prim->hook(); | |||||
| obj = fn_hook(py_args[2]); | |||||
| if (py::isinstance<py::none>(obj)) { | |||||
| obj = py_args[2]; | |||||
| } | |||||
| } | |||||
| obj = py::make_tuple(obj); | |||||
| return obj; | |||||
| } | |||||
| } // namespace compile | } // namespace compile | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -115,6 +115,7 @@ class FinalVM { | |||||
| void InstPushPrim(const VectorRef &args); | void InstPushPrim(const VectorRef &args); | ||||
| void InstSwitchReturn(const VectorRef &args); | void InstSwitchReturn(const VectorRef &args); | ||||
| void set_insts(const InstSet &value) { insts_ = value; } | void set_insts(const InstSet &value) { insts_ = value; } | ||||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args); | |||||
| protected: | protected: | ||||
| BaseRef Ref(int i); | BaseRef Ref(int i); | ||||
| @@ -156,6 +157,7 @@ class FinalVM { | |||||
| {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | ||||
| {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | ||||
| }; | }; | ||||
| std::map<std::string, py::object> _hook_grad; | |||||
| }; | }; | ||||
| using FinalVMPtr = std::shared_ptr<FinalVM>; | using FinalVMPtr = std::shared_ptr<FinalVM>; | ||||
| @@ -24,6 +24,7 @@ from .._checkparam import _check_str_by_regular | |||||
| from ..common.parameter import Parameter, ParameterTuple | from ..common.parameter import Parameter, ParameterTuple | ||||
| from .._c_expression import init_backend | from .._c_expression import init_backend | ||||
| from ..ops.primitive import Primitive | from ..ops.primitive import Primitive | ||||
| from ..ops.operations import HookBackward | |||||
| from ..parallel._tensor import _load_tensor_by_layout | from ..parallel._tensor import _load_tensor_by_layout | ||||
| from ..common.tensor import Tensor | from ..common.tensor import Tensor | ||||
| @@ -75,6 +76,9 @@ class Cell: | |||||
| self._parallel_inputs_run = None | self._parallel_inputs_run = None | ||||
| if flags: | if flags: | ||||
| self.add_flags(**flags) | self.add_flags(**flags) | ||||
| self._backward_hook = None | |||||
| self._enable_hook = False | |||||
| self._bprop_debug = False | |||||
| @property | @property | ||||
| def create_time(self): | def create_time(self): | ||||
| @@ -91,6 +95,16 @@ class Cell: | |||||
| """ | """ | ||||
| return self._param_prefix | return self._param_prefix | ||||
| @property | |||||
| def bprop_debug(self): | |||||
| return self._bprop_debug | |||||
| @bprop_debug.setter | |||||
| def bprop_debug(self, value): | |||||
| if not isinstance(value, bool): | |||||
| raise TypeError("'bprop debug' value must be bool type.") | |||||
| self._bprop_debug = value | |||||
| def update_cell_prefix(self): | def update_cell_prefix(self): | ||||
| """ | """ | ||||
| Update the all child cells' self.param_prefix. | Update the all child cells' self.param_prefix. | ||||
| @@ -728,3 +742,25 @@ class Cell: | |||||
| self._auto_parallel_mode = True | self._auto_parallel_mode = True | ||||
| self.add_flags(auto_parallel=True) | self.add_flags(auto_parallel=True) | ||||
| self._get_construct_inputs_number_and_name() | self._get_construct_inputs_number_and_name() | ||||
| def _hook_construct(self, inputs): | |||||
| """Hook construct method to replace original construct method when hook function enabled.""" | |||||
| inputs = self._backward_hook(inputs) | |||||
| inputs = self.construct(inputs) | |||||
| outputs = self._backward_hook(inputs) | |||||
| return outputs | |||||
| @property | |||||
| def enable_hook(self): | |||||
| """Whether the cell register hook function""" | |||||
| return self._enable_hook | |||||
| def register_backward_hook(self, fn): | |||||
| """ | |||||
| Set the cell backward hook function. | |||||
| Args: | |||||
| fn (function): Specifies the hook function with grad as input. | |||||
| """ | |||||
| self._backward_hook = HookBackward(fn, str(id(self))) | |||||
| self._enable_hook = True | |||||
| @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | _MirrorOperator, ReduceOp, _VirtualDataset, | ||||
| _VirtualDiv, _GetTensorSlice) | _VirtualDiv, _GetTensorSlice) | ||||
| from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary, | |||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||||
| TensorSummary, HistogramSummary, Print) | TensorSummary, HistogramSummary, Print) | ||||
| from .control_ops import ControlDepend, GeSwitch, Merge | from .control_ops import ControlDepend, GeSwitch, Merge | ||||
| from .inner_ops import ScalarCast | from .inner_ops import ScalarCast | ||||
| @@ -155,6 +155,7 @@ __all__ = [ | |||||
| 'HistogramSummary', | 'HistogramSummary', | ||||
| "Print", | "Print", | ||||
| 'InsertGradientOf', | 'InsertGradientOf', | ||||
| 'HookBackward', | |||||
| 'InvertPermutation', | 'InvertPermutation', | ||||
| 'Shape', | 'Shape', | ||||
| 'DropoutDoMask', | 'DropoutDoMask', | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """debug_ops""" | """debug_ops""" | ||||
| from types import FunctionType | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | from ..primitive import prim_attr_register, PrimitiveWithInfer | ||||
| @@ -241,6 +242,65 @@ class InsertGradientOf(PrimitiveWithInfer): | |||||
| return x_type | return x_type | ||||
| class HookBackward(PrimitiveWithInfer): | |||||
| """ | |||||
| Used as tag to hook gradient in intermediate variables. | |||||
| Note: | |||||
| The hook function should have one input of gradient of the variable. | |||||
| hook function will be executed in python environment, while callback | |||||
| of InsertGradientOf will be parsed and added to the graph. | |||||
| Args: | |||||
| hook_fn (Function): Python function. hook function. | |||||
| Inputs: | |||||
| - **inputs** (Tensor) - The variable to hook. | |||||
| Examples: | |||||
| >>> def hook_fn(grad_out): | |||||
| >>> print(grad_out) | |||||
| >>> | |||||
| >>> hook = P.HookBackward(hook_fn) | |||||
| >>> | |||||
| >>> def hook_test(x, y): | |||||
| >>> z = x * y | |||||
| >>> z = hook(z) | |||||
| >>> z = z * y | |||||
| >>> return z | |||||
| >>> | |||||
| >>> def backward(x, y): | |||||
| >>> return C.grad_all(hook_test)(x, y) | |||||
| >>> | |||||
| >>> backward(1, 2) | |||||
| """ | |||||
| def __init__(self, hook_fn, cell_id=""): | |||||
| super(HookBackward, self).__init__(self.__class__.__name__) | |||||
| self.add_prim_attr("cell_id", cell_id) | |||||
| self.init_attrs["cell_id"] = cell_id | |||||
| if not isinstance(hook_fn, FunctionType): | |||||
| raise TypeError("Hook function should be python function type.") | |||||
| self.register_hook(hook_fn) | |||||
| self.cell_id = cell_id | |||||
| def __call__(self, *inputs): | |||||
| """run in PyNative mode.""" | |||||
| if len(inputs) == 1: | |||||
| return inputs[0] | |||||
| return inputs | |||||
| def infer_shape(self, *inputs_shape): | |||||
| if len(inputs_shape) == 1: | |||||
| return inputs_shape[0] | |||||
| return inputs_shape | |||||
| def infer_dtype(self, *inputs_type): | |||||
| if len(inputs_type) == 1: | |||||
| return inputs_type[0] | |||||
| return inputs_type | |||||
| class Print(PrimitiveWithInfer): | class Print(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Output tensor or string to stdout. | Output tensor or string to stdout. | ||||
| @@ -0,0 +1,133 @@ | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as P | |||||
| from mindspore import context | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore import context, Tensor, ParameterTuple | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.nn import Dense, WithLossCell, SoftmaxCrossEntropyWithLogits, Momentum | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
| """weight initial for conv layer""" | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d(in_channels, out_channels, | |||||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| """weight initial for fc layer""" | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| """weight initial""" | |||||
| return TruncatedNormal(0.02) | |||||
| def cell_hook_function(cell_id, grad_input, grad_output): | |||||
| print(cell_id) | |||||
| assert(grad_output.asnumpy().shape == (32, 6, 14, 14)) | |||||
| assert(grad_input.asnumpy().shape == (32, 16, 10, 10)) | |||||
| def var_hook_function(grad_out): | |||||
| print("grad:", grad_out) | |||||
| assert(grad_out.asnumpy().shape == (32, 120)) | |||||
| class LeNet5(nn.Cell): | |||||
| """ | |||||
| Lenet network | |||||
| Args: | |||||
| num_class (int): Num classes. Default: 10. | |||||
| Returns: | |||||
| Tensor, output tensor | |||||
| Examples: | |||||
| >>> LeNet(num_class=10) | |||||
| """ | |||||
| def __init__(self, num_class=10): | |||||
| super(LeNet5, self).__init__() | |||||
| self.num_class = num_class | |||||
| self.batch_size = 32 | |||||
| self.conv1 = conv(1, 6, 5) | |||||
| self.conv2 = conv(6, 16, 5) | |||||
| self.conv2.register_backward_hook(cell_hook_function) | |||||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||||
| self.fc2 = fc_with_initialize(120, 84) | |||||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||||
| self.relu = nn.ReLU() | |||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
| self.reshape = P.Reshape() | |||||
| self.hook = P.HookBackward(var_hook_function) | |||||
| def construct(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv2(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.reshape(x, (self.batch_size, -1)) | |||||
| x = self.fc1(x) | |||||
| x = self.hook(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc3(x) | |||||
| return x | |||||
| class GradWrap(nn.Cell): | |||||
| """ GradWrap definition """ | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) | |||||
| def construct(self, x, label): | |||||
| weights = self.weights | |||||
| return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label) | |||||
| def test_hook(): | |||||
| net = LeNet5() | |||||
| optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False) | |||||
| net_with_criterion = WithLossCell(net, criterion) | |||||
| train_network = GradWrap(net_with_criterion) | |||||
| train_network.set_train() | |||||
| input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) | |||||
| label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) | |||||
| output = net(Tensor(input_data)) | |||||
| loss_output = criterion(output, label) | |||||
| grads = train_network(input_data, label) | |||||
| success = optimizer(grads) | |||||
| print(loss_output.asnumpy().shape) | |||||
| class MulAdd(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MulAdd, self).__init__() | |||||
| def construct(self, x, y): | |||||
| return 2 * x + y | |||||
| def bprop(self, x, y, out, dout): | |||||
| assert(x == 1) | |||||
| assert(y == 2) | |||||
| assert(out == 4) | |||||
| assert(dout == 1) | |||||
| return 3 * dout, 2 * y | |||||
| def test_custom_bprop(): | |||||
| mul_add = MulAdd() | |||||
| mul_add.bprop_debug = True | |||||
| assert C.grad_all(mul_add)(1, 2) == (3, 4) | |||||