Browse Source

eliminate_forward_cnode_in_grad_graph_decorated_by_ms_function

tags/v1.5.0-rc1
7347157+joylvliang@user.noreply.gitee.com 4 years ago
parent
commit
0fb07a6377
19 changed files with 981 additions and 194 deletions
  1. +189
    -89
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
  2. +9
    -4
      mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h
  3. +2
    -0
      mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc
  4. +8
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc
  5. +3
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h
  6. +94
    -4
      mindspore/ccsrc/pipeline/jit/action.cc
  7. +2
    -0
      mindspore/ccsrc/pipeline/jit/base.h
  8. +32
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  9. +3
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.h
  10. +204
    -68
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  11. +23
    -14
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  12. +3
    -4
      mindspore/common/api.py
  13. +3
    -3
      mindspore/core/ir/anf.h
  14. +7
    -0
      mindspore/core/ir/func_graph.cc
  15. +8
    -0
      mindspore/core/ir/func_graph.h
  16. +2
    -2
      mindspore/core/ir/func_graph_cloner.cc
  17. +1
    -1
      mindspore/core/ir/func_graph_cloner.h
  18. +199
    -0
      tests/st/pynative/ms_function/test_pynative_lenet_ms_function.py
  19. +189
    -5
      tests/st/pynative/ms_function/test_pynative_ms_function.py

+ 189
- 89
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc View File

@@ -31,6 +31,7 @@
#include "utils/ms_context.h"
#include "pipeline/jit/action.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/pynative/pynative_execute.h"
#include "debug/anf_ir_dump.h"

namespace mindspore {
@@ -276,13 +277,13 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
k_app = k_graph_->NewCNode(inputs);
}
// Run in pynative mode, when @ms_function is used.
ReplaceEquivdout(k_app, cnode_morph);
cnode_morph->clear_inputs_value();
cnode_morph->set_forward(nullptr, "");
for (size_t i = 0; i < param_adjoints.size(); ++i) {
param_adjoints[i]->RegisterKUser(k_app, i);
}

// Do forward computation
auto foward_app =
k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
@@ -301,116 +302,215 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint;
}

ValuePtr DFunctor::GenNewTensorInner(const ValuePtr &value) {
std::vector<ValuePtr> value_list;
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
auto new_tensor = std::make_shared<tensor::Tensor>(*tensor);
new_tensor->set_device_address(nullptr);
return new_tensor;
}
if (value->isa<ValueTuple>()) {
auto tuple = value->cast<ValueTuplePtr>();
for (size_t i = 0; i < tuple->size(); i++) {
value_list.push_back(GenNewTensorInner((*tuple)[i]));
}
return std::make_shared<ValueTuple>(value_list);
}
return value;
tensor::TensorPtr DFunctor::GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem) {
MS_EXCEPTION_IF_NULL(type_elem);
MS_EXCEPTION_IF_NULL(shape_elem);
auto shape = shape_elem->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape);
auto tensor_type = type_elem->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto type = tensor_type->element();
MS_EXCEPTION_IF_NULL(type);
return std::make_shared<tensor::Tensor>(type->type_id(), shape->shape());
}

