Merge pull request !2148 from wangqiuliang/fix-hook-bprop-issuetags/v0.5.0-beta
| @@ -113,6 +113,24 @@ def bool_or(x, y): | |||
| """Implement `bool_or`.""" | |||
| return x or y | |||
| def vm_compare(*args): | |||
| """Implement `vm_compare` for tensor.""" | |||
| obj_str = args[-1] | |||
| if obj_str == "shape": | |||
| fn = getattr(args[0].asnumpy(), obj_str) | |||
| return fn | |||
| if len(args) == 2: | |||
| fn = getattr(args[0].asnumpy(), obj_str) | |||
| return Tensor(fn()) | |||
| if isinstance(args[0], Tensor): | |||
| fn = getattr(args[0].asnumpy(), obj_str) | |||
| y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1] | |||
| else: | |||
| obj_str = "__r" + obj_str[2:] | |||
| fn = getattr(args[1].asnumpy(), obj_str) | |||
| y = args[0] | |||
| return Tensor(np.array(fn(y))) | |||
| def make_list(*xs): | |||
| """Implement `make_list`.""" | |||
| @@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr; | |||
| using MetaTensor = mindspore::tensor::MetaTensor; | |||
| using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; | |||
| FuncGraphPtr ConvertToBpropCut(const py::object &obj) { | |||
| std::vector<std::string> results = data_converter::GetObjKey(obj); | |||
| std::string obj_key = results[0]; | |||
| py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME); | |||
| auto bprop_graph = std::make_shared<FuncGraph>(); | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object()); | |||
| fake_bprop->set_hook(bprop_func); | |||
| (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); | |||
| outputs.push_back(NewValueNode(fake_bprop)); | |||
| py::object code_obj = py::getattr(bprop_func, "__code__"); | |||
| size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3; | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| auto param = bprop_graph->add_parameter(); | |||
| outputs.push_back(param); | |||
| } | |||
| auto p1 = bprop_graph->add_parameter(); | |||
| auto p2 = bprop_graph->add_parameter(); | |||
| outputs.push_back(p1); | |||
| outputs.push_back(p2); | |||
| bprop_graph->set_output(bprop_graph->NewCNode(outputs)); | |||
| data_converter::SetObjGraphValue(obj_key, bprop_graph); | |||
| return bprop_graph; | |||
| } | |||
| namespace { | |||
| bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { | |||
| MS_LOG(DEBUG) << "Converting python tuple"; | |||
| @@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { | |||
| return true; | |||
| } | |||
| FuncGraphPtr ConvertToBpropCut(py::object obj) { | |||
| std::vector<std::string> results = data_converter::GetObjKey(obj); | |||
| std::string obj_key = results[0]; | |||
| py::function bprop_func = py::getattr(obj, "bprop"); | |||
| FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>(); | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object()); | |||
| fake_bprop->set_hook(bprop_func); | |||
| (void)fake_bprop->AddAttr("bprop", MakeValue(true)); | |||
| outputs.push_back(NewValueNode(fake_bprop)); | |||
| py::object code_obj = py::getattr(bprop_func, "__code__"); | |||
| size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3; | |||
| for (size_t i = 0; i < inputs_num; ++i) { | |||
| auto param = bprop_graph->add_parameter(); | |||
| outputs.push_back(param); | |||
| } | |||
| auto p1 = bprop_graph->add_parameter(); | |||
| auto p2 = bprop_graph->add_parameter(); | |||
| outputs.push_back(p1); | |||
| outputs.push_back(p2); | |||
| bprop_graph->set_output(bprop_graph->NewCNode(outputs)); | |||
| data_converter::SetObjGraphValue(obj_key, bprop_graph); | |||
| return bprop_graph; | |||
| } | |||
| bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | |||
| if (func_graph == nullptr) { | |||
| @@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||
| return false; | |||
| } | |||
| // if the cell object has specified bprop, it has user-defined bprop function parse and record it | |||
| if (py::hasattr(obj, "bprop")) { | |||
| if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { | |||
| FuncGraphPtr bprop_graph = nullptr; | |||
| bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); | |||
| if (enable_bprop_debug) { | |||
| @@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||
| bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); | |||
| } | |||
| if (bprop_graph != nullptr) { | |||
| (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); | |||
| (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); | |||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); | |||
| func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||
| } | |||
| @@ -51,6 +51,7 @@ void ClearObjectCache(); | |||
| } // namespace data_converter | |||
| ClassPtr ParseDataClass(const py::object &cls_obj); | |||
| FuncGraphPtr ConvertToBpropCut(const py::object &obj); | |||
| void CleanDataClassToClassMap(); | |||
| @@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; | |||
| // define the parse constant | |||
| const int MAX_COMPARISON_OPS_SUPPORTED = 1; | |||
| const char CUSTOM_BPROP_NAME[] = "bprop"; | |||
| // define the Namespace name | |||
| const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace | |||
| @@ -45,7 +45,7 @@ enum PynativeStatusCode { | |||
| PYNATIVE_UNKNOWN_STATE = 0XFF | |||
| }; | |||
| enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_INPUT_MASK, PY_ARGS_NUM }; | |||
| enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; | |||
| struct OpExecInfo { | |||
| PrimitivePyPtr py_primitive; | |||
| @@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) { | |||
| return obj_tuple; | |||
| } | |||
| void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { | |||
| py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { | |||
| auto &py_args = *out_args; | |||
| py::tuple input_mask(args.size()); | |||
| for (size_t i = 0; i < args.size(); ++i) { | |||
| if (py::hasattr(args[i], "__parameter__")) { | |||
| input_mask[i] = true; | |||
| } else { | |||
| input_mask[i] = false; | |||
| } | |||
| py_args[i] = GetTupleObj(args[i]); | |||
| } | |||
| auto signature = prim->signatures(); | |||
| @@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple * | |||
| [](const Signature &sig) { return sig.dtype; }); | |||
| int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); | |||
| if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { | |||
| return; | |||
| return input_mask; | |||
| } | |||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; | |||
| for (size_t i = 0; i < dtypes.size(); ++i) { | |||
| @@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple * | |||
| continue; | |||
| } | |||
| } | |||
| return input_mask; | |||
| } | |||
| void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { | |||
| @@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < size; i++) { | |||
| ValuePtr input_value = PyAttrValue(py_args[i]); | |||
| if (input_value->isa<tensor::Tensor>()) { | |||
| if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) { | |||
| args_spec_list.emplace_back(abstract::FromValueInside(input_value, true)); | |||
| } else { | |||
| args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); | |||
| @@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn | |||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { | |||
| if (args.size() != PY_ARGS_NUM) { | |||
| MS_LOG(ERROR) << "Four args are needed by RunOp"; | |||
| MS_LOG(ERROR) << "Three args are needed by RunOp"; | |||
| return nullptr; | |||
| } | |||
| auto op_exec_info = std::make_shared<OpExecInfo>(); | |||
| @@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { | |||
| size_t input_num = a.size(); | |||
| op_exec_info->op_inputs = py::tuple(input_num); | |||
| ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs); | |||
| op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs); | |||
| // use python infer method | |||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | |||
| PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); | |||
| } | |||
| op_exec_info->py_primitive = prim; | |||
| op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); | |||
| op_exec_info->inputs_mask = args[PY_INPUT_MASK]; | |||
| if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { | |||
| MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; | |||
| return nullptr; | |||
| @@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn | |||
| return result; | |||
| } | |||
| AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) { | |||
| AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { | |||
| if (!grad_flag_ || graph_info_map_.size() == 0) { | |||
| return nullptr; | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | |||
| auto prim = op_exec_info->py_primitive; | |||
| inputs.push_back(NewValueNode(prim)); | |||
| py::tuple op_masks = args[PY_INPUT_MASK]; | |||
| py::tuple op_masks = op_exec_info->inputs_mask; | |||
| py::list op_args = args[PY_INPUTS]; | |||
| AbstractBasePtrList args_spec_list; | |||
| for (size_t i = 0; i < op_args.size(); i++) { | |||
| @@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) { | |||
| return err_ret; | |||
| } | |||
| auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result); | |||
| auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); | |||
| if (node != nullptr) { | |||
| node->set_abstract(op_exec_info->abstract); | |||
| MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); | |||
| @@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c | |||
| } | |||
| cell_graph_map_[cell_id] = curr_g_; | |||
| auto out_id = GetId(out); | |||
| if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) { | |||
| if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { | |||
| // cell construct return x, y | |||
| if (py::isinstance<py::tuple>(out)) { | |||
| std::vector<AnfNodePtr> args; | |||
| @@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c | |||
| } | |||
| } | |||
| auto output_node = GetObjNode(out); | |||
| AnfNodePtr output_node; | |||
| if (graph_info_map_[curr_g_].param_map.count(out_id)) { | |||
| output_node = graph_info_map_[curr_g_].param_map[out_id]; | |||
| } else { | |||
| output_node = GetObjNode(out); | |||
| } | |||
| curr_g_->set_output(output_node); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(curr_g_)); | |||
| MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); | |||
| resource_->manager()->AddFuncGraph(curr_g_); | |||
| // custom bprop debug | |||
| if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { | |||
| MS_LOG(DEBUG) << "Use cell custom bprop function."; | |||
| FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); | |||
| if (bprop_graph != nullptr) { | |||
| (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); | |||
| (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); | |||
| } | |||
| } | |||
| auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); | |||
| if (curr_g_ != top_g_) { | |||
| Popp(); | |||
| @@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||
| py::tuple RunOp(const py::args &args); | |||
| void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); | |||
| py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); | |||
| void ClearPyNativeSession(); | |||
| @@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { | |||
| graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); | |||
| } | |||
| AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out); | |||
| AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); | |||
| py::object Run(const py::tuple &args, const py::object &phase); | |||
| void Pushp(); | |||
| @@ -16,6 +16,7 @@ | |||
| """Registry the relation.""" | |||
| from collections import UserDict | |||
| from .. import context | |||
| class Registry(UserDict): | |||
| @@ -27,9 +28,16 @@ class Registry(UserDict): | |||
| def get(self, obj_str): | |||
| """Get the value by str.""" | |||
| if isinstance(obj_str, str): | |||
| if not isinstance(obj_str, str): | |||
| raise TypeError("key for tensor registry must be string.") | |||
| if context.get_context("enable_ge"): | |||
| def wrap(*args): | |||
| new_args = list(args) | |||
| new_args.append(obj_str) | |||
| return self["vm_compare"](*new_args) | |||
| obj = wrap | |||
| else: | |||
| obj = self[obj_str] | |||
| return obj | |||
| tensor_operator_registry = Registry() | |||
| @@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_ | |||
| from .._c_expression import MetaTensor | |||
| from .._checkparam import check_type, check_typename | |||
| from . import dtype as mstype | |||
| from .. import context | |||
| from ._register_for_tensor import tensor_operator_registry | |||
| __all__ = ['Tensor', 'MetaTensor'] | |||
| @@ -76,17 +75,19 @@ class Tensor(Tensor_): | |||
| return out | |||
| def __eq__(self, other): | |||
| if not isinstance(other, Tensor): | |||
| if not isinstance(other, (int, float, Tensor)): | |||
| return False | |||
| # The GE backend don't support single `Equal` operator execution. | |||
| # bool type is not supported for `Equal` operator in backend. | |||
| if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_: | |||
| if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_): | |||
| return Tensor(np.array(self.asnumpy() == other.asnumpy())) | |||
| return tensor_operator_registry.get('__eq__')(self, other) | |||
| def __ne__(self, other): | |||
| if not isinstance(other, Tensor): | |||
| if not isinstance(other, (int, float, Tensor)): | |||
| return True | |||
| # bool type is not supported for `NotEqual` operator in backend. | |||
| if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_): | |||
| return Tensor(np.array(self.asnumpy() != other.asnumpy())) | |||
| return tensor_operator_registry.get('__ne__')(self, other) | |||
| def __hash__(self): | |||
| @@ -105,7 +106,7 @@ class Tensor(Tensor_): | |||
| return out | |||
| def __radd__(self, other): | |||
| out = tensor_operator_registry.get('__add__')(other, self) | |||
| out = tensor_operator_registry.get('__add__')(self, other) | |||
| return out | |||
| def __imul__(self, other): | |||
| @@ -113,15 +114,15 @@ class Tensor(Tensor_): | |||
| return out | |||
| def __rmul__(self, other): | |||
| out = tensor_operator_registry.get('__mul__')(other, self) | |||
| out = tensor_operator_registry.get('__mul__')(self, other) | |||
| return out | |||
| def __truediv__(self, other): | |||
| out = tensor_operator_registry.get('__div__')(self, other) | |||
| out = tensor_operator_registry.get('__truediv__')(self, other) | |||
| return out | |||
| def __rtruediv__(self, other): | |||
| out = tensor_operator_registry.get('__div__')(other, self) | |||
| out = tensor_operator_registry.get('__truediv__')(other, self) | |||
| return out | |||
| def __sub__(self, other): | |||
| @@ -160,7 +161,7 @@ class Tensor(Tensor_): | |||
| return out | |||
| def __len__(self): | |||
| out = tensor_operator_registry.get('__shape__')(self) | |||
| out = tensor_operator_registry.get('shape')(self) | |||
| if not out: | |||
| return 1 | |||
| return out[0] | |||
| @@ -819,4 +819,4 @@ class Cell: | |||
| """ | |||
| self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") | |||
| self._enable_hook = True | |||
| self.enable_hook = True | |||
| @@ -140,6 +140,11 @@ class SequentialCell(Cell): | |||
| def __len__(self): | |||
| return len(self._cells) | |||
| def set_grad(self, flag=True): | |||
| self.requires_grad = flag | |||
| for cell in self._cells.values(): | |||
| cell.set_grad(flag) | |||
| def construct(self, input_data): | |||
| for cell in self.cell_list: | |||
| input_data = cell(input_data) | |||
| @@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell): | |||
| self._cells[str(len(self))] = cell | |||
| return self | |||
| def set_grad(self, flag=True): | |||
| self.requires_grad = flag | |||
| for cell in self._cells.values(): | |||
| cell.set_grad(flag) | |||
| def construct(self, *inputs): | |||
| raise NotImplementedError | |||
| @@ -112,7 +112,7 @@ class GradOperation(GradOperation_): | |||
| grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) | |||
| if self.grad_fn is None or self.fn != fn: | |||
| if self.get_by_list: | |||
| if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| @ms_function(obj=fn) | |||
| def after_grad(*args): | |||
| return grad_(fn, weights)(*args) | |||
| @@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry | |||
| from .primitive import Primitive | |||
| from . import operations as P | |||
| from .operations import _grad_ops | |||
| from .._extends import builtin_operations as BP | |||
| typeof = Primitive('typeof') | |||
| hastype = Primitive('hastype') | |||
| @@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient") | |||
| tensor_operator_registry.register('__add__', tensor_add) | |||
| tensor_operator_registry.register('__sub__', tensor_sub) | |||
| tensor_operator_registry.register('__mul__', tensor_mul) | |||
| tensor_operator_registry.register('__div__', tensor_div) | |||
| tensor_operator_registry.register('__truediv__', tensor_div) | |||
| #ms cannot support Tensor(True) compare | |||
| tensor_operator_registry.register('__eq__', equal) | |||
| tensor_operator_registry.register('__ne__', not_equal) | |||
| @@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt) | |||
| tensor_operator_registry.register('__le__', tensor_le) | |||
| tensor_operator_registry.register('__gt__', tensor_gt) | |||
| tensor_operator_registry.register('__ge__', tensor_ge) | |||
| tensor_operator_registry.register('__shape__', shape) | |||
| tensor_operator_registry.register('shape', shape) | |||
| #support GE backend for no compare operators | |||
| tensor_operator_registry.register('vm_compare', BP.vm_compare) | |||
| @@ -863,6 +863,8 @@ class TupleToArray(PrimitiveWithInfer): | |||
| args = list() | |||
| if isinstance(x, range): | |||
| args.append(tuple(x)) | |||
| else: | |||
| args.append(x) | |||
| return _run_op(self, self.name, args) | |||
| @@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None): | |||
| @_wrap_func | |||
| def _run_op(obj, op_name, args): | |||
| """Single op execution function supported by ge in PyNative mode.""" | |||
| op_mask = [0] * len(args) | |||
| op_inputs = [] | |||
| for i, arg in enumerate(args): | |||
| if hasattr(arg, '__parameter__'): | |||
| op_mask[i] = 1 | |||
| op_inputs.append(arg) | |||
| output = real_run_op(obj, op_name, args, tuple(op_mask)) | |||
| output = real_run_op(obj, op_name, args) | |||
| if not output: | |||
| raise RuntimeError("Pynative run op %s failed!" % op_name) | |||
| if len(output) == 1: | |||
| @@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() { | |||
| auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative"); | |||
| py::none py_none; | |||
| py::tuple op_mask = py::make_tuple(0, 1); | |||
| return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs, op_mask)); | |||
| return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs)); | |||
| } | |||
| TEST_F(TestPynativeExecute, TestRunOpInVM) { | |||
| @@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) { | |||
| py::none py_none; | |||
| auto op_exec_info_ptr = ConstructOpExecInfo(); | |||
| py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, | |||
| op_exec_info_ptr->op_inputs, op_exec_info_ptr->inputs_mask)); | |||
| op_exec_info_ptr->op_inputs)); | |||
| if (outputs.size() == 0) { | |||
| FAIL(); | |||
| } else { | |||
| @@ -452,5 +452,5 @@ def test_tensor_operation(): | |||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) | |||
| res = 8 / x | |||
| assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(ValueError): | |||
| res = x * (2, 3) | |||
| @@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| cell_hook_done = False | |||
| var_hook_done = False | |||
| cell_bprop_done = False | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| @@ -32,15 +35,35 @@ def weight_variable(): | |||
| def cell_hook_function(cell_id, grad_input, grad_output): | |||
| print(cell_id) | |||
| global cell_hook_done | |||
| cell_hook_done = True | |||
| assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14)) | |||
| assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10)) | |||
| def var_hook_function(grad_out): | |||
| print("grad:", grad_out) | |||
| global var_hook_done | |||
| var_hook_done = True | |||
| assert (grad_out[0].asnumpy().shape == (32, 120)) | |||
| class Block(nn.Cell): | |||
| def __init__(self): | |||
| super(Block, self).__init__() | |||
| self.relu = nn.ReLU() | |||
| def construct(self, x): | |||
| x = self.relu(x) | |||
| return x | |||
| def bprop(self, x, out, dout): | |||
| global cell_bprop_done | |||
| cell_bprop_done = True | |||
| grad = out.asnumpy() * dout.asnumpy() | |||
| grad = Tensor(grad) | |||
| return (grad,) | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| @@ -59,6 +82,7 @@ class LeNet5(nn.Cell): | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.conv2.register_backward_hook(cell_hook_function) | |||
| self.block = Block() | |||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||
| @@ -72,7 +96,7 @@ class LeNet5(nn.Cell): | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.block(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.reshape(x, (self.batch_size, -1)) | |||
| x = self.fc1(x) | |||
| @@ -110,6 +134,9 @@ def test_hook(): | |||
| loss_output = criterion(output, label) | |||
| grads = train_network(input_data, label) | |||
| success = optimizer(grads) | |||
| assert cell_hook_done | |||
| assert var_hook_done | |||
| assert cell_bprop_done | |||
| print(loss_output.asnumpy().shape) | |||