Browse Source

fix_codes_clean_and_pclint

tags/v1.3.0
lvliang 4 years ago
parent
commit
0d5d4935d4
2 changed files with 55 additions and 38 deletions
  1. +52
    -37
      mindspore/ccsrc/pybind_api/ir/primitive_py.cc
  2. +3
    -1
      mindspore/ccsrc/pybind_api/ir/primitive_py.h

+ 52
- 37
mindspore/ccsrc/pybind_api/ir/primitive_py.cc View File

@@ -182,11 +182,12 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
}
}

BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const {
BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const {
SyncData(py_args);
auto size = py_args.size();
py::tuple input_args(size - 2);
for (size_t i = 0; i < size - 2; ++i) {
constexpr size_t grad_param_nums = 2;
py::tuple input_args(size - grad_param_nums);
for (size_t i = 0; i < size - grad_param_nums; ++i) {
input_args[i] = py_args[i];
}
py::tuple convert_args(py_args.size());
@@ -207,48 +208,62 @@ BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const {
}
}

BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const {
constexpr size_t grad_input_index = 1;
constexpr size_t grad_output_index = 2;
constexpr size_t input_param_nums = 3;
SyncData(py_args[grad_output_index]);

py::object obj;
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
auto iter = hook_grad_.find(cell_id);
if (iter != hook_grad_.end()) {
py::tuple convert_args(input_param_nums - 1);
py::tuple input_args(input_param_nums - 1);
input_args[0] = iter->second;
input_args[1] = py_args[grad_output_index];
ConvertCTensorToPyTensor(input_args, &convert_args);
auto hook_args = py::tuple(input_param_nums);
hook_args[0] = cell_id;
hook_args[grad_input_index] = py::make_tuple(convert_args[0]);
hook_args[grad_output_index] = py::make_tuple(convert_args[1]);
obj = hook_(*hook_args);
if (py::isinstance<py::none>(obj)) {
obj = py_args[grad_output_index];
}
CheckHookConsistency(obj, py_args[grad_output_index]);
(void)hook_grad_.erase(cell_id);
} else {
hook_grad_[cell_id] = py_args[grad_output_index];
obj = py_args[grad_output_index];
}
obj = py::make_tuple(obj);
return std::make_shared<PyObjectRef>(obj);
}

BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const {
constexpr size_t grad_output_index = 2;
SyncData(py_args[grad_output_index]);
py::object obj = hook_(py::make_tuple(py_args[grad_output_index]));
if (py::isinstance<py::none>(obj)) {
obj = py_args[grad_output_index];
}
CheckHookConsistency(obj, py_args[grad_output_index]);
obj = py::make_tuple(obj);
return std::make_shared<PyObjectRef>(obj);
}

BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
py::tuple py_args = ConvertDatatoPyTuple(args);
bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) {
return RunBpropHookFunction(py_args);
return RunCellBpropFunction(py_args);
}
SyncData(py_args[2]);
bool is_cell = this->HasAttr(kCellHookAttrName);
py::object obj;
if (is_cell) {
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName));
auto iter = hook_grad_.find(cell_id);
if (iter != hook_grad_.end()) {
py::tuple convert_args(2);
py::tuple input_args(2);
input_args[0] = iter->second;
input_args[1] = py_args[2];
ConvertCTensorToPyTensor(input_args, &convert_args);
auto hook_args = py::tuple(3);
hook_args[0] = cell_id;
hook_args[1] = py::make_tuple(convert_args[0]);
hook_args[2] = py::make_tuple(convert_args[1]);
obj = hook_(*hook_args);
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];
}
CheckHookConsistency(obj, py_args[2]);
(void)hook_grad_.erase(cell_id);
} else {
hook_grad_[cell_id] = py_args[2];
obj = py_args[2];
}
} else {
// Hook operator for execute variable hook function
obj = hook_(py::make_tuple(py_args[2]));
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];
}
CheckHookConsistency(obj, py_args[2]);
return RunCellHookFunction(py_args);
}
obj = py::make_tuple(obj);
return std::make_shared<PyObjectRef>(obj);
return RunVariableHookFunction(py_args);
}

py::function PrimitivePy::GetComputeFunction() const {


+ 3
- 1
mindspore/ccsrc/pybind_api/ir/primitive_py.h View File

@@ -60,7 +60,9 @@ class PrimitivePy : public Primitive {
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
BaseRef RunHookFunction(const VectorRef &args) const override;
BaseRef RunBpropHookFunction(const py::tuple &py_args) const;
BaseRef RunCellBpropFunction(const py::tuple &py_args) const;
BaseRef RunCellHookFunction(const py::tuple &py_args) const;
BaseRef RunVariableHookFunction(const py::tuple &py_args) const;
BaseRef RunComputeFunction(const VectorRef &args) const override;
py::object RunPyComputeFunction(const py::tuple &py_args) const;
bool HasComputeFunction() const;


Loading…
Cancel
Save