Browse Source

!10064 Fix bprop second derivative

From: @zjun3021
Reviewed-by: @kisnwang
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8928170b75
4 changed files with 338 additions and 165 deletions
  1. +1
    -1
      mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
  2. +265
    -144
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +47
    -17
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  4. +25
    -3
      mindspore/ccsrc/pybind_api/ir/primitive_py.cc

+ 1
- 1
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc View File

@@ -46,7 +46,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
auto bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;

auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", obj);
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));


+ 265
- 144
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -624,7 +624,11 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
auto resource = GetResource();
MS_EXCEPTION_IF_NULL(resource);
MS_LOG(DEBUG) << "Get resource ptr " << resource.get();
int64_t graph_id = resource->results()[pipeline::kPynativeGraphId].cast<int64_t>();
int64_t graph_id = graph_id_;
auto it = resource->results().find(pipeline::kPynativeGraphId);
if (it != resource->results().end()) {
graph_id = it->second.cast<int64_t>();
}
op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]);
op_index_map_[op_name]++;
}
@@ -943,8 +947,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
auto param_name = py::cast<std::string>(name_attr);
auto df_builder = GetDfbuilder();
MS_EXCEPTION_IF_NULL(df_builder);
if (graph_info_map_.at(df_builder).second.params.find(obj_id) ==
graph_info_map_.at(df_builder).second.params.end()) {
if (graph_info_map_.at(df_builder).params.find(obj_id) == graph_info_map_.at(df_builder).params.end()) {
auto free_param = df_builder->add_parameter();
free_param->set_name(param_name);
free_param->debug_info()->set_name(param_name);
@@ -957,12 +960,12 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param);
return free_param;
}
node = graph_info_map_.at(df_builder).second.node_map[obj_id].first;
MS_LOG(DEBUG) << "Get input node " << node->ToString() << obj_id;
node = graph_info_map_.at(df_builder).node_map[obj_id].first;
MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
return node;
}

if (graph_info_map_.at(curr_g_).second.node_map.find(obj_id) != graph_info_map_.at(curr_g_).second.node_map.end()) {
if (graph_info_map_.at(curr_g_).node_map.find(obj_id) != graph_info_map_.at(curr_g_).node_map.end()) {
// op(x, y)
// out = op(op1(x, y))
// out = op(cell1(x, y))
@@ -989,7 +992,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
node = MakeValueNode(obj, obj_id);
}
node == nullptr ? MS_LOG(DEBUG) << "Get node is nullptr"
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << obj_id;
: MS_LOG(DEBUG) << "Get input node " << node->ToString() << " " << obj_id;
return node;
}

@@ -1077,14 +1080,14 @@ void PynativeExecutor::CleanTensorsInValueNode() {
}

AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
auto &out = graph_info_map_.at(curr_g_).second.node_map[obj_id];
auto &out = graph_info_map_.at(curr_g_).node_map[obj_id];
if (out.second.size() == 1 && out.second[0] == -1) {
return out.first;
}
MS_LOG(DEBUG) << "Output size " << out.second.size();

// Params node
if (graph_info_map_.at(curr_g_).second.params.find(obj_id) != graph_info_map_.at(curr_g_).second.params.end()) {
if (graph_info_map_.at(curr_g_).params.find(obj_id) != graph_info_map_.at(curr_g_).params.end()) {
auto para_node = out.first;
for (auto &idx : out.second) {
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node,
@@ -1441,35 +1444,50 @@ bool PynativeExecutor::IsNotNestedGrad() const {
}

bool PynativeExecutor::IsTopGraph(const std::string &cell_id) {
return std::any_of(
top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) {
return value.first == cell_id;
});
return std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
}

void PynativeExecutor::SubNestedGradCount() {
bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
});
}

void PynativeExecutor::SubNestedGradOrder() {
if (grad_order_ > 0) {
--grad_order_;
}
}

bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) {
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id, is_grad](const std::pair<std::string, std::pair<FuncGraphPtr, bool>> &value) {
return value.first == cell_id && (!is_grad || value.second.second);
});
return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) {
return value.cell_id == cell_id && (!is_grad || value.is_grad);
});
}

FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
// Cell is empty, get nearest dfbuilder
if (cell_id.empty() && !top_cell_list_.empty()) {
return top_cell_list_.back().second.second.first;
if (top_cell_list_.size() == 1) {
return top_cell_list_.begin()->df_builder;
}
if (grad_order_ == 0 || grad_order_ == 1) {
return top_cell_list_.back().df_builder;
}
if (top_cell_list_.size() < grad_order_) {
MS_LOG(EXCEPTION) << "Get wrong grad order";
}
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
// Grad order greater than 2
auto it = top_cell_list_.end();
std::advance(it, -2);
return it->df_builder;
}
// If top graph hold
for (const auto &it : top_cell_list_) {
if (cell_id.find(it.first) != std::string::npos) {
return it.second.second.first;
if (cell_id.find(it.cell_id) != std::string::npos) {
return it.df_builder;
}
}
return nullptr;
@@ -1478,15 +1496,31 @@ FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) {
ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) {
// Cell is empty, get nearest resource
if (cell_id.empty() && !top_cell_list_.empty()) {
return top_cell_list_.back().second.first;
if (top_cell_list_.size() == 1) {
return top_cell_list_.begin()->resource;
}
if (grad_order_ == 0 || grad_order_ == 1) {
return top_cell_list_.back().resource;
}
if (top_cell_list_.size() < grad_order_) {
MS_LOG(EXCEPTION) << "Get wrong grad order";
}
MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size();
// Grad order greater than 2
auto it = top_cell_list_.end();
std::advance(it, -2);
return it->resource;
}
for (const auto &it : top_cell_list_) {
if (cell_id.find(it.first) != std::string::npos) {
return it.second.first;
if (cell_id.find(it.cell_id) != std::string::npos) {
return it.resource;
}
}
// Current cell is not top graph, get first top cell
return top_cell_list_.front().second.first;
if (!top_cell_list_.empty()) {
return top_cell_list_.front().resource;
}
return nullptr;
}

std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
@@ -1674,7 +1708,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
}
PushCurrentGraphToStack();
if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) {
graph_info_map_.emplace(curr_g_, std::make_pair(cell_id, GraphInfo()));
GraphInfo graph_info = GraphInfo(cell_id);
graph_info_map_.emplace(curr_g_, graph_info);
}
for (size_t i = 0; i < args.size(); ++i) {
auto param = args[i];
@@ -1682,13 +1717,23 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
std::string param_id = GetId(param);
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true);
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
SetParamNodeMapInGraphInfoMap(curr_g_, param_id, nullptr);
SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param);
}
// check whether the construct of cell will be changed
if (!dynamic_cell_) {
dynamic_cell_ = IsDynamicCell(cell);
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
}
// Make bprop graph
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
return value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
});
if (it != cell_graph_list_.end()) {
MS_LOG(INFO) << "Make bprop graph";
it->custom_bprop_graph = true;
}
}
}

void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) {
@@ -1701,11 +1746,8 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
}
}
// Clear runop pre
auto it = std::find_if(
top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) {
return value.first == cell_id;
});
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
if (it != top_cell_list_.end()) {
top_cell_list_.erase(it);
}
@@ -1714,10 +1756,12 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
op_index_with_tensor_id_.clear();

auto df_builder = std::make_shared<FuncGraph>();
graph_info_map_.emplace(df_builder, std::make_pair(cell_id, GraphInfo()));
GraphInfo graph_info = GraphInfo(cell_id);
graph_info_map_.emplace(df_builder, graph_info);
auto resource = std::make_shared<pipeline::Resource>();
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
top_cell_list_.emplace_back(std::make_pair(cell_id, std::make_pair(resource, std::make_pair(df_builder, nullptr))));
auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id);
top_cell_list_.emplace_back(top_cell_info);
MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get();
}

