Browse Source

!2148 fix hook and bprop debug issue in pynative

Merge pull request !2148 from wangqiuliang/fix-hook-bprop-issue
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
9ba6f61d01
18 changed files with 158 additions and 74 deletions
  1. +18
    -0
      mindspore/_extends/builtin_operations.py
  2. +31
    -31
      mindspore/ccsrc/pipeline/parse/data_converter.cc
  3. +1
    -0
      mindspore/ccsrc/pipeline/parse/data_converter.h
  4. +1
    -0
      mindspore/ccsrc/pipeline/parse/parse_base.h
  5. +1
    -1
      mindspore/ccsrc/pynative/base.h
  6. +32
    -12
      mindspore/ccsrc/pynative/pynative_execute.cc
  7. +2
    -2
      mindspore/ccsrc/pynative/pynative_execute.h
  8. +10
    -2
      mindspore/common/_register_for_tensor.py
  9. +11
    -10
      mindspore/common/tensor.py
  10. +1
    -1
      mindspore/nn/cell.py
  11. +10
    -0
      mindspore/nn/layer/container.py
  12. +1
    -1
      mindspore/ops/composite/base.py
  13. +5
    -2
      mindspore/ops/functional.py
  14. +2
    -0
      mindspore/ops/operations/array_ops.py
  15. +1
    -7
      mindspore/ops/primitive.py
  16. +2
    -3
      tests/ut/cpp/pynative/pynative_execute_test.cc
  17. +1
    -1
      tests/ut/python/ir/test_tensor.py
  18. +28
    -1
      tests/ut/python/pynative_mode/test_hook.py

+ 18
- 0
mindspore/_extends/builtin_operations.py View File

@@ -113,6 +113,24 @@ def bool_or(x, y):
"""Implement `bool_or`."""
return x or y

def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))


def make_list(*xs):
"""Implement `make_list`."""


+ 31
- 31
mindspore/ccsrc/pipeline/parse/data_converter.cc View File

@@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr;
using MetaTensor = mindspore::tensor::MetaTensor;
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;

FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);

auto bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;

auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));

py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);

bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}

namespace {
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple";
@@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
return true;
}

FuncGraphPtr ConvertToBpropCut(py::object obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, "bprop");

FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;

auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));

py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);

bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}

bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
@@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
return false;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
@@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}


+ 1
- 0
mindspore/ccsrc/pipeline/parse/data_converter.h View File

@@ -51,6 +51,7 @@ void ClearObjectCache();
} // namespace data_converter

ClassPtr ParseDataClass(const py::object &cls_obj);
FuncGraphPtr ConvertToBpropCut(const py::object &obj);

void CleanDataClassToClassMap();



+ 1
- 0
mindspore/ccsrc/pipeline/parse/parse_base.h View File

@@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";

// define the parse constant
const int MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";

// define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace


+ 1
- 1
mindspore/ccsrc/pynative/base.h View File

@@ -45,7 +45,7 @@ enum PynativeStatusCode {
PYNATIVE_UNKNOWN_STATE = 0XFF
};

enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_INPUT_MASK, PY_ARGS_NUM };
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };

struct OpExecInfo {
PrimitivePyPtr py_primitive;


+ 32
- 12
mindspore/ccsrc/pynative/pynative_execute.cc View File

@@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
return obj_tuple;
}

void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
auto &py_args = *out_args;
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) {
if (py::hasattr(args[i], "__parameter__")) {
input_mask[i] = true;
} else {
input_mask[i] = false;
}
py_args[i] = GetTupleObj(args[i]);
}
auto signature = prim->signatures();
@@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
[](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return;
return input_mask;
}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) {
@@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
continue;
}
}
return input_mask;
}

void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
@@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]);
if (input_value->isa<tensor::Tensor>()) {
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
} else {
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
@@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn

OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Four args are needed by RunOp";
MS_LOG(ERROR) << "Three args are needed by RunOp";
return nullptr;
}
auto op_exec_info = std::make_shared<OpExecInfo>();
@@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
size_t input_num = a.size();
op_exec_info->op_inputs = py::tuple(input_num);

ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr;
@@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return result;
}

AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.size() == 0) {
return nullptr;
}
std::vector<AnfNodePtr> inputs;
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto prim = op_exec_info->py_primitive;
inputs.push_back(NewValueNode(prim));
py::tuple op_masks = args[PY_INPUT_MASK];
py::tuple op_masks = op_exec_info->inputs_mask;
py::list op_args = args[PY_INPUTS];
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < op_args.size(); i++) {
@@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
return err_ret;
}

auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (node != nullptr) {
node->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
@@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
cell_graph_map_[cell_id] = curr_g_;
auto out_id = GetId(out);
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
// cell construct return x, y
if (py::isinstance<py::tuple>(out)) {
std::vector<AnfNodePtr> args;
@@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
}

auto output_node = GetObjNode(out);
AnfNodePtr output_node;
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
output_node = graph_info_map_[curr_g_].param_map[out_id];
} else {
output_node = GetObjNode(out);
}
curr_g_->set_output(output_node);
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(curr_g_));
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
MS_LOG(DEBUG) << "Use cell custom bprop function.";
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
if (bprop_graph != nullptr) {
(void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (curr_g_ != top_g_) {
Popp();


+ 2
- 2
mindspore/ccsrc/pynative/pynative_execute.h View File

@@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat

py::tuple RunOp(const py::args &args);

void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);

void ClearPyNativeSession();

@@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase);

void Pushp();


+ 10
- 2
mindspore/common/_register_for_tensor.py View File

@@ -16,6 +16,7 @@
"""Registry the relation."""

from collections import UserDict
from .. import context


class Registry(UserDict):
@@ -27,9 +28,16 @@ class Registry(UserDict):

def get(self, obj_str):
"""Get the value by str."""
if isinstance(obj_str, str):
if not isinstance(obj_str, str):
raise TypeError("key for tensor registry must be string.")
if context.get_context("enable_ge"):
def wrap(*args):
new_args = list(args)
new_args.append(obj_str)
return self["vm_compare"](*new_args)
obj = wrap
else:
obj = self[obj_str]
return obj


tensor_operator_registry = Registry()

+ 11
- 10
mindspore/common/tensor.py View File

@@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor
from .._checkparam import check_type, check_typename
from . import dtype as mstype
from .. import context
from ._register_for_tensor import tensor_operator_registry

__all__ = ['Tensor', 'MetaTensor']
@@ -76,17 +75,19 @@ class Tensor(Tensor_):
return out

def __eq__(self, other):
if not isinstance(other, Tensor):
if not isinstance(other, (int, float, Tensor)):
return False
# The GE backend don't support single `Equal` operator execution.
# bool type is not supported for `Equal` operator in backend.
if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_:
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)

def __ne__(self, other):
if not isinstance(other, Tensor):
if not isinstance(other, (int, float, Tensor)):
return True
# bool type is not supported for `NotEqual` operator in backend.
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
return tensor_operator_registry.get('__ne__')(self, other)

def __hash__(self):
@@ -105,7 +106,7 @@ class Tensor(Tensor_):
return out

def __radd__(self, other):
out = tensor_operator_registry.get('__add__')(other, self)
out = tensor_operator_registry.get('__add__')(self, other)
return out

def __imul__(self, other):
@@ -113,15 +114,15 @@ class Tensor(Tensor_):
return out

def __rmul__(self, other):
out = tensor_operator_registry.get('__mul__')(other, self)
out = tensor_operator_registry.get('__mul__')(self, other)
return out

def __truediv__(self, other):
out = tensor_operator_registry.get('__div__')(self, other)
out = tensor_operator_registry.get('__truediv__')(self, other)
return out

def __rtruediv__(self, other):
out = tensor_operator_registry.get('__div__')(other, self)
out = tensor_operator_registry.get('__truediv__')(other, self)
return out

def __sub__(self, other):
@@ -160,7 +161,7 @@ class Tensor(Tensor_):
return out

def __len__(self):
out = tensor_operator_registry.get('__shape__')(self)
out = tensor_operator_registry.get('shape')(self)
if not out:
return 1
return out[0]


+ 1
- 1
mindspore/nn/cell.py View File

@@ -819,4 +819,4 @@ class Cell:

