| @@ -469,24 +469,6 @@ void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) | |||
| } | |||
| } | |||
| namespace { | |||
| ValuePtr ShallowCopyValue(const ValuePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<mindspore::tensor::Tensor>()) { | |||
| auto tensor_value = value->cast<mindspore::tensor::TensorPtr>(); | |||
| return std::make_shared<mindspore::tensor::Tensor>(*tensor_value); | |||
| } else if (value->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> values; | |||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||
| (void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values), | |||
| [](const ValuePtr &elem) { return ShallowCopyValue(elem); }); | |||
| return std::make_shared<ValueTuple>(values); | |||
| } else { | |||
| return value; | |||
| } | |||
| } | |||
| } // namespace | |||
| PynativeAdjointPtr KPynativeCellImpl::ForgeGetItemAdjoint(const CNodePtr &cnode) { | |||
| if (cnode->size() != 3) { | |||
| MS_LOG(EXCEPTION) << "TupleGetItem/ListGetItem CNode should have 3 inputs, but CNode: " << cnode->DebugString(); | |||
| @@ -642,8 +624,8 @@ bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList & | |||
| // is not used in bprop_fg; | |||
| ValuePtrList cloned_op_args; | |||
| (void)std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args), | |||
| [](const ValuePtr &value) { return ShallowCopyValue(value); }); | |||
| ValuePtr cloned_out = ShallowCopyValue(out); | |||
| [](const ValuePtr &value) { return ShallowCopyTensorValue(value); }); | |||
| ValuePtr cloned_out = ShallowCopyTensorValue(out); | |||
| PynativeAdjointPtr cnode_adjoint; | |||
| if (fg_type == PynativeAdjoint::FuncGraphType::kBackwardPropagate) { | |||
| auto optimized_bprop_fg = OptimizeBPropFuncGraph(fg, cnode, cloned_op_args, cloned_out); | |||
| @@ -79,6 +79,8 @@ std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) { | |||
| COMMON_EXPORT void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors); | |||
| COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value); | |||
| COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple); | |||
| // sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor} | |||
| @@ -2843,6 +2843,33 @@ std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_pos | |||
| MS_LOG(EXCEPTION) << "Grad position only support tuple."; | |||
| } | |||
| void GradExecutor::ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) { | |||
| if (!has_sens) { | |||
| return; | |||
| } | |||
| // Get index and number of sens args. | |||
| size_t sens_index = input_args.size() - 1; | |||
| size_t sens_num = 1; | |||
| if (py::isinstance<py::tuple>(input_args[sens_index])) { | |||
| py::tuple tuple_sens = py::cast<py::tuple>(input_args[sens_index]); | |||
| sens_num = ConvertArgs(tuple_sens).size(); | |||
| } | |||
| // Shallow copy sens args to new sens args. | |||
| MS_EXCEPTION_IF_NULL(run_args); | |||
| for (size_t i = sens_index; i < sens_index + sens_num; ++i) { | |||
| const auto &original_sens = (*run_args)[i]; | |||
| if (utils::isa<ValuePtr>(original_sens)) { | |||
| auto sens_value = utils::cast<ValuePtr>(original_sens); | |||
| MS_EXCEPTION_IF_NULL(sens_value); | |||
| auto new_sens_value = ShallowCopyTensorValue(sens_value); | |||
| MS_EXCEPTION_IF_NULL(new_sens_value); | |||
| MS_LOG(DEBUG) << "sens args [" << sens_value->ToString() << "] has been shallow copied to [" | |||
| << new_sens_value->ToString() << "]."; | |||
| (*run_args)[i] = new_sens_value; | |||
| } | |||
| } | |||
| } | |||
| void GradExecutor::UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph) { | |||
| MS_EXCEPTION_IF_NULL(bprop_graph); | |||
| const auto &bprop_params = bprop_graph->parameters(); | |||
| @@ -3061,8 +3088,10 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p | |||
| MS_LOG(DEBUG) << "Run resource ptr " << resource.get(); | |||
| VectorRef arg_list; | |||
| py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens)); | |||
| auto filter_args = FilterTensorArgs(args, has_sens); | |||
| py::tuple converted_args = ConvertArgs(filter_args); | |||
| pipeline::ProcessVmArgInner(converted_args, resource, &arg_list); | |||
| ShallowCopySensValue(filter_args, has_sens, &arg_list); | |||
| MS_LOG(DEBUG) << "Convert args size " << converted_args.size() << ", graph param size " << arg_list.size(); | |||
| compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>(); | |||
| MS_EXCEPTION_IF_NULL(run); | |||
| @@ -270,6 +270,7 @@ class GradExecutor { | |||
| std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); | |||
| void UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph); | |||
| std::vector<size_t> GetGradPositionArgs(const py::object &grad_position); | |||
| void ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args); | |||
| // Manage resource for construct forward graph. | |||
| const std::string &graph_phase() const { return graph_phase_; } | |||
| AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); | |||
| @@ -302,6 +302,24 @@ void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> * | |||
| } | |||
| } | |||
| ValuePtr ShallowCopyTensorValue(const ValuePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<tensor::Tensor>()) { | |||
| auto tensor_value = value->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor_value); | |||
| return std::make_shared<tensor::Tensor>(*tensor_value); | |||
| } else if (value->isa<ValueTuple>()) { | |||
| std::vector<ValuePtr> values; | |||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||
| (void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values), | |||
| [](const ValuePtr &elem) { return ShallowCopyTensorValue(elem); }); | |||
| return std::make_shared<ValueTuple>(values); | |||
| } else { | |||
| return value; | |||
| } | |||
| } | |||
| size_t CountValueNum(const ValueTuplePtr &value_tuple) { | |||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||
| size_t cnt = 0; | |||