ValuePtr DFunctor::GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value,
bool need_replace_forward) {
ValuePtr out = value;
auto ref_size = mng->node_users()[node].size();
if (ref_size < 2) {
if (need_replace_forward) {
out = GenNewTensorInner(value);
} else {
auto tensor = value->cast<tensor::TensorPtr>();
tensor->set_device_address(nullptr);
return tensor;
ValueNodePtr DFunctor::GenNewTensor(const CNodePtr &cnode_morph) {
MS_EXCEPTION_IF_NULL(cnode_morph);
if (cnode_morph->forward().first != nullptr) {
return cnode_morph->forward().first;
}
if (IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
ValueNodePtr out_vnode = NewValueNode(std::make_shared<UMonad>());
out_vnode->set_abstract(std::make_shared<abstract::AbstractUMonad>());
return out_vnode;
}

auto cnode_shape = cnode_morph->Shape();
MS_EXCEPTION_IF_NULL(cnode_shape);
auto cnode_type = cnode_morph->Type();
MS_EXCEPTION_IF_NULL(cnode_type);
// Create output values.
if (cnode_type->isa<Tuple>()) {
auto tuple_shape = cnode_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
auto tuple_type = cnode_type->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_type);
size_t output_num = tuple_type->elements().size();
std::vector<ValuePtr> output_values;
for (size_t i = 0; i < output_num; ++i) {
auto shape_elem = tuple_shape->shape()[i];
auto type_elem = tuple_type->elements()[i];
output_values.push_back(GenNewTensorInner(type_elem, shape_elem));
}
if (output_values.empty()) {
MS_LOG(EXCEPTION) << "The output values is empty, cnode morph: " << cnode_morph->DebugString();
}
return NewValueNode(std::make_shared<ValueTuple>(output_values));
} else if (cnode_type->isa<TensorType>()) {
return NewValueNode(GenNewTensorInner(cnode_type, cnode_shape));
} else if (cnode_shape->isa<abstract::NoShape>()) {
ShapeVector NoShape;
return NewValueNode(std::make_shared<tensor::Tensor>(cnode_type->type_id(), NoShape));
}
return out;

MS_LOG(EXCEPTION) << "Unknown shape: " << cnode_shape->ToString() << ", type: " << cnode_type->ToString();
}

void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
auto forward = cnode_morph->forward().first;
if (forward == nullptr) {
void DFunctor::GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
FuncGraphPtr *fprop_graph) {
MS_EXCEPTION_IF_NULL(k_app);
MS_EXCEPTION_IF_NULL(fprop_graph);
const auto &prim = k_app->input(0);
if (!IsValueNode<FuncGraph>(prim)) {
return;
}
auto &input = cnode->input(0);
if (!IsValueNode<FuncGraph>(input)) {
return;
}
auto fg = GetValueNode<FuncGraphPtr>(input);
// Clone a new fprop graph for different k_app.
auto original_fprop = GetValueNode<FuncGraphPtr>(prim);
MS_EXCEPTION_IF_NULL(original_fprop);
*fprop_graph = BasicClone(original_fprop);
k_app->set_input(0, NewValueNode(*fprop_graph));

// {prim::maketuple, forward_output, bprop_graph}
auto output = fg->output();
auto output = (*fprop_graph)->output();
MS_EXCEPTION_IF_NULL(output);
if (!output->isa<CNode>()) {
return;
}
auto cnode_output = output->cast<CNodePtr>();
auto make_tuple_node = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple_node);
constexpr size_t input_size = 3;
if (cnode_output->size() < input_size) {
MS_LOG(EXCEPTION) << "The inputs size of node " << cnode_output->DebugString() << " is less than " << input_size;
}
constexpr size_t forward_output_index = 1;
auto &cnode_input = cnode_output->input(forward_output_index);
if (!cnode_input->isa<CNode>()) {
return;
if (make_tuple_node->size() != input_size) {
MS_LOG(EXCEPTION) << "The inputs size of make tuple node " << make_tuple_node->DebugString() << " is not equal to "
<< input_size;
}
constexpr size_t bprop_graph_index = 2;
auto &input_fg = cnode_output->input(bprop_graph_index);
if (!IsValueNode<FuncGraph>(input_fg)) {

// Get forward CNode.
const size_t forward_output_index = 1;
const auto &output_node = make_tuple_node->input(forward_output_index);
MS_EXCEPTION_IF_NULL(output_node);
if (!output_node->isa<CNode>()) {
return;
}
// replace forward output with value node
auto equivdout = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(equivdout);
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = Manage({fg, func_graph}, false);
auto forward_value = GenNewTensor(manager, equivdout, forward, true);
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
auto value_node = NewValueNode(forward_value);
value_node->set_has_new_value(true);
manager->Replace(equivdout, value_node);
// replace input object with value node
auto paras = fg->parameters();
auto inputs_value = cnode_morph->inputs_value();
if (inputs_value.empty()) {

// Get bprop graph of forward CNode.
const size_t bprop_graph_index = 2;
const auto &bprop_vnode = make_tuple_node->input(bprop_graph_index);
if (!IsValueNode<FuncGraph>(bprop_vnode)) {
return;
}
if (inputs_value.size() > paras.size()) {
MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " but inputs size:" << inputs_value.size();
}
for (size_t i = 0; i < inputs_value.size(); i++) {

MS_EXCEPTION_IF_NULL(forward_node);
MS_EXCEPTION_IF_NULL(bprop_graph);
*forward_node = output_node->cast<CNodePtr>();
*bprop_graph = GetValueNode<FuncGraphPtr>(bprop_vnode);
}

std::vector<AnfNodePtr> DFunctor::RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph) {
MS_EXCEPTION_IF_NULL(cnode_morph);
if (IsPrimitiveCNode(cnode_morph, prim::kPrimStopGradient)) {
return {};
}
// Use manager to get the link relation among nodes.
MS_EXCEPTION_IF_NULL(bprop_graph);
MS_EXCEPTION_IF_NULL(fprop_graph);
auto manager = Manage({fprop_graph, bprop_graph}, false);

// Replace output node.
MS_EXCEPTION_IF_NULL(forward_node);
auto ref_size = manager->node_users().at(forward_node).size();
MS_LOG(DEBUG) << "Ref size: " << ref_size;
auto output_vnode = GenNewTensor(cnode_morph);
MS_EXCEPTION_IF_NULL(output_vnode);
output_vnode->set_has_new_value(true);
manager->Replace(forward_node, output_vnode);
MS_LOG(DEBUG) << "Replace: " << forward_node->DebugString() << " with " << output_vnode->ToString();

// Save forward output node when it used in its bprop graph.
std::vector<AnfNodePtr> used_forward_nodes;
if (ref_size >= 2) {
cnode_morph->set_forward(output_vnode, "");
used_forward_nodes.push_back(cnode_morph);
MS_LOG(DEBUG) << "node has been used in grad graph: " << cnode_morph->DebugString()
<< ", its output value: " << output_vnode->ToString();
}
return used_forward_nodes;
}

std::vector<AnfNodePtr> DFunctor::RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
const CNodePtr &cnode_morph) {
// Use manager to get the link relation among nodes.
MS_EXCEPTION_IF_NULL(bprop_graph);
MS_EXCEPTION_IF_NULL(fprop_graph);
auto manager = Manage({fprop_graph, bprop_graph}, false);

MS_EXCEPTION_IF_NULL(cnode_morph);
const auto &paras = fprop_graph->parameters();
if (cnode_morph->size() - 1 != paras.size() && !IsPrimitiveCNode(cnode_morph, prim::kPrimUpdateState)) {
MS_LOG(EXCEPTION) << "The size of parameters in fprop graph:" << paras.size()
<< ", but the size of input tensors of forward node: " << cnode_morph->inputs().size() - 1;
}

std::vector<AnfNodePtr> used_input_nodes;
for (size_t i = 0; i < paras.size(); ++i) {
const auto &input_node = cnode_morph->input(i + 1);
MS_EXCEPTION_IF_NULL(input_node);
// Parameter, ValueNode and StopGradient CNode no need to replace.
if (input_node->isa<Parameter>() || input_node->isa<ValueNode>() ||
IsPrimitiveCNode(input_node, prim::kPrimStopGradient)) {
continue;
}
// Replace forward input node by its output value.
auto para_ref_size = manager->node_users()[paras[i]].size();
auto input_value = inputs_value[i];
if (para_ref_size > 0 && input_value.first != nullptr) {
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
auto input_value_node = NewValueNode(input_value.first);
input_value_node->set_has_new_value(true);
input_value_node->set_used_graph_count(para_ref_size);
manager->Replace(paras[i], input_value_node);
CNodePtr cnode_i = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode_i);
auto output_vnode_i = GenNewTensor(cnode_i);
MS_EXCEPTION_IF_NULL(output_vnode_i);
output_vnode_i->set_has_new_value(true);
manager->Replace(paras[i], output_vnode_i);
MS_LOG(DEBUG) << "Replace: " << paras[i]->DebugString() << " with " << output_vnode_i->ToString();
// Save forward input node when it used in bprop graph.
if (para_ref_size > 0 && !IsPrimitiveCNode(input_node, prim::kPrimUpdateState)) {
cnode_i->set_forward(output_vnode_i, "");
used_input_nodes.push_back(cnode_i);
MS_LOG(DEBUG) << "Input CNode has been used in grad graph: " << cnode_i->DebugString()
<< ", its output value: " << output_vnode_i->ToString();
}
}
MS_LOG(DEBUG) << "Start opt node" << fg->output()->DebugString(4);
auto res = std::make_shared<pipeline::Resource>();
res->set_manager(manager);
res->set_func_graph(fg);
PynativeElimOpt(res);
auto out = fg->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out);
auto c_input = out->input(1);
MS_EXCEPTION_IF_NULL(c_input);
if (!c_input->isa<ValueNode>()) {
return used_input_nodes;
}
void DFunctor::ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph) {
// The process of replacing forward node only works in pynative mode, when @ms_function is used.
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
!grad_exec->eliminate_forward()) {
return;
}
auto out_node = c_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(out_node);
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), true));

MS_EXCEPTION_IF_NULL(cnode_morph);
MS_LOG(DEBUG) << "Run replace for cnode morph: " << cnode_morph->DebugString(2);
// Get forward node and its fprop graph, bprop graph.
MS_EXCEPTION_IF_NULL(k_app);
CNodePtr forward_node = nullptr;
FuncGraphPtr bprop_graph = nullptr;
FuncGraphPtr fprop_graph = nullptr;
GetForwardOutNodeAndBpropGraph(k_app, &forward_node, &bprop_graph, &fprop_graph);
if (forward_node == nullptr || bprop_graph == nullptr || fprop_graph == nullptr) {
return;
}

// Replace forward node used in bprop graph by its output tensors. The same process for its input node.
auto used_forward_nodes = RunOutputReplace(forward_node, bprop_graph, fprop_graph, cnode_morph);
auto used_input_nodes = RunInputReplace(bprop_graph, fprop_graph, cnode_morph);

// Save used forward input and output nodes to func_graph.
auto ms_func_graph = cnode_morph->func_graph();
MS_EXCEPTION_IF_NULL(ms_func_graph);
ms_func_graph->set_used_forward_nodes(used_forward_nodes);
ms_func_graph->set_used_forward_nodes(used_input_nodes);
}

bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {


+ 9
- 4
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h View File

@@ -98,10 +98,15 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
// Replace the primal graph with k graph
void EliminatePrimalGraph();
// Pynative specialize
void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
ValuePtr GenNewTensorInner(const ValuePtr &value);
ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value,
bool need_replace_forward);
ValueNodePtr GenNewTensor(const CNodePtr &forward_node);
tensor::TensorPtr GenNewTensorInner(const TypePtr &type_elem, const BaseShapePtr &shape_elem);
void GetForwardOutNodeAndBpropGraph(const CNodePtr &k_app, CNodePtr *forward_node, FuncGraphPtr *bprop_graph,
FuncGraphPtr *fprop_graph);
std::vector<AnfNodePtr> RunOutputReplace(const CNodePtr &forward_node, const FuncGraphPtr &bprop_graph,
const FuncGraphPtr &fprop_graph, const CNodePtr &cnode_morph);
std::vector<AnfNodePtr> RunInputReplace(const FuncGraphPtr &bprop_graph, const FuncGraphPtr &fprop_graph,
const CNodePtr &cnode_morph);
void ReplaceEquivdout(const CNodePtr &k_app, const CNodePtr &cnode_morph);

std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.


+ 2
- 0
mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc View File

@@ -462,6 +462,7 @@ void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node)

namespace {
ValuePtr ShallowCopyValue(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<mindspore::tensor::Tensor>()) {
auto tensor_value = value->cast<mindspore::tensor::TensorPtr>();
return std::make_shared<mindspore::tensor::Tensor>(*tensor_value);
@@ -613,6 +614,7 @@ void KPynativeCellImpl::BuildAdjointForInput(const CNodePtr &cnode, const ValueP
}
forged_adjoint->users().push_back(cnode);
} else {
MS_EXCEPTION_IF_NULL(op_args[i - 1]);
auto input_adjoint =
std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, op_args[i - 1], FuncGraphPtr(nullptr));
(void)anfnode_to_adjoin_.insert(std::make_pair(input, input_adjoint));


+ 8
- 0
mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc View File

@@ -15,6 +15,7 @@
*/

#include "frontend/optimizer/irpass/gradient_eliminate.h"
#include "pipeline/pynative/pynative_execute.h"

namespace mindspore {
namespace opt {
@@ -74,6 +75,13 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
// ExpandJ innermost graph or primitive first.
std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo),
[](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); });
// Check whether need to eliminate forward cnodes in pynative mode.
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();
bool eliminate_forward = grad_exec->eliminate_forward();
grad_exec->set_eliminate_forward(eliminate_forward && todo.empty());
}
// Expand j nodes that don't have embed j nodes.
bool change = false;
auto manager = optimizer->manager();


+ 3
- 0
mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h View File

@@ -80,6 +80,7 @@ class PartialEliminater : public AnfVisitor {
(void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args));
TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info()));
auto new_node = node->func_graph()->NewCNode(args);
new_node->set_abstract(node->abstract());

// reorder the formal parameter of fg.
AnfNodePtrList new_params;
@@ -357,6 +358,7 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
}
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
auto new_node = old_cnode->func_graph()->NewCNode(args);
new_node->set_abstract(old_cnode->abstract());
return new_node;
}
};
@@ -445,6 +447,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
}
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
auto new_node = old_cnode->func_graph()->NewCNode(args);
new_node->set_abstract(old_cnode->abstract());
return new_node;
}
};


+ 94
- 4
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -26,6 +26,8 @@
#include "ir/func_graph_cloner.h"
#include "ir/param_info.h"
#include "ir/cell.h"
#include "parse/python_adapter.h"
#include "abstract/abstract_value.h"
#include "frontend/parallel/costmodel_context.h"
#include "frontend/parallel/context.h"
#include "pipeline/jit/pass.h"
@@ -34,17 +36,17 @@
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "pipeline/jit/static_analysis/order_enforce.h"
#include "pipeline/jit/static_analysis/remove_monad.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
#include "pipeline/jit/static_analysis/program_specialize.h"
#include "pipeline/jit/resource.h"
#include "utils/ms_context.h"
#include "pipeline/jit/remove_value_node_dup.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/optimizer.h"
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "utils/ms_context.h"
#include "vm/transform.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/parameter_server.h"
#include "ps/scheduler.h"
@@ -117,6 +119,50 @@ void ExecuteActionForMindRT(const ResourcePtr &res) {
});
res->results()[kOutput] = run;
}

// Modify the output node of func_graph to add forward nodes used in bprop graph.
void ModifyOutputNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
const auto &used_forward_nodes = func_graph->used_forward_nodes();

// Get original output node and abstract
auto original_output_node = func_graph->output();
MS_EXCEPTION_IF_NULL(original_output_node);
auto original_output_abs = original_output_node->abstract();
MS_EXCEPTION_IF_NULL(original_output_abs);

// Create a new make tuple node to hold all forward used nodes.
abstract::AbstractBasePtrList added_abs_list;
std::vector<AnfNodePtr> added_node_list{NewValueNode(prim::kPrimMakeTuple)};
std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(),
[&added_abs_list, &added_node_list](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
added_node_list.push_back(node);
added_abs_list.push_back(node->abstract());
});
AnfNodePtr added_output_node = nullptr;
AbstractBasePtr added_output_abs = nullptr;
if (added_abs_list.empty()) {
added_output_node = NewValueNode(MakeValue<int32_t>(1));
added_output_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(1));
} else {
added_output_node = func_graph->NewCNode(added_node_list);
added_output_abs = std::make_shared<abstract::AbstractTuple>(added_abs_list);
}
added_output_node->set_abstract(added_output_abs);
MS_LOG(DEBUG) << "Added output node info: " << added_output_node->DebugString();

// Merge original output node and used forward nodes to return node.
std::vector<AnfNodePtr> new_output_nodes{NewValueNode(prim::kPrimMakeTuple), original_output_node, added_output_node};
auto merge_node = func_graph->NewCNode(new_output_nodes);
abstract::AbstractBasePtrList new_output_abs{original_output_abs, added_output_abs};
merge_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_output_abs));
MS_LOG(DEBUG) << "Merge node info: " << merge_node->DebugString(2);
func_graph->set_output(merge_node);

// Clear
func_graph->ClearUsedForwardNodes();
}
} // namespace
using CompileGraphs = compile::CompileGraphs;
using abstract::AnalysisResult;
@@ -590,6 +636,47 @@ bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
return false;
}

bool EliminateForwardCNode(const ResourcePtr &res) {
// This function only works in Pynative mode. The func_graph is decorated by ms_function.
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
return true;
}

auto graph_executor = pipeline::ExecutorPy::GetInstance();
MS_EXCEPTION_IF_NULL(graph_executor);
auto phase = graph_executor->phase();
MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << phase;
// Export graph run in pynative mode no need to do this action.
if (phase.find("export") != std::string::npos) {
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();
grad_exec->set_eliminate_forward(true);
return true;
}

// Run grad process for func_graph and replace forward nodes with its output tensors.
MS_LOG(DEBUG) << "Run eliminate forward nodes action.";
MS_EXCEPTION_IF_NULL(res);
auto ms_func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(ms_func_graph);
auto pynative_exec = pynative::PynativeExecutor::GetInstance();
auto grad_exec = pynative_exec->grad_executor();
bool eliminate_forward = grad_exec->eliminate_forward();
grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty());
auto grad_graph = ad::Grad(ms_func_graph, res);
MS_EXCEPTION_IF_NULL(grad_graph);
graph_executor->SetGradGraph(grad_graph, phase);
ModifyOutputNode(ms_func_graph);

// Keep roots for only keeping forward func graph in resource.
auto manager = res->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->KeepRoots({ms_func_graph});

