Browse Source

!1427 fix check bprop attr error

Merge pull request !1427 from panyifeng/fix_check_bprop_attr_error
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
0b191615a9
7 changed files with 38 additions and 8 deletions
  1. +14
    -4
      mindspore/ccsrc/optimizer/ad/kprim.cc
  2. +3
    -1
      mindspore/ccsrc/pipeline/init.cc
  3. +3
    -0
      mindspore/ccsrc/utils/context/ms_context.h
  4. +10
    -1
      mindspore/context.py
  5. +2
    -1
      mindspore/ops/operations/other_ops.py
  6. +3
    -0
      tests/ut/python/pynative_mode/test_cell_bprop.py
  7. +3
    -1
      tests/ut/python/pynative_mode/test_framstruct.py

+ 14
- 4
mindspore/ccsrc/optimizer/ad/kprim.cc View File

@@ -32,6 +32,7 @@
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "utils/context/ms_context.h"
#include "debug/info.h"
#include "debug/trace.h"

@@ -181,10 +182,19 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
}

void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool check_bprop_flag = context->check_bprop_flag();
// Skip checking if check_bprop not set
if (!check_bprop_flag) {
return;
}

// bprop_fg has been checked in caller
auto check_bprop = prim::GetPythonOps("check_bprop", "mindspore.ops.functional")->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(check_bprop);
check_bprop->set_attr("prim_to_check", std::make_shared<StringImm>(prim_to_check));
auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
MS_EXCEPTION_IF_NULL(check_bprop_class);
auto check_bprop =
bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});

std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
@@ -192,7 +202,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
AnfNodePtr params = bprop_fg->NewCNode(inputs);

inputs.clear();
inputs.push_back(NewValueNode(check_bprop));
inputs.push_back(check_bprop);
inputs.push_back(bprop_fg->output());
inputs.push_back(params);
AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);


+ 3
- 1
mindspore/ccsrc/pipeline/init.cc View File

@@ -141,7 +141,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.");
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
.def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.")
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.");

(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")


+ 3
- 0
mindspore/ccsrc/utils/context/ms_context.h View File

@@ -140,6 +140,8 @@ class MsContext {

void set_profiling_options(const std::string &options) { profiling_options_ = options; }
std::string profiling_options() const { return profiling_options_; }
bool check_bprop_flag() const { return check_bprop_flag_; }
void set_check_bprop_flag(bool check_bprop_flag) { check_bprop_flag_ = check_bprop_flag; }

private:
MsContext(const std::string &backend_policy, const std::string &target);
@@ -179,6 +181,7 @@ class MsContext {
std::thread tdt_print_;
bool profiling_mode_;
std::string profiling_options_;
bool check_bprop_flag_;
};

} // namespace mindspore


+ 10
- 1
mindspore/context.py View File

@@ -324,6 +324,13 @@ class _Context:
thread_info = self._thread_local_info
thread_info.debug_runtime = enable

@property
def check_bprop(self):
return self._context_handle.get_check_bprop_flag()

@check_bprop.setter
def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag)

def check_input_format(x):
import re
@@ -449,7 +456,8 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool)
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
check_bprop=bool)
def set_context(**kwargs):
"""
Sets context for running environment.
@@ -500,6 +508,7 @@ def set_context(**kwargs):
The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
training_trace and task_trace. Default: "training_trace".
check_bprop (bool): Whether to check bprop. Default: False.

Raises:
ValueError: If input key is not an attribute in context.


+ 2
- 1
mindspore/ops/operations/other_ops.py View File

@@ -323,8 +323,9 @@ class CheckBprop(PrimitiveWithInfer):
"""

@prim_attr_register
def __init__(self):
def __init__(self, prim_to_check=""):
"""init CheckBprop"""
self.prim_to_check = prim_to_check

def infer_shape(self, xshapes, yshapes):
tips = f'Bprop of {self.prim_to_check}'


+ 3
- 0
tests/ut/python/pynative_mode/test_cell_bprop.py View File

@@ -353,6 +353,7 @@ class MulAddWithWrongOutputNum(nn.Cell):


def test_grad_mul_add_with_wrong_output_num():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputNum()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, 2)
@@ -370,6 +371,7 @@ class MulAddWithWrongOutputType(nn.Cell):


def test_grad_mul_add_with_wrong_output_type():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputType()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
@@ -388,6 +390,7 @@ class MulAddWithWrongOutputShape(nn.Cell):


def test_grad_mul_add_with_wrong_output_shape():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputShape()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2])))

+ 3
- 1
tests/ut/python/pynative_mode/test_framstruct.py View File

@@ -893,6 +893,7 @@ def test_grad_if_defer_inline():


def test_bprop_with_wrong_output_num():
context.set_context(check_bprop=True)
class BpropWithWrongOutputNum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
@@ -926,8 +927,8 @@ def test_bprop_with_wrong_output_num():
with pytest.raises(TypeError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)


def test_bprop_with_wrong_output_type():
context.set_context(check_bprop=True)
class BpropWithWrongOutputType(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
@@ -963,6 +964,7 @@ def test_bprop_with_wrong_output_type():


def test_bprop_with_wrong_output_shape():
context.set_context(check_bprop=True)
class BpropWithWrongOutputShape(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):


Loading…
Cancel
Save