|
|
|
@@ -525,105 +525,88 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_ |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, |
|
|
|
const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args, |
|
|
|
bool applyJ) { |
|
|
|
FuncGraphPtr ret = std::make_shared<FuncGraph>(); |
|
|
|
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); |
|
|
|
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, |
|
|
|
const std::vector<AnfNodePtr> &forward_graph_params, |
|
|
|
const std::vector<AnfNodePtr> &weight_args) { |
|
|
|
FuncGraphPtr k_child = std::make_shared<FuncGraph>(); |
|
|
|
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true); |
|
|
|
|
|
|
|
auto weights_node = weights; |
|
|
|
if (weights == nullptr && !args.empty()) { |
|
|
|
weights_node = ret->NewCNode(args); |
|
|
|
AnfNodePtr weights_node = nullptr; |
|
|
|
if (weights != nullptr) { |
|
|
|
weights_node = weights; |
|
|
|
} else if (!weight_args.empty()) { |
|
|
|
weights_node = k_child->NewCNode(weight_args); |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); |
|
|
|
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
if (applyJ) { |
|
|
|
inputs.push_back(opsJ); |
|
|
|
inputs.push_back(node); |
|
|
|
node = ret->NewCNode(inputs); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> params; |
|
|
|
for (size_t i = 0; i < params_list.size(); ++i) { |
|
|
|
params.push_back(ret->add_parameter()); |
|
|
|
inputs.push_back(k); |
|
|
|
for (size_t i = 0; i < forward_graph_params.size(); ++i) { |
|
|
|
inputs.push_back(k_child->add_parameter()); |
|
|
|
} |
|
|
|
auto k_app = k_child->NewCNode(inputs); |
|
|
|
|
|
|
|
inputs.clear(); |
|
|
|
inputs.push_back(node); |
|
|
|
(void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); |
|
|
|
AnfNodePtr cnode = ret->NewCNode(inputs); |
|
|
|
|
|
|
|
inputs.clear(); |
|
|
|
inputs.push_back(opsTupleItem); |
|
|
|
inputs.push_back(cnode); |
|
|
|
inputs.push_back(NewValueNode(static_cast<int64_t>(0))); |
|
|
|
auto out = ret->NewCNode(inputs); |
|
|
|
|
|
|
|
inputs.clear(); |
|
|
|
inputs.push_back(opsTupleItem); |
|
|
|
inputs.push_back(cnode); |
|
|
|
inputs.push_back(NewValueNode(static_cast<int64_t>(1))); |
|
|
|
AnfNodePtr ptr_bprop = ret->NewCNode(inputs); |
|
|
|
auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem); |
|
|
|
auto f_app = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))}); |
|
|
|
auto bprop = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))}); |
|
|
|
|
|
|
|
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem); |
|
|
|
return ret; |
|
|
|
GradByParameter(k_child, f_app, bprop, weights_node); |
|
|
|
return k_child; |
|
|
|
} |
|
|
|
|
|
|
|
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights, |
|
|
|
ValueNodePtr opsTupleItem) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
// Do grad by the parameter of GradOperation. |
|
|
|
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, |
|
|
|
const AnfNodePtr &weights) { |
|
|
|
MS_EXCEPTION_IF_NULL(k_child); |
|
|
|
|
|
|
|
AnfNodePtr ptr_bprop_arg = nullptr; |
|
|
|
AnfNodePtr bprop_arg = nullptr; |
|
|
|
if (sens_param_) { |
|
|
|
ptr_bprop_arg = func_graph->add_parameter(); |
|
|
|
bprop_arg = k_child->add_parameter(); |
|
|
|
} else { |
|
|
|
auto ones_like = prim::GetPythonOps("ones_like"); |
|
|
|
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out}); |
|
|
|
bprop_arg = k_child->NewCNode({NewValueNode(ones_like), f_app}); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg}); |
|
|
|
AnfNodePtr b_app = k_child->NewCNode({bprop, bprop_arg}); |
|
|
|
|
|
|
|
CNodePtr fv_bprop = nullptr; |
|
|
|
if (get_by_list_) { |
|
|
|
// python code: grads = hyper_map(F.partial(env_get, env), weights) |
|
|
|
AnfNodePtr env = |
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(static_cast<int64_t>(0))}); |
|
|
|
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))}); |
|
|
|
AnfNodePtr partial_env_get = |
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); |
|
|
|
k_child->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); |
|
|
|
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>(); |
|
|
|
fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); |
|
|
|
fv_bprop = k_child->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr inputs_bprop = nullptr; |
|
|
|
if (get_all_) { |
|
|
|
TailPtr tail = std::make_shared<Tail>("tail", true); |
|
|
|
inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp}); |
|
|
|
inputs_bprop = k_child->NewCNode({NewValueNode(tail), b_app}); |
|
|
|
} |
|
|
|
|
|
|
|
// Gradients wrt inputs and parameters |
|
|
|
if (fv_bprop != nullptr && inputs_bprop != nullptr) { |
|
|
|
func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); |
|
|
|
k_child->set_output(k_child->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Gradients wrt parameters |
|
|
|
if (fv_bprop != nullptr) { |
|
|
|
func_graph->set_output(fv_bprop); |
|
|
|
k_child->set_output(fv_bprop); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Gradients wrt inputs |
|
|
|
if (inputs_bprop != nullptr) { |
|
|
|
func_graph->set_output(inputs_bprop); |
|
|
|
k_child->set_output(inputs_bprop); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Gradients wrt first input. |
|
|
|
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input |
|
|
|
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(static_cast<int64_t>(1))})); |
|
|
|
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input |
|
|
|
k_child->set_output( |
|
|
|
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(1))})); |
|
|
|
} |
|
|
|
|
|
|
|
// Generate the graph. |
|
|
|
@@ -643,39 +626,39 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp |
|
|
|
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); |
|
|
|
MS_EXCEPTION_IF_NULL(real_fn); |
|
|
|
|
|
|
|
FuncGraphPtr ptr_graph = real_fn->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(ptr_graph); |
|
|
|
FuncGraphPtr df_builder = nullptr; |
|
|
|
FuncGraphPtr forward_graph = real_fn->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph); |
|
|
|
FuncGraphPtr grad_fg = nullptr; |
|
|
|
{ |
|
|
|
TraceGuard g(std::make_shared<TraceGradOperation>(ptr_graph->debug_info())); |
|
|
|
df_builder = std::make_shared<FuncGraph>(); |
|
|
|
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info())); |
|
|
|
grad_fg = std::make_shared<FuncGraph>(); |
|
|
|
} |
|
|
|
auto nparam = ptr_graph->parameters().size(); |
|
|
|
auto nparam = forward_graph->parameters().size(); |
|
|
|
|
|
|
|
std::ostringstream ss; |
|
|
|
ss << "grad{" << nparam << "}"; |
|
|
|
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true); |
|
|
|
df_builder->debug_info()->set_name(ss.str()); |
|
|
|
ParameterPtr param_graph = df_builder->add_parameter(); |
|
|
|
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); |
|
|
|
grad_fg->debug_info()->set_name(ss.str()); |
|
|
|
ParameterPtr param_graph = grad_fg->add_parameter(); |
|
|
|
|
|
|
|
AnfNodePtr weights = nullptr; |
|
|
|
if (get_by_list_) { |
|
|
|
weights = df_builder->add_parameter(); |
|
|
|
weights = grad_fg->add_parameter(); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.push_back(NewValueNode(prim::kPrimJ)); |
|
|
|
inputs.push_back(param_graph); |
|
|
|
auto jf = df_builder->NewCNode(inputs); |
|
|
|
auto j = grad_fg->NewCNode(inputs); |
|
|
|
// df is checked in GetGrad |
|
|
|
FuncGraphPtr df = nullptr; |
|
|
|
FuncGraphPtr k_child = nullptr; |
|
|
|
{ |
|
|
|
TraceGuard guard(std::make_shared<TraceGradOperation>(ptr_graph->debug_info())); |
|
|
|
df = GetGrad(jf, weights, ptr_graph->parameters()); |
|
|
|
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info())); |
|
|
|
k_child = GetGrad(j, weights, forward_graph->parameters()); |
|
|
|
} |
|
|
|
df_builder->set_output(NewValueNode(df)); |
|
|
|
grad_fg->set_output(NewValueNode(k_child)); |
|
|
|
|
|
|
|
return df_builder; |
|
|
|
return grad_fg; |
|
|
|
} |
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { |
|
|
|
|