|
|
|
@@ -172,6 +172,7 @@ FuncGraphPtr BuildFakeBProp(const PrimitivePtr &prim, size_t inputs_num) { |
|
|
|
for (size_t i = 0; i < inputs_num; ++i) { |
|
|
|
// Mock params for inputs |
|
|
|
auto param = func_graph->add_parameter(); |
|
|
|
MS_EXCEPTION_IF_NULL(param); |
|
|
|
// Mock derivatives for each inputs |
|
|
|
outputs.push_back(fake_input_sens); |
|
|
|
} |
|
|
|
@@ -191,10 +192,10 @@ class PynativeAdjoint { |
|
|
|
: tape_(tape), op_args_(op_args), out_(out), fg_(fg), fg_type_(fg_type) {} |
|
|
|
|
|
|
|
AnfNodePtrList &users() { return users_; } |
|
|
|
const ValuePtrList &op_args() { return op_args_; } |
|
|
|
const ValuePtr &out() { return out_; } |
|
|
|
const FuncGraphPtr &fg() { return fg_; } |
|
|
|
const FuncGraphType &fg_type() { return fg_type_; } |
|
|
|
const ValuePtrList &op_args() const { return op_args_; } |
|
|
|
const ValuePtr &out() const { return out_; } |
|
|
|
const FuncGraphPtr &fg() const { return fg_; } |
|
|
|
const FuncGraphType &fg_type() const { return fg_type_; } |
|
|
|
AnfNodePtr RealDout() { |
|
|
|
if (dout_ != nullptr) { |
|
|
|
return dout_; |
|
|
|
@@ -249,11 +250,11 @@ class KPynativeCellImpl : public KPynativeCell { |
|
|
|
tape_->debug_info()->set_name("grad_top"); |
|
|
|
for (size_t i = 0; i < cell_inputs.size(); ++i) { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info())); |
|
|
|
tape_->add_parameter(); |
|
|
|
(void)tape_->add_parameter(); |
|
|
|
// Build adjoint for every input parameter |
|
|
|
auto input_adjoint = |
|
|
|
std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, input_param_values[i], FuncGraphPtr(nullptr)); |
|
|
|
anfnode_to_adjoin_.insert(std::make_pair(cell_inputs[i], input_adjoint)); |
|
|
|
(void)anfnode_to_adjoin_.insert(std::make_pair(cell_inputs[i], input_adjoint)); |
|
|
|
} |
|
|
|
} |
|
|
|
~KPynativeCellImpl() override = default; |
|
|
|
@@ -348,11 +349,11 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_ |
|
|
|
SetSensAndWeights(weights, has_sens_arg); |
|
|
|
// Build forward CNode; |
|
|
|
if (build_formal_param) { |
|
|
|
BuildKNode(); |
|
|
|
(void)BuildKNode(); |
|
|
|
} |
|
|
|
// BackPropagate sensitivity, except when the last node is a valuenode which may be obtained by constant folding; |
|
|
|
if (!last_node_->isa<ValueNode>()) { |
|
|
|
BackPropagate(!build_formal_param); |
|
|
|
(void)BackPropagate(!build_formal_param); |
|
|
|
} |
|
|
|
// Return the gradient; |
|
|
|
SetOutput(weights, grad_inputs, grad_weights); |
|
|
|
@@ -394,7 +395,7 @@ bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &o |
|
|
|
} |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(bprop_fg); |
|
|
|
BuildAdjoint(cnode, op_args, out, bprop_fg); |
|
|
|
(void)BuildAdjoint(cnode, op_args, out, bprop_fg); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -413,7 +414,7 @@ bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &cnode, const ValuePtr |
|
|
|
MS_LOG(EXCEPTION) << "Should be func graph, but: " << cnode->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(bprop_fg); |
|
|
|
BuildAdjoint(cnode, op_args, out, bprop_fg); |
|
|
|
(void)BuildAdjoint(cnode, op_args, out, bprop_fg); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -423,7 +424,7 @@ bool KPynativeCellImpl::KPynativeWithFProp(const CNodePtr &cnode, const ValuePtr |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(fprop_fg); |
|
|
|
|
|
|
|
BuildAdjoint(cnode, op_args, out, fprop_fg, PynativeAdjoint::kForwardPropagate); |
|
|
|
(void)BuildAdjoint(cnode, op_args, out, fprop_fg, PynativeAdjoint::kForwardPropagate); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -446,7 +447,7 @@ void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) |
|
|
|
MS_LOG(DEBUG) << "Build adjoint for valuenode: " << v_node->ToString(); |
|
|
|
auto v_node_pynative_adjoint = |
|
|
|
std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, v_node->value(), FuncGraphPtr(nullptr)); |
|
|
|
anfnode_to_adjoin_.insert(std::make_pair(output_node, v_node_pynative_adjoint)); |
|
|
|
(void)anfnode_to_adjoin_.insert(std::make_pair(output_node, v_node_pynative_adjoint)); |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString(); |
|
|
|
@@ -461,8 +462,8 @@ ValuePtr ShallowCopyValue(const ValuePtr &value) { |
|
|
|
} else if (value->isa<ValueTuple>()) { |
|
|
|
std::vector<ValuePtr> values; |
|
|
|
auto value_tuple = value->cast<ValueTuplePtr>(); |
|
|
|
std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values), |
|
|
|
[](const ValuePtr &elem) { return ShallowCopyValue(elem); }); |
|
|
|
(void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values), |
|
|
|
[](const ValuePtr &elem) { return ShallowCopyValue(elem); }); |
|
|
|
return std::make_shared<ValueTuple>(values); |
|
|
|
} else { |
|
|
|
return value; |
|
|
|
@@ -504,7 +505,7 @@ PynativeAdjointPtr KPynativeCellImpl::ForgeGetItemAdjoint(const CNodePtr &cnode) |
|
|
|
if (index_value->value() < 0) { |
|
|
|
MS_LOG(EXCEPTION) << "CNode input 2 should not less than 0, CNode: " << cnode->DebugString(); |
|
|
|
} |
|
|
|
size_t index_value_imm = index_value->value(); |
|
|
|
size_t index_value_imm = LongToSize(index_value->value()); |
|
|
|
if (index_value_imm >= input_1_out->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "CNode input 2 should be index between [0, " << input_1_out->size() |
|
|
|
<< ", but: " << index_value->ToString(); |
|
|
|
@@ -608,7 +609,7 @@ void KPynativeCellImpl::BuildAdjointForInput(const CNodePtr &cnode, const ValueP |
|
|
|
} else { |
|
|
|
auto input_adjoint = |
|
|
|
std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, op_args[i - 1], FuncGraphPtr(nullptr)); |
|
|
|
anfnode_to_adjoin_.insert(std::make_pair(input, input_adjoint)); |
|
|
|
(void)anfnode_to_adjoin_.insert(std::make_pair(input, input_adjoint)); |
|
|
|
input_adjoint->users().push_back(cnode); |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -623,8 +624,8 @@ bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList & |
|
|
|
// Clone op_args and out, so the address of tensor data can be reset to nullptr if the value of tensor |
|
|
|
// is not used in bprop_fg; |
|
|
|
ValuePtrList cloned_op_args; |
|
|
|
std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args), |
|
|
|
[](const ValuePtr &value) { return ShallowCopyValue(value); }); |
|
|
|
(void)std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args), |
|
|
|
[](const ValuePtr &value) { return ShallowCopyValue(value); }); |
|
|
|
ValuePtr cloned_out = ShallowCopyValue(out); |
|
|
|
PynativeAdjointPtr cnode_adjoint; |
|
|
|
if (fg_type == PynativeAdjoint::kBackwardPropagate) { |
|
|
|
@@ -636,7 +637,7 @@ bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList & |
|
|
|
|
|
|
|
BuildAdjointForInput(cnode, op_args); |
|
|
|
|
|
|
|
anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_adjoint)); |
|
|
|
(void)anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_adjoint)); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -718,7 +719,7 @@ const AnfNodePtrList KPynativeCellImpl::BuildKNodeListFromPrimalCNode(const CNod |
|
|
|
MS_EXCEPTION_IF_NULL(adjoint); |
|
|
|
AnfNodePtrList node_list; |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) { |
|
|
|
node_list.emplace_back(BuildKNodeForCNodeInput(adjoint, cnode->input(i), i)); |
|
|
|
(void)node_list.emplace_back(BuildKNodeForCNodeInput(adjoint, cnode->input(i), i)); |
|
|
|
} |
|
|
|
return node_list; |
|
|
|
} |
|
|
|
@@ -758,7 +759,7 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr & |
|
|
|
node_list.push_back(adjoint->RealDout()); |
|
|
|
} else { |
|
|
|
const auto &k_node_list = BuildKNodeListFromPrimalCNode(cnode, adjoint); |
|
|
|
node_list.insert(node_list.end(), k_node_list.begin(), k_node_list.end()); |
|
|
|
(void)node_list.insert(node_list.end(), k_node_list.begin(), k_node_list.end()); |
|
|
|
// out; |
|
|
|
node_list.push_back(adjoint->k_node()); |
|
|
|
// dout; |
|
|
|
@@ -767,7 +768,7 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr & |
|
|
|
// Back propagate process |
|
|
|
auto bprop_app = tape_->NewCNode(node_list); |
|
|
|
bprop_app->set_abstract(bprop_output_abs); |
|
|
|
BackPropagate(cnode, bprop_app); |
|
|
|
(void)BackPropagate(cnode, bprop_app); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -780,12 +781,12 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr & |
|
|
|
CNodePtr bprop_cnode; |
|
|
|
if (by_value) { |
|
|
|
AnfNodePtrList args_node_list; |
|
|
|
std::transform(adjoint->op_args().begin(), adjoint->op_args().end(), std::back_inserter(args_node_list), |
|
|
|
[](const ValuePtr &value) { |
|
|
|
auto v_node = NewValueNode(value); |
|
|
|
v_node->set_abstract(value->ToAbstract()->Broaden()); |
|
|
|
return v_node; |
|
|
|
}); |
|
|
|
(void)std::transform(adjoint->op_args().begin(), adjoint->op_args().end(), std::back_inserter(args_node_list), |
|
|
|
[](const ValuePtr &value) { |
|
|
|
auto v_node = NewValueNode(value); |
|
|
|
v_node->set_abstract(value->ToAbstract()->Broaden()); |
|
|
|
return v_node; |
|
|
|
}); |
|
|
|
|
|
|
|
bprop_cnode = GetBPropFromFProp(fprop_fg, args_node_list); |
|
|
|
} else { |
|
|
|
@@ -797,7 +798,7 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr & |
|
|
|
node_list.push_back(adjoint->RealDout()); |
|
|
|
// Back propagate process |
|
|
|
auto bprop_app = tape_->NewCNode(node_list); |
|
|
|
BackPropagate(cnode, bprop_app); |
|
|
|
(void)BackPropagate(cnode, bprop_app); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -829,9 +830,9 @@ bool KPynativeCellImpl::BackPropagate(bool by_value) { |
|
|
|
auto fg_type = iter->second->fg_type(); |
|
|
|
|
|
|
|
if (fg_type == PynativeAdjoint::kBackwardPropagate) { |
|
|
|
BackPropagateOneCNodeWithBPropFuncGraph(cnode, iter->second, fg, by_value); |
|
|
|
(void)BackPropagateOneCNodeWithBPropFuncGraph(cnode, iter->second, fg, by_value); |
|
|
|
} else { |
|
|
|
BackPropagateOneCNodeWithFPropFuncGraph(cnode, iter->second, fg, by_value); |
|
|
|
(void)BackPropagateOneCNodeWithFPropFuncGraph(cnode, iter->second, fg, by_value); |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
@@ -921,6 +922,7 @@ FuncGraphPtr KPynativeCellImpl::BuildMakeSequenceBprop(const PrimitivePtr &prim, |
|
|
|
b->debug_info()->set_name(ss.str()); |
|
|
|
for (size_t i = 0; i < inputs_num; ++i) { |
|
|
|
auto param = b->add_parameter(); |
|
|
|
MS_EXCEPTION_IF_NULL(param); |
|
|
|
} |
|
|
|
// out, dout |
|
|
|
auto p1 = b->add_parameter(); |
|
|
|
@@ -974,7 +976,7 @@ void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool ha |
|
|
|
for (const auto &weight : weights) { |
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info())); |
|
|
|
auto p = tape_->add_parameter(); |
|
|
|
need_grad_weights_.emplace(weight); |
|
|
|
(void)need_grad_weights_.emplace(weight); |
|
|
|
auto input_w = weight->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(input_w); |
|
|
|
// Use name to match weight parameter in high order |
|
|
|
@@ -1066,7 +1068,7 @@ bool KPynativeCellImpl::BuildKNode() { |
|
|
|
auto cnode = iter->first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
for (size_t i = 0; i < cnode->inputs().size(); ++i) { |
|
|
|
node_list.emplace_back(BuildKNodeForCNodeInput(iter->second, cnode->input(i), i)); |
|
|
|
(void)node_list.emplace_back(BuildKNodeForCNodeInput(iter->second, cnode->input(i), i)); |
|
|
|
} |
|
|
|
auto k_node = tape_->NewCNode(node_list); |
|
|
|
k_node->set_abstract(iter->second->out()->ToAbstract()->Broaden()); |
|
|
|
@@ -1091,7 +1093,7 @@ CNodePtr KPynativeCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, cons |
|
|
|
auto get_bprop = |
|
|
|
bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))}); |
|
|
|
bprop_builder->set_output(get_bprop); |
|
|
|
bprop_builder_inputs.insert(bprop_builder_inputs.begin(), NewValueNode(bprop_builder)); |
|
|
|
(void)bprop_builder_inputs.insert(bprop_builder_inputs.begin(), NewValueNode(bprop_builder)); |
|
|
|
get_bprop = tape_->NewCNode(bprop_builder_inputs); |
|
|
|
|
|
|
|
return get_bprop; |
|
|
|
@@ -1103,7 +1105,7 @@ void KPynativeCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bo |
|
|
|
const auto ¶meters = tape_->parameters(); |
|
|
|
auto cell_inputs_size = cell_inputs_.size(); |
|
|
|
for (size_t i = 0; i < cell_inputs_size; ++i) { |
|
|
|
tr.Replace(cell_inputs_[i], parameters[i]); |
|
|
|
(void)tr.Replace(cell_inputs_[i], parameters[i]); |
|
|
|
} |
|
|
|
// (Inputs, sens, weights) or (Inputs, weights) |
|
|
|
size_t weight_offset = cell_inputs_size; |
|
|
|
@@ -1111,7 +1113,7 @@ void KPynativeCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bo |
|
|
|
weight_offset = weight_offset + 1; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < weights.size(); ++i) { |
|
|
|
tr.Replace(weights[i], parameters[weight_offset + i]); |
|
|
|
(void)tr.Replace(weights[i], parameters[weight_offset + i]); |
|
|
|
} |
|
|
|
tr.Commit(); |
|
|
|
} |
|
|
|
|