Browse Source

!31423 Shallow copy for sens args when run grad graph in pynative mode.

Merge pull request !31423 from JoyLvliang/shallow_copy_tensor_value
r1.7
i-robot Gitee 4 years ago
parent
commit
5b83c9e003
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 53 additions and 21 deletions
  1. +2
    -20
      mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc
  2. +2
    -0
      mindspore/ccsrc/include/common/utils/convert_utils.h
  3. +30
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  4. +1
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  5. +18
    -0
      mindspore/ccsrc/utils/convert_utils.cc

+ 2
- 20
mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc View File

@@ -469,24 +469,6 @@ void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node)
}
}

namespace {
ValuePtr ShallowCopyValue(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<mindspore::tensor::Tensor>()) {
auto tensor_value = value->cast<mindspore::tensor::TensorPtr>();
return std::make_shared<mindspore::tensor::Tensor>(*tensor_value);
} else if (value->isa<ValueTuple>()) {
std::vector<ValuePtr> values;
auto value_tuple = value->cast<ValueTuplePtr>();
(void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
[](const ValuePtr &elem) { return ShallowCopyValue(elem); });
return std::make_shared<ValueTuple>(values);
} else {
return value;
}
}
} // namespace

PynativeAdjointPtr KPynativeCellImpl::ForgeGetItemAdjoint(const CNodePtr &cnode) {
if (cnode->size() != 3) {
MS_LOG(EXCEPTION) << "TupleGetItem/ListGetItem CNode should have 3 inputs, but CNode: " << cnode->DebugString();
@@ -642,8 +624,8 @@ bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &
// is not used in bprop_fg;
ValuePtrList cloned_op_args;
(void)std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args),
[](const ValuePtr &value) { return ShallowCopyValue(value); });
ValuePtr cloned_out = ShallowCopyValue(out);
[](const ValuePtr &value) { return ShallowCopyTensorValue(value); });
ValuePtr cloned_out = ShallowCopyTensorValue(out);
PynativeAdjointPtr cnode_adjoint;
if (fg_type == PynativeAdjoint::FuncGraphType::kBackwardPropagate) {
auto optimized_bprop_fg = OptimizeBPropFuncGraph(fg, cnode, cloned_op_args, cloned_out);


+ 2
- 0
mindspore/ccsrc/include/common/utils/convert_utils.h View File

@@ -79,6 +79,8 @@ std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) {

COMMON_EXPORT void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);

COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);

COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);

// sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor}


+ 30
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -2843,6 +2843,33 @@ std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_pos
MS_LOG(EXCEPTION) << "Grad position only support tuple.";
}

void GradExecutor::ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) {
if (!has_sens) {
return;
}
// Get index and number of sens args.
size_t sens_index = input_args.size() - 1;
size_t sens_num = 1;
if (py::isinstance<py::tuple>(input_args[sens_index])) {
py::tuple tuple_sens = py::cast<py::tuple>(input_args[sens_index]);
sens_num = ConvertArgs(tuple_sens).size();
}
// Shallow copy sens args to new sens args.
MS_EXCEPTION_IF_NULL(run_args);
for (size_t i = sens_index; i < sens_index + sens_num; ++i) {
const auto &original_sens = (*run_args)[i];
if (utils::isa<ValuePtr>(original_sens)) {
auto sens_value = utils::cast<ValuePtr>(original_sens);
MS_EXCEPTION_IF_NULL(sens_value);
auto new_sens_value = ShallowCopyTensorValue(sens_value);
MS_EXCEPTION_IF_NULL(new_sens_value);
MS_LOG(DEBUG) << "sens args [" << sens_value->ToString() << "] has been shallow copied to ["
<< new_sens_value->ToString() << "].";
(*run_args)[i] = new_sens_value;
}
}
}

void GradExecutor::UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph) {
MS_EXCEPTION_IF_NULL(bprop_graph);
const auto &bprop_params = bprop_graph->parameters();
@@ -3061,8 +3088,10 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p
MS_LOG(DEBUG) << "Run resource ptr " << resource.get();

VectorRef arg_list;
py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens));
auto filter_args = FilterTensorArgs(args, has_sens);
py::tuple converted_args = ConvertArgs(filter_args);
pipeline::ProcessVmArgInner(converted_args, resource, &arg_list);
ShallowCopySensValue(filter_args, has_sens, &arg_list);
MS_LOG(DEBUG) << "Convert args size " << converted_args.size() << ", graph param size " << arg_list.size();
compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
MS_EXCEPTION_IF_NULL(run);


+ 1
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -270,6 +270,7 @@ class GradExecutor {
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
void UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph);
std::vector<size_t> GetGradPositionArgs(const py::object &grad_position);
void ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args);
// Manage resource for construct forward graph.
const std::string &graph_phase() const { return graph_phase_; }
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);


+ 18
- 0
mindspore/ccsrc/utils/convert_utils.cc View File

@@ -302,6 +302,24 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *
}
}

ValuePtr ShallowCopyTensorValue(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor_value = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor_value);
return std::make_shared<tensor::Tensor>(*tensor_value);
} else if (value->isa<ValueTuple>()) {
std::vector<ValuePtr> values;
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
(void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
[](const ValuePtr &elem) { return ShallowCopyTensorValue(elem); });
return std::make_shared<ValueTuple>(values);
} else {
return value;
}
}

size_t CountValueNum(const ValueTuplePtr &value_tuple) {
MS_EXCEPTION_IF_NULL(value_tuple);
size_t cnt = 0;


Loading…
Cancel
Save