grad_exec->set_eliminate_forward(true);
return true;
}

bool TaskEmitAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
@@ -995,6 +1082,9 @@ std::vector<ActionItem> VmPipeline() {

(void)actions.emplace_back(std::make_pair("remove_monad_from_random_op", RemoveRandomOpMonadAction));

// eliminate forward cnode for grad graph
(void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));

(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
#if ((defined ENABLE_CPU) && (!defined _WIN32))
if (ps::PSContext::instance()->is_worker()) {


+ 2
- 0
mindspore/ccsrc/pipeline/jit/base.h View File

@@ -29,6 +29,8 @@ namespace mindspore {
namespace pipeline {
struct ExecutorInfo {
FuncGraphPtr func_graph;
// The grad graph of func_graph, it will create in PyNative mode when @ms_function is used.
FuncGraphPtr grad_graph;
ResourcePtr resource;
// The num of input data.
std::size_t arg_list_size;


+ 32
- 0
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -317,6 +317,38 @@ FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) {
return info_[phase]->func_graph;
}

FuncGraphPtr ExecutorPy::GetGradGraph(const std::string &phase) {
if (phase.empty()) {
MS_LOG(EXCEPTION) << "The input phase is empty.";
}
if (info_.count(phase) == 0) {
MS_LOG(EXCEPTION) << "No phase in executor:" << phase;
}

auto execute_info = info_[phase];
MS_EXCEPTION_IF_NULL(execute_info);
auto grad_graph = execute_info->grad_graph;
MS_EXCEPTION_IF_NULL(grad_graph);
return grad_graph;
}

void ExecutorPy::SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase) {
if (phase.empty()) {
MS_LOG(EXCEPTION) << "The input phase is empty.";
}
if (info_.count(phase) == 0) {
MS_LOG(EXCEPTION) << "No phase in executor: " << phase;
}

auto execute_info = info_[phase];
MS_EXCEPTION_IF_NULL(execute_info);
if (execute_info->grad_graph != nullptr) {
MS_LOG(DEBUG) << "The grad graph has existed, phase is: " << phase;
}
MS_EXCEPTION_IF_NULL(grad_graph);
execute_info->grad_graph = grad_graph;
}

compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) {
ResourcePtr res = GetResource(phase);
MS_EXCEPTION_IF_NULL(res);


+ 3
- 0
mindspore/ccsrc/pipeline/jit/pipeline.h View File

@@ -71,6 +71,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {

~ExecutorPy();

const std::string &phase() const { return phase_; }
void SaveCompiledGraph(const std::string &phase_s);
bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm,
const std::string &queue_name);
@@ -83,6 +84,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
py::object Run(const py::tuple &args, const py::object &phase);
ResourcePtr GetResource(const std::string &phase);
FuncGraphPtr GetFuncGraph(const std::string &phase);
FuncGraphPtr GetGradGraph(const std::string &phase);
void SetGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase);
py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type);
compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase);
bool HasCompiled(const std::string &phase) const;


+ 204
- 68
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -619,6 +619,76 @@ void ResetTopCellInfo(const TopCellInfoPtr &top_cell, const py::args &args) {
top_cell->set_input_args_id(input_args_id);
}

void CreateNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const OpExecInfoPtr &op_exec_info,
const py::object &added_out, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph) {
MS_EXCEPTION_IF_NULL(top_cell);
MS_EXCEPTION_IF_NULL(grad_graph);
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(ms_func_graph);
// Get Added forward nodes.
auto merge_node = ms_func_graph->output();
MS_EXCEPTION_IF_NULL(merge_node);
auto merge_make_tuple = merge_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(merge_make_tuple);
if (merge_make_tuple->size() != 3) {
MS_LOG(EXCEPTION) << "The input size of merge make tuple node should be 3, but it is: " << merge_make_tuple->size();
}
const auto &added_forward_node = merge_make_tuple->input(2);
MS_EXCEPTION_IF_NULL(added_forward_node);
if (added_forward_node->isa<ValueNode>()) {
MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_forward_node->DebugString();
std::vector<tensor::TensorPtr> total_output_tensors;
TensorValueToTensor(PyAttrValue(added_out), &total_output_tensors);
top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
return;
}
auto added_make_tuple = added_forward_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(added_make_tuple);
MS_LOG(DEBUG) << "The added forward make tuple node info: " << added_make_tuple->DebugString();

// Get Added forward output tensors when forward func graph has been ran.
std::vector<tensor::TensorPtr> total_output_tensors;
TensorValueToTensor(PyAttrValue(added_out), &total_output_tensors);
// Create new output tensors for forward nodes, it will also work in grad graph with same value node.
size_t index = 0;
for (size_t i = 1; i < added_make_tuple->size(); ++i) {
auto cnode = added_make_tuple->input(i)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << "Create New output tensor for cnode: " << cnode->DebugString();
auto output_vnode = cnode->forward().first;
MS_EXCEPTION_IF_NULL(output_vnode);
grad_graph->AddValueNode(output_vnode);
MS_LOG(DEBUG) << "Original output value node: " << output_vnode << " info: " << output_vnode->ToString();
// Create new tensor.
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
if (index + output_num > total_output_tensors.size()) {
MS_LOG(EXCEPTION) << "The size of total_output_tensors: " << total_output_tensors.size()
<< ", but the current index: " << index << ", output num: " << output_num;
}
std::vector<ValuePtr> new_values;
std::for_each(total_output_tensors.begin() + index, total_output_tensors.begin() + index + output_num,
[&new_values](const auto &tensor) { new_values.push_back(tensor); });
index = index + output_num;
// Create new value.
if (output_num == 1) {
output_vnode->set_value(new_values[0]);
} else if (output_num > 1) {
output_vnode->set_value(std::make_shared<ValueTuple>(new_values));
} else {
MS_LOG(EXCEPTION) << "The output value of forward cnode is empty, forward cnode info: " << cnode->ToString();
}
MS_LOG(DEBUG) << "New output value node: " << output_vnode << " info: " << output_vnode->ToString();
}

// Save op info with new tensors for current running ms_function func graph.
if (index != total_output_tensors.size()) {
MS_LOG(EXCEPTION) << "The index: " << index
<< " should be equal to the size of total_output_tensors: " << total_output_tensors.size();
}
top_cell->set_op_info_with_ms_func_forward_tensors(op_exec_info->op_info, total_output_tensors);
}