"""
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True
self.enable_hook = True

+ 10
- 0
mindspore/nn/layer/container.py View File

@@ -140,6 +140,11 @@ class SequentialCell(Cell):
def __len__(self):
return len(self._cells)

def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)

def construct(self, input_data):
for cell in self.cell_list:
input_data = cell(input_data)
@@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell):
self._cells[str(len(self))] = cell
return self

def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)

def construct(self, *inputs):
raise NotImplementedError

+ 1
- 1
mindspore/ops/composite/base.py View File

@@ -112,7 +112,7 @@ class GradOperation(GradOperation_):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn:
if self.get_by_list:
if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug:
if context.get_context("mode") == context.GRAPH_MODE:
@ms_function(obj=fn)
def after_grad(*args):
return grad_(fn, weights)(*args)


+ 5
- 2
mindspore/ops/functional.py View File

@@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from .primitive import Primitive
from . import operations as P
from .operations import _grad_ops
from .._extends import builtin_operations as BP

typeof = Primitive('typeof')
hastype = Primitive('hastype')
@@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div)
tensor_operator_registry.register('__truediv__', tensor_div)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
@@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__shape__', shape)
tensor_operator_registry.register('shape', shape)
#support GE backend for no compare operators
tensor_operator_registry.register('vm_compare', BP.vm_compare)

+ 2
- 0
mindspore/ops/operations/array_ops.py View File

@@ -863,6 +863,8 @@ class TupleToArray(PrimitiveWithInfer):
args = list()
if isinstance(x, range):
args.append(tuple(x))
else:
args.append(x)
return _run_op(self, self.name, args)




+ 1
- 7
mindspore/ops/primitive.py View File

@@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None):
@_wrap_func
def _run_op(obj, op_name, args):
"""Single op execution function supported by ge in PyNative mode."""
op_mask = [0] * len(args)
op_inputs = []
for i, arg in enumerate(args):
if hasattr(arg, '__parameter__'):
op_mask[i] = 1
op_inputs.append(arg)
output = real_run_op(obj, op_name, args, tuple(op_mask))
output = real_run_op(obj, op_name, args)
if not output:
raise RuntimeError("Pynative run op %s failed!" % op_name)
if len(output) == 1:


+ 2
- 3
tests/ut/cpp/pynative/pynative_execute_test.cc View File

@@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() {

auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative");
py::none py_none;
py::tuple op_mask = py::make_tuple(0, 1);
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs, op_mask));
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs));
}

TEST_F(TestPynativeExecute, TestRunOpInVM) {
@@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) {
py::none py_none;
auto op_exec_info_ptr = ConstructOpExecInfo();
py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name,
op_exec_info_ptr->op_inputs, op_exec_info_ptr->inputs_mask));
op_exec_info_ptr->op_inputs));
if (outputs.size() == 0) {
FAIL();
} else {


+ 1
- 1
tests/ut/python/ir/test_tensor.py View File

@@ -452,5 +452,5 @@ def test_tensor_operation():
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
res = 8 / x
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
with pytest.raises(TypeError):
with pytest.raises(ValueError):
res = x * (2, 3)

+ 28
- 1
tests/ut/python/pynative_mode/test_hook.py View File

@@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum
from mindspore.ops import composite as C

context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
cell_hook_done = False
var_hook_done = False
cell_bprop_done = False


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
@@ -32,15 +35,35 @@ def weight_variable():

def cell_hook_function(cell_id, grad_input, grad_output):
print(cell_id)
global cell_hook_done
cell_hook_done = True
assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))


def var_hook_function(grad_out):
print("grad:", grad_out)
global var_hook_done
var_hook_done = True
assert (grad_out[0].asnumpy().shape == (32, 120))


class Block(nn.Cell):
def __init__(self):
super(Block, self).__init__()
self.relu = nn.ReLU()

def construct(self, x):
x = self.relu(x)
return x

def bprop(self, x, out, dout):
global cell_bprop_done
cell_bprop_done = True
grad = out.asnumpy() * dout.asnumpy()
grad = Tensor(grad)
return (grad,)

class LeNet5(nn.Cell):
"""
Lenet network
@@ -59,6 +82,7 @@ class LeNet5(nn.Cell):
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.conv2.register_backward_hook(cell_hook_function)
self.block = Block()
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
@@ -72,7 +96,7 @@ class LeNet5(nn.Cell):
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.block(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
@@ -110,6 +134,9 @@ def test_hook():
loss_output = criterion(output, label)
grads = train_network(input_data, label)
success = optimizer(grads)
assert cell_hook_done
assert var_hook_done
assert cell_bprop_done
print(loss_output.asnumpy().shape)




Loading…
Cancel
Save