@@ -1730,8 +1774,10 @@ void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const p
auto tuple_size = static_cast<int64_t>(tuple.size());
for (int64_t i = 0; i < tuple_size; ++i) {
auto id = GetId(tuple[i]);
if (is_param) {
SetParamNodeMapInGraphInfoMap(g, id, nullptr);
if (is_param && node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
SetParamNodeMapInGraphInfoMap(g, id, param);
}
SetNodeMapInGraphInfoMap(g, id, node, i);
SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param);
@@ -1750,8 +1796,10 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con
std::vector<int64_t> tmp = index_sequence;
tmp.emplace_back(i);
auto id = GetId(tuple[i]);
if (is_param) {
SetParamNodeMapInGraphInfoMap(g, id, nullptr);
if (is_param && node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
SetParamNodeMapInGraphInfoMap(g, id, param);
}
SetNodeMapInGraphInfoMap(g, id, node, tmp);
SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param);
@@ -1767,7 +1815,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o
}
auto out_id = GetId(out);
// x =op1, y =op2, return (x, y)
if (graph_info_map_.at(curr_g_).second.node_map.find(out_id) == graph_info_map_.at(curr_g_).second.node_map.end()) {
if (graph_info_map_.at(curr_g_).node_map.find(out_id) == graph_info_map_.at(curr_g_).node_map.end()) {
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
auto tuple = out.cast<py::tuple>();
auto tuple_size = static_cast<int64_t>(tuple.size());
@@ -1795,27 +1843,33 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
auto resource = GetResource(cell_id);
MS_EXCEPTION_IF_NULL(resource);
resource->manager()->AddFuncGraph(curr_g_);
UpdateCellGraph(cell_id, true, false);

set_need_replace_forward(IsNotNestedGrad());
auto newfg = MakeGradGraph(cell, args, curr_g_, resource, IsTopGraph(cell_id));
auto is_bprop_graph = IsBpropGraph(cell_id);
auto is_bprop_cell = py::hasattr(cell, parse::CUSTOM_BPROP_NAME);
if (!is_bprop_cell || !is_bprop_graph) {
resource->manager()->AddFuncGraph(curr_g_);
}
if (!is_bprop_cell) {
UpdateCellGraph(cell, curr_g_, cell_id, true, false);
}
auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args);

if (graph_stack_.size() > 1) {
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(curr_g_));

PopGraphStack();
// connect the previous graph to the inside graph
auto graph_prev = graph_stack_.top();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], false);
inputs.emplace_back(input);
if (!is_bprop_cell || !is_bprop_graph) {
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(curr_g_));

PopGraphStack();
// connect the previous graph to the inside graph
auto graph_prev = graph_stack_.top();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], false);
inputs.emplace_back(input);
}
auto out_cnode = graph_prev->NewCNode(inputs);
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
}
auto out_cnode = graph_prev->NewCNode(inputs);
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args));
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
} else {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("before_resolve.ir", newfg);
@@ -1829,40 +1883,54 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
}
}

void PynativeExecutor::UpdateCellGraph(const std::string &cell_id, bool need_cloned, bool is_grad) {
FuncGraphPtr tmp = curr_g_;
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
// Bprop just save backward graph
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
if (it != cell_graph_list_.end()) {
it->is_grad = is_grad;
it->fg = g;
MS_LOG(DEBUG) << "Update bprop bg";
} else {
auto cell_info = CellInfo(false, true, false, g, cell_id);
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
}
return;
}
FuncGraphPtr tmp = g;
if (need_cloned && !IsNotNestedGrad()) {
auto cloned_curr_g = BasicClone(curr_g_);
graph_info_map_[cloned_curr_g] = graph_info_map_.at(curr_g_);
auto cloned_curr_g = BasicClone(g);
graph_info_map_[cloned_curr_g] = graph_info_map_.at(g);
tmp = cloned_curr_g;
MS_LOG(DEBUG) << "Replace cur graph " << curr_g_.get() << " with cloned new " << cloned_curr_g.get();
MS_LOG(DEBUG) << "Replace cur graph " << g.get() << " with cloned new " << cloned_curr_g.get();
}
for (auto &it : cell_graph_list_) {
if (it.first != cell_id) {
if (it.cell_id != cell_id) {
continue;
}
it.second.second = is_grad;
it.is_grad = is_grad;
if (need_cloned) {
it.second.first = tmp;
it.fg = tmp;
}
if (!need_cloned && !is_grad) {
graph_info_map_[curr_g_] = graph_info_map_.at(it.second.first);
graph_info_map_.erase(it.second.first);
it.second.first = curr_g_;
MS_LOG(DEBUG) << "Replace cur graph " << it.second.first.get() << " with new " << curr_g_.get();
graph_info_map_[g] = graph_info_map_.at(it.fg);
graph_info_map_.erase(it.fg);
it.fg = g;
MS_LOG(DEBUG) << "Replace cur graph " << it.fg.get() << " with new " << g.get();
}
return;
}
MS_LOG(DEBUG) << "Add new cell graph " << cell_id;
cell_graph_list_.insert(cell_graph_list_.begin(), std::make_pair(cell_id, std::make_pair(tmp, false)));
auto cell_info = CellInfo(false, true, false, tmp, cell_id);
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
}

FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::args &args, const FuncGraphPtr &g,
const ResourcePtr &r, bool is_top) {
// custom bprop debug
bool need_replace_param = false;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
need_replace_param = true;
FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
const string &cell_id, const py::args &args) {
bool is_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !IsBpropGraph(cell_id);
if (is_custom_bprop) {
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
@@ -1875,14 +1943,18 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::a
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
}
}
// Obtain grad graph
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("fg.ir", g);
FuncGraphPtr newfg = nullptr;
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) || is_custom_bprop) {
// Obtain grad graph
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("fg.ir", g);
}
auto is_top = IsTopGraph(cell_id);
MS_LOG(DEBUG) << "Grad top cell " << is_top;
set_need_replace_forward(IsNotNestedGrad());
newfg = ad::Grad(g, r, is_top);
}
MS_LOG(DEBUG) << "Grad top cell " << is_top;
auto newfg = ad::Grad(g, r, is_top);

if (need_replace_param) {
if (is_custom_bprop) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
if (args.size() > params.size()) {
@@ -1894,6 +1966,7 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::a
auto v_node = NewValueNode(value);
manager->Replace(params[i], v_node);
}
UpdateCellGraph(cell, newfg, cell_id, false, false);
}
return newfg;
}
@@ -1965,12 +2038,12 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
resource->manager()->KeepRoots({df_builder});
resource->results()[pipeline::kBackend] = compile::CreateBackend();

MS_LOG(DEBUG) << "Start opt";
MS_LOG(INFO) << "Start opt";
PynativeOptimizeAction(resource);
SaveTensorsInValueNode(resource);
TaskEmitAction(resource);
ExecuteAction(resource);
UpdateCellGraph(cell_id, false, true);
UpdateCellGraph(cell, curr_g_, cell_id, false, true);
UpdateGraphInfoMap(cell_id);
resource->Clean();
}
@@ -2018,13 +2091,10 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
return;
}
ResourcePtr resource = nullptr;
auto ia = std::find_if(
top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) {
return value.first == cell_id;
});
auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
if (ia != top_cell_list_.end()) {
resource = GetResource(ia->first);
resource = GetResource(ia->cell_id);
MS_EXCEPTION_IF_NULL(resource);
MS_LOG(DEBUG) << "Find old resource " << resource.get();
}
@@ -2035,22 +2105,33 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
}
MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
graph_info_map_.emplace(df_builder, std::make_pair(cell_id, GraphInfo()));
top_cell_list_.emplace_back(std::make_pair(cell_id, std::make_pair(resource, std::make_pair(df_builder, nullptr))));
GraphInfo graph_info = GraphInfo(cell_id);
graph_info_map_.emplace(df_builder, graph_info);
auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id);
top_cell_list_.emplace_back(top_cell_info);
FuncGraphPtr forward_graph = nullptr;
auto ib = std::find_if(
cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id](const std::pair<std::string, std::pair<FuncGraphPtr, bool>> &value) { return value.first == cell_id; });
auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
if (ib != cell_graph_list_.end()) {
forward_graph = ib->second.first;
forward_graph = ib->fg;
}
MS_EXCEPTION_IF_NULL(forward_graph);
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
DumpIR("nested_bprop.ir", forward_graph);
}
// Custom bprop get backward graph(before opt), which use like other forward graph
curr_g_ = forward_graph;
resource->set_func_graph(forward_graph);
return;
}