void SaveOpInfo(const TopCellInfoPtr &top_cell, const std::string &op_info,
const std::vector<tensor::TensorPtr> &op_out_tensors) {
MS_EXCEPTION_IF_NULL(top_cell);
@@ -636,6 +706,10 @@ void SaveOpInfo(const TopCellInfoPtr &top_cell, const std::string &op_info,

void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<tensor::TensorPtr> &pre_tensors) {
MS_EXCEPTION_IF_NULL(new_tensor);
if (pre_tensors.empty()) {
MS_LOG(EXCEPTION) << "The size of pre tensors is empty.";
}

auto device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
for (auto &pre_tensor : pre_tensors) {
MS_EXCEPTION_IF_NULL(pre_tensor);
@@ -761,6 +835,7 @@ void TopCellInfo::ClearDeviceMemory() {
}
for (const auto &elem : tensors_in_bprop_graph) {
MS_EXCEPTION_IF_NULL(elem);
MS_LOG(DEBUG) << "Clear device address for tensor: " << elem->ToString();
elem->set_device_address(nullptr);
}
}
@@ -785,9 +860,9 @@ void TopCellInfo::Clear() {
k_pynative_cell_ptr_ = nullptr;
graph_info_map_.clear();
sub_cell_list_.clear();
ms_function_grad_cache_.clear();
op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear();
}

void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info) {
@@ -1292,7 +1367,7 @@ AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &ob
auto abs = node->abstract();
ValuePtr out_obj = nullptr;
if (node->forward().first != nullptr) {
out_obj = node->forward().first;
out_obj = node->forward().first->value();
} else {
out_obj = PyAttrValue(obj);
}
@@ -1303,7 +1378,7 @@ AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &ob
node->add_input_value(out_obj, "");
node->add_input_value(MakeValue(idx), "");
out_obj = (*out_obj->cast<ValueTuplePtr>())[static_cast<size_t>(idx)];
node->set_forward(out_obj, "");
node->set_forward(NewValueNode(out_obj), "");
}
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[static_cast<size_t>(idx)];
@@ -1426,11 +1501,33 @@ void GradExecutor::DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr
}
}

void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
CNodePtr *ms_function_cnode) {
void GradExecutor::UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const py::object &added_out) {
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase();
auto new_forward_value = PyAttrValue(added_out);
MS_EXCEPTION_IF_NULL(new_forward_value);
MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
std::vector<tensor::TensorPtr> new_tensors;
TensorValueToTensor(new_forward_value, &new_tensors);
if (new_tensors.empty()) {
MS_LOG(DEBUG) << "The size of added forward tensors is zero, no need to update.";
return;
}

MS_EXCEPTION_IF_NULL(op_exec_info);
op_exec_info->op_inputs = args;
const auto &old_tensors = top_cell()->op_info_with_ms_func_forward_tensors().at(op_exec_info->op_info);
if (old_tensors.size() != new_tensors.size()) {
MS_LOG(EXCEPTION) << "The size of old tensors is: " << old_tensors.size()
<< ", but the size of new tensors is: " << new_tensors.size()
<< ", the current op info is: " << op_exec_info->op_info;
}
for (size_t i = 0; i < new_tensors.size(); ++i) {
UpdateTensorInfo(new_tensors[i], {old_tensors[i]});
old_tensors[i]->set_sync_status(kNeedSyncDeviceToHost);
}
}

void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
ValuePtrList *input_values, CNodePtr *ms_function_cnode) {
// Get input node info of ms_function
MS_EXCEPTION_IF_NULL(ms_func_graph);
std::vector<AnfNodePtr> input_nodes{NewValueNode(ms_func_graph)};
@@ -1446,25 +1543,37 @@ void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, con
(*input_values).emplace_back(inp_i_value);
}

// Get weights info of ms_function
// Get dfbuilder and graph info map
auto df_builder = GetDfbuilder(top_cell()->cell_id());
MS_EXCEPTION_IF_NULL(df_builder);
const auto &graph_info = top_cell()->graph_info_map().at(df_builder);
MS_EXCEPTION_IF_NULL(graph_info);
// Get weights info of ms_function
std::vector<AnfNodePtr> new_params;
auto manage = Manage(ms_func_graph, false);
for (const auto &anf_node : ms_func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(anf_node);
auto param = anf_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
if (param->has_default()) {
input_nodes.emplace_back(param);
auto default_value = param->default_param();
MS_EXCEPTION_IF_NULL(default_value);
(*input_values).emplace_back(default_value);
op_exec_info->op_inputs.append(default_value);
// Add weights to df_builder
SetParamNodeMapInGraphInfoMap(df_builder, param->name(), param);
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
<< default_value->ToString() << ". Its name is: " << param->name();
if (!param->has_default()) {
new_params.push_back(param);
continue;
}
if (graph_info->params.count(param->name())) {
// Share same weight parameter in different ms_function call.
auto same_param = graph_info->params.at(param->name());
manage->Replace(anf_node, same_param);
param = same_param;
}
new_params.push_back(param);
input_nodes.emplace_back(param);
(*input_values).emplace_back(param->default_param());
SetParamNodeMapInGraphInfoMap(df_builder, param->name(), param);
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
<< param->default_param()->ToString() << ". Its name is: " << param->name();
}
ms_func_graph->set_parameters(new_params);
manage->Clear();

// Make a CNode which includes ms_function fprop graph and inputs node
MS_EXCEPTION_IF_NULL(ms_function_cnode);
@@ -1473,28 +1582,23 @@ void GradExecutor::MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, con
}

// Make adjoint for ms_function fprop graph and connect it with previous op
void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &fprop_g,
const py::object &out, const py::args &args,
void GradExecutor::MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const py::object &actual_out, const py::args &args,
const std::string &graph_phase) {
ValuePtrList input_values;
CNodePtr ms_function_cnode = nullptr;
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
MakeCNodeForMsFunction(ms_func_graph, args, op_exec_info, &input_values, &ms_function_cnode);
MakeCNodeForMsFunction(ms_func_graph, args, &input_values, &ms_function_cnode);
MS_EXCEPTION_IF_NULL(ms_function_cnode);
SetTupleArgsToGraphInfoMap(curr_g_, out, ms_function_cnode);
SetNodeMapInGraphInfoMap(curr_g_, GetId(out), ms_function_cnode);
// Record ms_function cnode info and update forward tensors
op_exec_info->op_name = graph_phase;
RecordGradOpInfo(op_exec_info, out);
MS_LOG(DEBUG) << "Ms_function cnode op info: " << op_exec_info->op_info;
UpdateForwardTensorInfoInBpropGraph(op_exec_info, out);
// Add out and dout
auto out_value = parse::data_converter::PyDataToValue(out);
SetTupleArgsToGraphInfoMap(curr_g_, actual_out, ms_function_cnode);
SetNodeMapInGraphInfoMap(curr_g_, GetId(actual_out), ms_function_cnode);

// Connect grad graph of ms_function to context.
auto out_value = parse::data_converter::PyDataToValue(actual_out);
MS_EXCEPTION_IF_NULL(out_value);
// Do ad grad for ms_function cnode
auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, out_value, fprop_g)) {
MS_EXCEPTION_IF_NULL(grad_graph);
if (!k_pynative_cell_ptr->KPynativeWithFProp(ms_function_cnode, input_values, out_value, grad_graph)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode, ms_function cnode info: "
<< ms_function_cnode->DebugString();
}
@@ -2458,11 +2562,13 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
const std::vector<AnfNodePtr> &weights, size_t arg_size,
const py::args &args) {
bool build_formal_param = false;
if ((!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !cell_stack_.empty() && IsNestedGrad()) ||
top_cell()->ms_function_flag()) {
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !cell_stack_.empty() && IsNestedGrad()) {
build_formal_param = true;
need_renormalize_ = true;
}
if (top_cell()->ms_function_flag()) {
need_renormalize_ = true;
}

auto k_pynative_cell_ptr = top_cell()->k_pynative_cell_ptr();
MS_EXCEPTION_IF_NULL(k_pynative_cell_ptr);
@@ -2769,43 +2875,72 @@ void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
(void)top_cell_list_.erase(iter);
}

void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) {
// Get actual output value and added output value.
if (!py::isinstance<py::tuple>(out)) {
MS_LOG(EXCEPTION) << "The output value of ms_function func graph should be a tuple.";
}
auto tuple_out = py::cast<py::tuple>(out);
if (tuple_out.size() != 2) {
MS_LOG(EXCEPTION) << "The tuple size of output value of ms_function func graph should be 2.";
}
py::object actual_out = tuple_out[0];
py::object added_out = tuple_out[1];
MS_LOG(DEBUG) << "Added output value is: " << PyAttrValue(added_out)->ToString();

// Identity op info for current running ms_func graph.
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
op_exec_info->op_name = phase;
RecordGradOpInfo(op_exec_info, actual_out);
MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;

// Step 1: Update actual output tensors used in grad graph.
MS_LOG(DEBUG) << "ms_function actual output value: " << PyAttrValue(actual_out)->ToString();
UpdateForwardTensorInfoInBpropGraph(op_exec_info, actual_out);

// Step 2: Update output tensors of added forward nodes, which are added to return node of ms_function func graph.
if (top_cell()->op_info_with_ms_func_forward_tensors().count(op_exec_info->op_info)) {
UpdateMsFunctionForwardTensors(op_exec_info, added_out);
return;
}

MS_LOG(DEBUG) << "Ms func graph run firstly. The graph phase is: " << graph_phase();
if (!need_construct_graph()) {
MS_LOG(DEBUG) << "The grad flag is set to false or the cell stack is empty. No need to make grad for ms_function";
set_graph_phase("");
MS_LOG(EXCEPTION) << "The flag of need construct graph is False.";
}
CreateNewTensorsInGradGraph(top_cell(), op_exec_info, added_out, ms_func_graph, grad_graph);

// Clone new ms_function func graph and grad graph.
auto new_ms_func_graph = BasicClone(ms_func_graph);
auto new_grad_graph = BasicClone(grad_graph, true);
auto new_make_tuple = new_ms_func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(new_make_tuple);
new_ms_func_graph->set_output(new_make_tuple->input(1));

// Make Adjoint for grad graph
MakeAdjointForMsFunction(new_ms_func_graph, new_grad_graph, actual_out, args, phase);
}

void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
if (!grad_flag_) {
MS_LOG(DEBUG) << "Only run forward infer computation, no need to construct grad graph.";
return;
}
// Get ms_function graph by phase
if (graph_phase().empty()) {
MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain backend graph which is complied by ms_function";
MS_LOG(EXCEPTION) << "The graph phase is empty, can not obtain ms_function func graph.";
}
MS_LOG(DEBUG) << "Ms_function graph phase: " << graph_phase();
// Get fprop graph of ms_function

// Get ms_function func graph and grad graph.
const auto &phase = graph_phase();
MS_LOG(DEBUG) << "ms_function func graph phase: " << phase;
auto executor = pipeline::ExecutorPy::GetInstance();
MS_EXCEPTION_IF_NULL(executor);
FuncGraphPtr fprop_g = nullptr;
FuncGraphPtr ms_func_graph = nullptr;
const auto &ms_function_grad_cache = top_cell()->ms_function_grad_cache();
auto iter = ms_function_grad_cache.find(graph_phase());
if (iter == ms_function_grad_cache.end()) {
ms_func_graph = BasicClone(executor->GetFuncGraph(graph_phase()));
MS_EXCEPTION_IF_NULL(ms_func_graph);
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
res->set_func_graph(ms_func_graph);
res->manager()->AddFuncGraph(ms_func_graph, true);
fprop_g = ad::Grad(ms_func_graph, res, true);
MS_EXCEPTION_IF_NULL(fprop_g);
top_cell()->set_ms_function_grad_cache(graph_phase(), ms_func_graph, fprop_g);
} else {
ms_func_graph = iter->second.first;
MS_EXCEPTION_IF_NULL(ms_func_graph);
fprop_g = iter->second.second;
MS_EXCEPTION_IF_NULL(fprop_g);
}
DumpGraphIR("before_grad_ms_function.ir", ms_func_graph);
DumpGraphIR("after_grad_ms_function.ir", fprop_g);
// Make adjoint for fprop graph of ms function graph
MakeAdjointForMsFunction(ms_func_graph, fprop_g, out, args, graph_phase());
FuncGraphPtr ms_func_graph = executor->GetFuncGraph(phase);
MS_EXCEPTION_IF_NULL(ms_func_graph);
FuncGraphPtr grad_graph = executor->GetGradGraph(phase);
MS_EXCEPTION_IF_NULL(grad_graph);

GradMsFunctionInner(phase, out, args, ms_func_graph, grad_graph);
set_graph_phase("");
}

@@ -2824,9 +2959,10 @@ void GradExecutor::ClearRes() {
grad_order_ = 0;
top_cell_switch_counts_ = 0;
grad_flag_ = false;
need_renormalize_ = false;
grad_is_running_ = false;
enable_op_cache_ = true;
grad_is_running_ = false;
need_renormalize_ = false;
eliminate_forward_ = true;
top_cell_ = nullptr;
curr_g_ = nullptr;
bprop_cell_list_.clear();


+ 23
- 14
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -42,9 +42,9 @@

namespace mindspore::pynative {
namespace py = pybind11;
using MsFunctionGradCache = std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>>;
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
using OpInfoWithMsFuncForwardTensors = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;

py::object RealRunOp(const py::args &args);

@@ -103,10 +103,12 @@ class TopCellInfo {
void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
k_pynative_cell_ptr_ = k_pynative_cell_ptr;
}
const MsFunctionGradCache &ms_function_grad_cache() const { return ms_function_grad_cache_; }
void set_ms_function_grad_cache(const std::string &graph_phase, const FuncGraphPtr &func_graph,
const FuncGraphPtr &grad_graph) {
ms_function_grad_cache_[graph_phase] = std::make_pair(func_graph, grad_graph);
const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
return op_info_with_ms_func_forward_tensors_;
}
void set_op_info_with_ms_func_forward_tensors(const std::string &op_info,
const std::vector<tensor::TensorPtr> &forward_tensors) {
op_info_with_ms_func_forward_tensors_[op_info] = forward_tensors;
}
void ClearDeviceMemory();
void Clear();
@@ -132,7 +134,7 @@ class TopCellInfo {
std::unordered_set<std::string> sub_cell_list_;
OpInfoWithTensorId op_info_with_tensor_id_;
TensorIdWithTensorObject tensor_id_with_tensor_object_;
MsFunctionGradCache ms_function_grad_cache_;
OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
};
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;

@@ -189,17 +191,23 @@ class GradExecutor {
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
void DoOpGrad(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &op_out);
void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &fprop_g, const py::object &out,
const py::args &args, const std::string &graph_phase);
void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args,
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
// Construct grad graph for ms_function
bool eliminate_forward() const { return eliminate_forward_; }
void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
void GradMsFunction(const py::object &out, const py::args &args);
void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph);
void UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const py::object &added_out);
void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
const py::object &actual_out, const py::args &args, const std::string &graph_phase);
void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args, ValuePtrList *input_values,
CNodePtr *ms_function_cnode);
// Update forward tensors info
void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
py::object CheckGraph(const py::object &cell, const py::args &args);
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args);
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
void GradMsFunction(const py::object &out, const py::args &args);
void ClearGrad(const py::object &cell, const py::args &args);
void ClearRes();
void ClearCellRes(const std::string &cell_id = "");
@@ -239,7 +247,7 @@ class GradExecutor {
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
abstract::AbstractBasePtrList GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph);
// Manage resource for construct forward graph.
std::string &graph_phase() { return graph_phase_; }
const std::string &graph_phase() const { return graph_phase_; }
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
@@ -261,9 +269,10 @@ class GradExecutor {

private:
bool grad_flag_{false};
bool need_renormalize_{false};
bool grad_is_running_{false};
bool enable_op_cache_{true};
bool grad_is_running_{false};
bool need_renormalize_{false};
bool eliminate_forward_{true};
int custom_bprop_cell_count_{0};
size_t grad_order_{0};
size_t top_cell_switch_counts_{0};


+ 3
- 4
mindspore/common/api.py View File

@@ -209,8 +209,11 @@ class _MindSporeFunction:
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)
output = self._executor(tuple(new_inputs), phase)