// Copy weights
std::vector<AnfNodePtr> weights_params{};
for (const auto &it : graph_info_map_.at(forward_graph).second.params) {
if (it.second != nullptr) {
for (const auto &it : graph_info_map_.at(forward_graph).params) {
if (it.second->has_default()) {
weights_params.emplace_back(it.second);
graph_info_map_.at(df_builder).second.params.emplace(it.first, it.second);
graph_info_map_.at(df_builder).params.emplace(it.first, it.second);
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
}
}
@@ -2061,7 +2142,7 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
DumpIR("nested_fg.ir", forward_graph);
}
set_need_replace_forward(false);
auto newfg = MakeGradGraph(cell, args, forward_graph, resource, IsTopGraph(cell_id));
auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
resource->set_func_graph(newfg);
}

@@ -2091,11 +2172,9 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
auto param = tuple[it];
auto param_id = GetId(param);
AnfNodePtr para_node = nullptr;
if (graph_info_map_.at(df_builder).second.params.find(param_id) !=
graph_info_map_.at(df_builder).second.params.end() &&
graph_info_map_.at(df_builder).second.node_map.find(param_id) !=
graph_info_map_.at(df_builder).second.node_map.end()) {
para_node = graph_info_map_.at(df_builder).second.node_map[param_id].first;
if (graph_info_map_.at(df_builder).params.find(param_id) != graph_info_map_.at(df_builder).params.end() &&
graph_info_map_.at(df_builder).node_map.find(param_id) != graph_info_map_.at(df_builder).node_map.end()) {
para_node = graph_info_map_.at(df_builder).node_map[param_id].first;
} else {
auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
if (py::isinstance<py::none>(name_attr)) {
@@ -2117,6 +2196,10 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder) {
abstract::AbstractBasePtrList args_spec;
std::size_t size = args.size();
auto df_params = df_builder->parameters();
if (df_params.size() < size) {
MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << size;
}
// input params
for (std::size_t i = 0; i < size; i++) {
ValuePtr converted = nullptr;
@@ -2127,11 +2210,11 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
bool broaden = true;
auto abs = abstract::FromValue(converted, broaden);
args_spec.emplace_back(abs);
auto param_node = std::static_pointer_cast<Parameter>(df_builder->parameters()[i]);
auto param_node = std::static_pointer_cast<Parameter>(df_params[i]);
param_node->set_abstract(abs);
}
// weights params
for (const auto &param : df_builder->parameters()) {
for (const auto &param : df_params) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
ValuePtr value = param_node->default_param();
@@ -2148,24 +2231,18 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder,
const ResourcePtr &resource) {
bool is_cloned = false;
std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>> r(
std::make_pair(nullptr, std::make_pair(nullptr, nullptr)));
auto it = std::find_if(
top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>> &value) {
return value.first == cell_id;
});
if (it != top_cell_list_.end()) {
r = it->second;
auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
[&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; });
if (it == top_cell_list_.end()) {
MS_LOG(EXCEPTION) << "Get top cell failed";
}
MS_EXCEPTION_IF_NULL(r.first);
if (r.second.second == nullptr) {
if (it->bg == nullptr) {
auto cloned_df_newfg = BasicClone(resource->func_graph());
r.second = std::make_pair(df_builder, cloned_df_newfg);
it->bg = cloned_df_newfg;
MS_LOG(DEBUG) << "Cloned df newfg";
is_cloned = false;
} else {
resource->set_func_graph(r.second.second);
resource->set_func_graph(it->bg);
MS_LOG(DEBUG) << "Used cloned df newfg";
}
return is_cloned;
@@ -2174,11 +2251,10 @@ bool PynativeExecutor::CloneDfbuiler(const std::string &cell_id, const FuncGraph
void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op,
const std::vector<AnfNodePtr> &weights, size_t arg_size, const std::string &cell_id) {
FuncGraphPtr top_g = nullptr;
auto it = std::find_if(
cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id](const std::pair<std::string, std::pair<FuncGraphPtr, bool>> &value) { return value.first == cell_id; });
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
if (it != cell_graph_list_.end()) {
top_g = it->second.first;
top_g = it->fg;
}
MS_EXCEPTION_IF_NULL(top_g);
auto nparam = top_g->parameters().size();
@@ -2194,8 +2270,12 @@ void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr &

auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g->parameters(), weights);
std::vector<AnfNodePtr> inputs = {NewValueNode(df)};
auto df_params = df_builder->parameters();
if (df_params.size() < arg_size) {
MS_LOG(EXCEPTION) << "Df parameters size " << df_params.size() << " less than " << arg_size;
}
for (size_t i = 0; i < arg_size; ++i) {
inputs.emplace_back(df_builder->parameters()[i]);
inputs.emplace_back(df_params[i]);
}
auto out = df_builder->NewCNode(inputs);
df_builder->set_output(out);
@@ -2208,17 +2288,17 @@ void PynativeExecutor::UpdateGraphInfoMap(const std::string &cell_id) {
bool index_find = false;
for (const auto &it : cell_graph_list_) {
if (index_find) {
l.emplace_back(it.first);
l.emplace_back(it.cell_id);
continue;
}
if (it.first == cell_id) {
if (it.cell_id == cell_id) {
index_find = true;
l.emplace_back(it.first);
l.emplace_back(it.cell_id);
}
}
for (const auto &it : l) {
for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) {
if (ic->second.first.find(it) != std::string::npos) {
if (ic->second.cell_id.find(it) != std::string::npos) {
ic = graph_info_map_.erase(ic);
} else {
++ic;
@@ -2229,7 +2309,7 @@ void PynativeExecutor::UpdateGraphInfoMap(const std::string &cell_id) {

py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &args) {
BaseRef ret = false;
AddNestedGradCount();
AddNestedGradOrder();
if (!grad_running()) {
MS_LOG(DEBUG) << "Grad not running yet";
return BaseRefToPyData(ret);
@@ -2238,8 +2318,8 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size()));
MS_LOG(DEBUG) << "Key is " << key;
for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) {
MS_LOG(DEBUG) << "Cur cell id " << it->first;
if (key != it->first.substr(0, std::min(PTR_LEN, it->first.size()))) {
MS_LOG(DEBUG) << "Cur cell id " << it->cell_id;
if (key != it->cell_id.substr(0, std::min(PTR_LEN, it->cell_id.size()))) {
continue;
}
MS_LOG(DEBUG) << "Delete cellid from cell graph list";
@@ -2255,7 +2335,7 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args,
MS_LOG(DEBUG) << "Run start cell id " << cell_id;
bool has_sens = false;
for (const auto &it : top_cell_list_) {
if (cell_id.find(it.first) != std::string::npos && cell_id != it.first) {
if (cell_id.find(it.cell_id) != std::string::npos && cell_id != it.cell_id) {
has_sens = true;
break;
}
@@ -2285,12 +2365,42 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args,
BaseRef value = (*run)(arg_list);
CleanTensorsInValueNode();
set_grad_runing(false);
MS_LOG(DEBUG) << "Run end " << value.ToString();
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
auto out = BaseRefToPyData(value);
if (MakeBpropNestedCnode(cell, out, cell_id)) {
return out;
}
MakeNestedCnode(cell_id, args, resource, out, has_sens);
return out;
}

bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id) {
if (graph_stack_.empty() || !py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
MS_LOG(DEBUG) << "No nested bprop grad find";
return false;
}
auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) {
return value.custom_bprop_graph && value.is_custom_bprop && cell_id.find(value.cell_id) != std::string::npos;
});
if (it != cell_graph_list_.end()) {
MS_LOG(DEBUG) << "Make bprop graph end";
}
auto out_id = GetId(out);
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(curr_g_));
PopGraphStack();
for (const auto &ig : graph_info_map_.at(curr_g_).params) {
if (!ig.second->has_default()) {
inputs.emplace_back(ig.second);
}
}
auto cnode = curr_g_->NewCNode(inputs);
SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
MS_LOG(DEBUG) << "Custom bprop make nested node is " << cnode->DebugString(4);
return true;
}

void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens) {
if (graph_stack_.empty()) {
@@ -2313,15 +2423,15 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
}
auto out_id = GetId(out);
auto cnode = graph_prev->NewCNode(inputs);
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
SetNodeMapInGraphInfoMap(graph_prev, out_id, cnode);
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
}

template <typename T>
void MapClear(T *map, const std::string &flag) {
void MapClear(T *map, const std::string &cell_id) {
for (auto it = map->begin(); it != map->end();) {
if (it->first.find(flag) != std::string::npos) {
if (it->first.find(cell_id) != std::string::npos) {
it = map->erase(it);
} else {
it++;
@@ -2329,6 +2439,17 @@ void MapClear(T *map, const std::string &flag) {
}
}

template <typename T>
void VectorClear(T *vec, const std::string &cell_id) {
for (auto it = vec->begin(); it != vec->end();) {
if (it->cell_id.find(cell_id) != std::string::npos) {
it = vec->erase(it);
} else {
it++;
}
}
}

void PynativeExecutor::Clear(const std::string &cell_id) {
if (cell_id.empty()) {
Clean();
@@ -2337,7 +2458,7 @@ void PynativeExecutor::Clear(const std::string &cell_id) {

MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id;
for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) {
if (it->second.first.find(cell_id) != std::string::npos) {
if (it->second.cell_id.find(cell_id) != std::string::npos) {
it = graph_info_map_.erase(it);
} else {
++it;
@@ -2351,14 +2472,14 @@ void PynativeExecutor::Clear(const std::string &cell_id) {
ConfigManager::GetInstance().ResetIterNum();
MapClear<std::unordered_map<std::string, bool>>(&cell_dynamic_map_, cell_id);
MapClear<std::unordered_map<std::string, std::pair<std::string, std::string>>>(&cell_sw_map_, cell_id);
MapClear<std::vector<std::pair<std::string, std::pair<FuncGraphPtr, bool>>>>(&cell_graph_list_, cell_id);
MapClear<std::vector<std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>>>>(
&top_cell_list_, cell_id);
VectorClear<std::vector<CellInfo>>(&cell_graph_list_, cell_id);
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
node_abs_map_.clear();
}

void PynativeExecutor::Clean() {
MS_LOG(DEBUG) << "Clean";
SubNestedGradCount();
SubNestedGradOrder();
node_abs_map_.clear();
obj_to_forward_id_.clear();
ad::CleanRes();


+ 47
- 17
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -58,10 +58,38 @@ py::tuple RunOp(const py::args &args);
void ClearPyNativeSession();

struct GraphInfo {
std::unordered_map<std::string, AnfNodePtr> params; // hold input parameters and cell weigths
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
std::string cell_id;
AnfNodePtr output;
std::unordered_map<std::string, ParameterPtr> params; // hold input parameters and cell weigths
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
std::vector<std::string> objects;
GraphInfo() = default;
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
};

struct CellInfo {
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
bool custom_bprop_graph{false}; // Custom bprop make forward graph
FuncGraphPtr fg; // Forward graph
std::string cell_id;
CellInfo() = default;
CellInfo(bool isgrad, bool custom_bprop, bool bprop_graph, FuncGraphPtr foward_graph, std::string cellid)
: is_grad(isgrad),
is_custom_bprop(custom_bprop),
custom_bprop_graph(bprop_graph),
fg(std::move(foward_graph)),
cell_id(std::move(cellid)) {}
};

struct TopCellInfo {
ResourcePtr resource;
FuncGraphPtr df_builder;
FuncGraphPtr bg; // Backward graph
std::string cell_id;
TopCellInfo() = default;
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid)
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {}
};

class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
@@ -147,23 +175,25 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void PopGraphStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = "");
void AddNestedGradCount() { ++grad_order_; }
void SubNestedGradCount();
void AddNestedGradOrder() { ++grad_order_; }
void SubNestedGradOrder();
bool IsNotNestedGrad() const;
bool IsTopGraph(const std::string &cell_id);
bool IsBpropGraph(const std::string &cell_id);
bool grad_running() const { return grad_is_running_; }
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; }
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
void UpdateCellGraph(const std::string &cell_id, bool need_cloned = false, bool is_grad = false);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
const std::string &out_id, const py::args &args);
FuncGraphPtr MakeGradGraph(const py::object &cell, const py::args &args, const FuncGraphPtr &g, const ResourcePtr &r,
bool is_top);
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, const string &cell_id,
const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
py::object *sens = nullptr);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
@@ -180,23 +210,24 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);

// Hold graph(forward and grad) info
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
graph_info_map_[g].second.objects.push_back(obj);
graph_info_map_[g].objects.push_back(obj);
}
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false);
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) {
graph_info_map_[g].second.params.emplace(std::make_pair(id, param));
graph_info_map_[g].params.emplace(std::make_pair(id, param));
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
int64_t index = -1) {
graph_info_map_[g].second.node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
const std::vector<int64_t> &index) {
graph_info_map_[g].second.node_map[id] = std::make_pair(node, index);
graph_info_map_[g].node_map[id] = std::make_pair(node, index);
}
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
const std::vector<int64_t> &index_sequence, bool is_param = false);
@@ -204,7 +235,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::shared_ptr<PynativeExecutor> executor_;
static std::mutex instance_lock_;
static int64_t graph_id_;
int64_t grad_order_{0};
size_t grad_order_{0};
bool grad_flag_{false};
bool dynamic_cell_{false};
bool grad_is_running_{false};
@@ -218,13 +249,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_set<std::string> cell_input_args_;
std::unordered_map<std::string, bool> cell_dynamic_map_;
// Record all info for all cells
std::unordered_map<FuncGraphPtr, std::pair<std::string, GraphInfo>> graph_info_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
// Use vector for keep order
std::vector<CellInfo> cell_graph_list_;
std::vector<TopCellInfo> top_cell_list_;
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
// key: cell_id, value: (forward graph, whether grad), use vector for keep order
std::vector<std::pair<std::string, std::pair<FuncGraphPtr, bool>>> cell_graph_list_;
// key: cell_id, value: (resource, (df_builder, grad graph), use vector for keep order
std::vector<std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>>> top_cell_list_;

// Used for runop and replace forward result of grad graph
std::unordered_map<std::string, size_t> op_index_map_;


+ 25
- 3
mindspore/ccsrc/pybind_api/ir/primitive_py.cc View File

@@ -29,6 +29,7 @@
#include "utils/ms_context.h"
#include "utils/primitive_utils.h"
#include "pipeline/jit/resource.h"
#include "pipeline/pynative/pynative_execute.h"

namespace mindspore {
namespace {
@@ -171,11 +172,32 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) {
SyncData(py_args);
auto size = py_args.size();
py::tuple input_args(size - 2);
for (size_t i = 0; i < size - 2; ++i) {
input_args[i] = py_args[i];
}
py::tuple convert_args(py_args.size());
ConvertCTensorToPyTensor(py_args, &convert_args);
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
return std::make_shared<PyObjectRef>(grads);
auto inst = pynative::PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
try {
inst->NewGraph(GetPyObj(), input_args.cast<py::args>());
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
inst->EndGraph(GetPyObj(), grads_obj, input_args.cast<py::args>());
return std::make_shared<PyObjectRef>(grads);
} catch (const py::type_error &ex) {
inst->ClearRes();
throw py::type_error(ex);
} catch (const py::value_error &ex) {
inst->ClearRes();
throw py::value_error(ex);
} catch (...) {
inst->ClearRes();
std::string exName(abi::__cxa_current_exception_type()->name());
MS_LOG(EXCEPTION) << "Error occurred in run bprop. Exception name: " << exName;
}
}
SyncData(py_args[2]);
bool is_cell = this->HasAttr(kCellHookAttrName);


Loading…
Cancel
Save