if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_exec.set_graph_phase(phase)
_pynative_exec.grad_ms_function(output, *new_inputs)
output = output[0]
return output


@@ -274,17 +277,13 @@ def ms_function(fn=None, obj=None, input_signature=None):
def wrap_mindspore(func):
@wraps(func)
def staging_specialize(*args):
input_args = args
if obj is not None:
logger.warning("Obj is no longer in use, and the function's own object has been used to \
distinguish whether it has been compiled.")
process_obj = None
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
input_args = args[1:]
process_obj = args[0]
out = _MindSporeFunction(func, input_signature, process_obj)(*args)
if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_exec.grad_ms_function(out, *input_args)
return out

return staging_specialize


+ 3
- 3
mindspore/core/ir/anf.h View File

@@ -268,8 +268,8 @@ class MS_CORE_API CNode : public AnfNode, public EffectInfoHolder {
void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; }
const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; }

void set_forward(const ValuePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
const std::pair<ValuePtr, std::string> &forward() const { return output_value_; }
void set_forward(const ValueNodePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
const std::pair<ValueNodePtr, std::string> &forward() const { return output_value_; }

bool stop_gradient() const { return stop_gradient_; }
void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
@@ -357,7 +357,7 @@ class MS_CORE_API CNode : public AnfNode, public EffectInfoHolder {
// inputs_value_ store cnode input value and id in pynative mode
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
std::pair<ValuePtr, std::string> output_value_;
std::pair<ValueNodePtr, std::string> output_value_;
std::unordered_map<std::string, ValuePtr> attrs_;
std::unordered_map<std::string, ValuePtr> primal_attrs_;
std::vector<NodeDebugInfoPtr> primal_debug_infos_;


+ 7
- 0
mindspore/core/ir/func_graph.cc View File

@@ -735,6 +735,13 @@ bool FuncGraph::ContainMultiTarget() const {
return false;
}

void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes) {
std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(), [this](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
used_forward_nodes_.emplace(node);
});
}

size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
return ++fg_seen_generation;


+ 8
- 0
mindspore/core/ir/func_graph.h View File

@@ -405,6 +405,10 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
std::string bprop_hash() const { return bprop_hash_; }
void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }

const std::unordered_set<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
void set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes);
void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }

private:
// Only used for func_graph manager to control resource free.
int attached_mng_cnt() const { return attached_mng_cnt_; }
@@ -489,6 +493,10 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
bool dropped_ = false;
// If the graph is a bprop graph, it should has a hash of the bprop directory.
std::string bprop_hash_;

// If the graph is decorated by @ms_function and runs grad process in pynative mode,
// forward nodes used in grad graph will be added to output for holding output values.
std::unordered_set<AnfNodePtr> used_forward_nodes_;
};

inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {


+ 2
- 2
mindspore/core/ir/func_graph_cloner.cc View File

@@ -688,9 +688,9 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
}

FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) {
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes) {
MS_EXCEPTION_IF_NULL(func_graph);
Cloner cloner({func_graph}, false, true, true, std::make_shared<TraceCopy>(), nullptr);
Cloner cloner({func_graph}, clone_value_nodes, true, true, std::make_shared<TraceCopy>(), nullptr);
return cloner[func_graph];
}



+ 1
- 1
mindspore/core/ir/func_graph_cloner.h View File

@@ -129,7 +129,7 @@ ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &r

FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph);
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes = false);
} // namespace mindspore

#endif // MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_

+ 199
- 0
tests/st/pynative/ms_function/test_pynative_lenet_ms_function.py View File

@@ -0,0 +1,199 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import numpy as np
import pytest

import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.nn.optim import Momentum
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
from mindspore import context, Tensor, ParameterTuple
from mindspore.nn.wrap.cell_wrapper import WithLossCell
from mindspore.common.initializer import TruncatedNormal

np.random.seed(1)
grad_by_list = C.GradOperation(get_by_list=True)


def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)


class conv_relu_maxpool2d_1(nn.Cell):
def __init__(self):
super(conv_relu_maxpool2d_1, self).__init__()
self.weight_variable = weight_variable()
self.conv = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0,
weight_init=self.weight_variable, pad_mode="valid")
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)

@ms_function
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.max_pool2d(x)
return x


class conv_relu_maxpool2d_2(nn.Cell):
def __init__(self):
super(conv_relu_maxpool2d_2, self).__init__()
self.weight_variable = weight_variable()
self.conv = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0,
weight_init=self.weight_variable, pad_mode="valid")
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)

@ms_function
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.max_pool2d(x)
return x


class fc(nn.Cell):
def __init__(self):
super(fc, self).__init__()
self.weight_variable = weight_variable()
self.dense = nn.Dense(16 * 5 * 5, 120, self.weight_variable, self.weight_variable)

@ms_function
def construct(self, x):
x = self.dense(x)
return x


def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)


class LeNet(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes, Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""

def __init__(self, num_class=10):
super(LeNet, self).__init__()
self.num_class = num_class
self.batch_size = 32
self.conv_relu_maxpool2d_1 = conv_relu_maxpool2d_1()
self.conv_relu_maxpool2d_2 = conv_relu_maxpool2d_2()
self.fc1 = fc()
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()

def construct(self, x):
x = self.conv_relu_maxpool2d_1(x)
x = self.conv_relu_maxpool2d_2(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x


class CrossEntropyLoss(nn.Cell):
"""
Define loss for network
"""

def __init__(self):
super(CrossEntropyLoss, self).__init__()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.num = Tensor(32.0, mstype.float32)

def construct(self, logits, label):
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss = self.cross_entropy(logits, label)[0]
loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num)
return loss


class GradWrap(nn.Cell):
"""
GradWrap definition
"""

def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))

def construct(self, x, label):
weights = self.weights
return grad_by_list(self.network, weights)(x, label)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_lenet_ms_func():
context.set_context(mode=context.PYNATIVE_MODE)

epoch_size = 20
batch_size = 32
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
labels = Tensor(np.ones([batch_size]).astype(np.int32))

net = LeNet()
criterion = CrossEntropyLoss()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)

net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrap(net_with_criterion)
train_network.set_train()
total_time = 0

for epoch in range(0, epoch_size):
start_time = time.time()
fw_output = net(inputs)
loss_output = criterion(fw_output, labels)
grads = train_network(inputs, labels)
optimizer(grads)
end_time = time.time()
cost_time = end_time - start_time
total_time = total_time + cost_time

print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003

+ 189
- 5
tests/st/pynative/ms_function/test_pynative_ms_function.py View File

@@ -15,22 +15,206 @@
import pytest
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.common.api import ms_function
import mindspore.nn as nn
import mindspore.ops as P
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.common.api import ms_function
from mindspore.common import Parameter, ParameterTuple
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE)

@ms_function
def ConvBnReLU(x):
conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
relu = nn.ReLU()

x = conv(x)
x = bn(x)
x = relu(x)

return x

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function():
context.set_context(mode=context.PYNATIVE_MODE)
def test_call_single_func():
inputs = Tensor(np.ones([1, 1, 2, 2]).astype(np.float32))
out = ConvBnReLU(inputs)
assert np.allclose(out[0][0][0][0].asnumpy(), 3.9999797, 0.0001, 0.0001)
assert np.allclose(out[0][1][0][0].asnumpy(), 3.9999797, 0.0001, 0.0001)
grad = P.GradOperation(get_all=True, get_by_list=True, sens_param=False)
out_grad = grad(ConvBnReLU)(inputs)
assert np.allclose(out_grad[0][0][0][0][0][0].asnumpy(), 1.99998, 0.0001, 0.0001)


class CellConvBnReLU(nn.Cell):
def __init__(self):
super(CellConvBnReLU, self).__init__()
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
self.relu = nn.ReLU()

@ms_function
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_call_single_cell():
inputs = Tensor(np.ones([1, 1, 2, 2]).astype(np.float32))
# run forward
net = CellConvBnReLU()
out = net(inputs)
assert np.allclose(out[0][0][0][0].asnumpy(), 3.9999797, 0.0001, 0.0001)
assert np.allclose(out[0][1][0][0].asnumpy(), 3.9999797, 0.0001, 0.0001)
# run grad twice
grad = P.GradOperation(get_all=True, get_by_list=True, sens_param=False)
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
grad_first = grad(net, ParameterTuple(net.trainable_params()))(inputs)
assert np.allclose(grad_first[0][0][0][0][0][0].asnumpy(), 1.99998, 0.0001, 0.0001)
assert np.allclose(grad_first[1][0][0][0][0][0].asnumpy(), 0.99999, 0.0001, 0.0001)
assert np.allclose(grad_first[1][1][0].asnumpy(), 3.99997, 0.0001, 0.0001)
assert np.allclose(grad_first[1][2][0].asnumpy(), 1.00000, 0.0001, 0.0001)
optimizer(grad_first[1])
grad_second = grad(net, ParameterTuple(net.trainable_params()))(inputs)
assert np.allclose(grad_second[0][0][0][0][0][0].asnumpy(), 1.07999, 0.0001, 0.0001)
assert np.allclose(grad_second[1][0][0][0][0][0].asnumpy(), 0.59999, 0.0001, 0.0001)
assert np.allclose(grad_second[1][1][0].asnumpy(), 3.59998, 0.0001, 0.0001)
assert np.allclose(grad_second[1][2][0].asnumpy(), 1.00000, 0.0001, 0.0001)


class AddMulMul(nn.Cell):
def __init__(self):
super(AddMulMul, self).__init__()
self.param = Parameter(Tensor(0.5, ms.float32))

@ms_function
def construct(self, x):
x = x + x
x = x * self.param
x = x * x
return x


class CellCallSingleCell(nn.Cell):
def __init__(self):
super(CellCallSingleCell, self).__init__()
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
self.relu = nn.ReLU()
self.add_mul_mul = AddMulMul()

def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.add_mul_mul(x)
x = self.relu(x)
return x


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cell_call_cell():
inputs = Tensor(np.ones([1, 1, 2, 2]).astype(np.float32))
# run forward
net = CellCallSingleCell()
out = net(inputs)
assert np.allclose(out[0][0][0][0].asnumpy(), 15.9998, 0.0001, 0.0001)
assert np.allclose(out[0][1][0][0].asnumpy(), 15.9998, 0.0001, 0.0001)
# run grad twice
grad = P.GradOperation(get_all=True, get_by_list=True, sens_param=False)
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
grad_first = grad(net, ParameterTuple(net.trainable_params()))(inputs)
assert np.allclose(grad_first[0][0][0][0][0][0].asnumpy(), 16.0, 0.0001, 0.0001)
assert np.allclose(grad_first[1][0][0][0][0][0].asnumpy(), 8.0, 0.0001, 0.0001)
assert np.allclose(grad_first[1][1][0].asnumpy(), 3.1999e+01, 0.0001, 0.0001)
assert np.allclose(grad_first[1][2][0].asnumpy(), 7.9999e+00, 0.0001, 0.0001)
assert np.allclose(grad_first[1][3].asnumpy(), 127.999, 0.0001, 0.0001)
optimizer(grad_first[1])
grad_second = grad(net, ParameterTuple(net.trainable_params()))(inputs)
assert np.allclose(grad_second[0][0][0][0][0][0].asnumpy(), 2.726e+03, 1, 1)
assert np.allclose(grad_second[1][0][0][0][0][0].asnumpy(), 6.816e+03, 1, 1)
assert np.allclose(grad_second[1][1][0].asnumpy(), -2.477e+03, 1, 1)
assert np.allclose(grad_second[1][2][0].asnumpy(), -3.097e+03, 1, 1)
assert np.allclose(grad_second[1][3].asnumpy(), -1289, 1, 1)


class Mul(nn.Cell):
def __init__(self):
super(Mul, self).__init__()
self.param = Parameter(Tensor(1.5, ms.float32))

@ms_function
def construct(self, x):
x = x * self.param
return x


class CallSameFunc(nn.Cell):
def __init__(self):
super(CallSameFunc, self).__init__()
self.conv_bn_relu = CellConvBnReLU()
self.mul = Mul()

def construct(self, x):
x = self.mul(x)
x = self.mul(x)
x = self.mul(x)
x = self.conv_bn_relu(x)
return x


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_call_same_func():
inputs = Tensor(np.ones([1, 1, 2, 2]).astype(np.float32))
# run forward
net = CallSameFunc()
out = net(inputs)
assert np.allclose(out[0][0][0][0].asnumpy(), 13.4999, 0.0001, 0.0001)
assert np.allclose(out[0][1][0][0].asnumpy(), 13.4999, 0.0001, 0.0001)
# run grad twice
grad = P.GradOperation(get_all=True, get_by_list=True, sens_param=False)
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
grad_first = grad(net, ParameterTuple(net.trainable_params()))(inputs)
assert np.allclose(grad_first[0][0][0][0][0][0].asnumpy(), 6.75, 0.01, 0.01)
assert np.allclose(grad_first[1][0][0][0][0][0].asnumpy(), 3.375, 0.001, 0.001)
assert np.allclose(grad_first[1][1][0].asnumpy(), 13.4999, 0.0001, 0.0001)
assert np.allclose(grad_first[1][2][0].asnumpy(), 1.0000, 0.0001, 0.0001)
assert np.allclose(grad_first[1][3].asnumpy(), 54.0000, 0.0001, 0.0001)
optimizer(grad_first[1])
grad_second = grad(net, ParameterTuple(net.trainable_params()))(inputs)
assert np.allclose(grad_second[0][0][0][0][0][0].asnumpy(), 27.5, 0.1, 0.1)
assert np.allclose(grad_second[1][0][0][0][0][0].asnumpy(), 20.76, 0.01, 0.01)
assert np.allclose(grad_second[1][1][0].asnumpy(), -157, 1, 1)
assert np.allclose(grad_second[1][2][0].asnumpy(), 1.0000, 0.0001, 0.0001)
assert np.allclose(grad_second[1][3].asnumpy(), -84.6, 0.1, 0.1)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function():
class MsFunctionCell(nn.Cell):
def __init__(self):
super().__init__()


Loading…
Cancel
Save