| @@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||
| * │ └── MapPy | |||
| * ├── Tail | |||
| * ├── MakeTupleGradient | |||
| * ├── MakeListGradient | |||
| * ├── GradOperation | |||
| * └── TupleAdd | |||
| */ | |||
| @@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ | |||
| // do nothing | |||
| } else if (meta_func_graph->isa<prim::MakeTupleGradient>()) { | |||
| // do nothing | |||
| } else if (meta_func_graph->isa<prim::MakeListGradient>()) { | |||
| // do nothing | |||
| } else if (meta_func_graph->isa<prim::TupleAdd>()) { | |||
| // do nothing | |||
| } else if (meta_func_graph->isa<prim::TupleSlice>()) { | |||
| @@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg | |||
| return fg; | |||
| } | |||
| FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| int list_size = SizeToInt(args_spec_list.size()); | |||
| std::ostringstream ss; | |||
| ss << "▶make_list_" << list_size; | |||
| FuncGraphPtr fg = std::make_shared<FuncGraph>(); | |||
| fg->debug_info()->set_name(ss.str()); | |||
| std::vector<AnfNodePtr> params; | |||
| params.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| for (int i = 0; i < list_size; ++i) { | |||
| params.push_back(fg->add_parameter()); | |||
| } | |||
| // make fprob first result, maketuple's forward result. | |||
| AnfNodePtr out = fg->NewCNode(params); | |||
| // make fprob second result, maketuple's backward function. | |||
| FuncGraphPtr b = std::make_shared<FuncGraph>(); | |||
| ss.clear(); | |||
| ss << "◀make_list_" << list_size; | |||
| b->debug_info()->set_name(ss.str()); | |||
| AnfNodePtr dout = b->add_parameter(); | |||
| std::vector<AnfNodePtr> grads; | |||
| grads.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| grads.push_back(NewValueNode(newenv)); | |||
| for (int i = 0; i < list_size; ++i) { | |||
| grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)})); | |||
| } | |||
| b->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| b->set_output(b->NewCNode(grads)); | |||
| fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); | |||
| (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList)); | |||
| return fg; | |||
| } | |||
| GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) | |||
| : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { | |||
| if (get_by_list) { | |||
| @@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph { | |||
| }; | |||
| using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; | |||
| class MakeListGradient : public MetaFuncGraph { | |||
| public: | |||
| explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~MakeListGradient() override = default; | |||
| MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; } | |||
| }; | |||
| using MakeListGradientPtr = std::shared_ptr<MakeListGradient>; | |||
| class GradOperation : public MetaFuncGraph { | |||
| public: | |||
| explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, | |||
| @@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi | |||
| auto elem = GetValue<int>(e); | |||
| return elem; | |||
| }); | |||
| if (IntToSize(indices_shp[1]) != dense_shape_vec.size()) { | |||
| MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices " | |||
| << indices_shp[1] << ", but got " << dense_shape_vec.size(); | |||
| } | |||
| for (auto dense_shape_elem : dense_shape_vec) { | |||
| if (dense_shape_elem < 0) { | |||
| MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " | |||
| @@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { | |||
| return meta; | |||
| } | |||
| if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { | |||
| MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient"); | |||
| bprop_registry_meta_[prim::kPrimMakeList] = meta; | |||
| return meta; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; | |||
| } | |||
| @@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R | |||
| return fprop; | |||
| } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { | |||
| return nullptr; | |||
| } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { | |||
| return nullptr; | |||
| } | |||
| FuncGraphPtr bprop_fg = nullptr; | |||
| @@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) { | |||
| [](const AbstractAttribute &item) { return item.second; }); | |||
| return std::make_shared<AbstractTuple>(baselist); | |||
| } | |||
| return nullptr; | |||
| } | |||
| static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { | |||
| if (t == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (t->isa<AbstractList>()) { | |||
| auto abs_list = dyn_cast<AbstractList>(t); | |||
| return std::make_shared<AbstractTuple>(abs_list->elements()); | |||
| @@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||
| new_node = EraseMakeKeywordArgNode(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { | |||
| new_node = EraseExtractKeywordArg(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { | |||
| } | |||
| if (new_node != nullptr) { | |||
| new_node->set_abstract(node->abstract()); | |||
| MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); | |||
| (void)manager->Replace(node, new_node); | |||
| changed = true; | |||
| } | |||
| } | |||
| for (auto &node : manager->all_nodes()) { | |||
| auto ret = Reabs(node->abstract()); | |||
| if (ret) { | |||
| MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " | |||
| << ret->ToString(); | |||
| node->set_abstract(ret); | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(root); | |||
| bool changed = false; | |||
| // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var | |||
| AnfNodeSet all_node = manager->all_nodes(); | |||
| for (auto &node : all_node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| AnfNodePtr new_node = nullptr; | |||
| if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { | |||
| new_node = ConvertMakeListToMakeTuple(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { | |||
| new_node = ConvertListGetItemToTupleGetItem(cnode); | |||
| @@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||
| } | |||
| for (auto &node : manager->all_nodes()) { | |||
| auto ret = Reabs(node->abstract()); | |||
| auto ret = AdaptAbs(node->abstract()); | |||
| if (ret) { | |||
| MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " | |||
| << ret->ToString(); | |||
| @@ -32,6 +32,7 @@ namespace opt { | |||
| // Remove the class type from graphs | |||
| bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); | |||
| bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); | |||
| // Remove most uses of tuples from the graph | |||
| // tuples that are returned will be kept | |||
| @@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| bool CleanListPass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| bool changed = opt::CleanList(func_graph, res->manager()); | |||
| abstract::AbstractBasePtrList args_spec; | |||
| auto parameters = func_graph->parameters(); | |||
| (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), | |||
| [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); | |||
| if (changed) { | |||
| FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); | |||
| res->set_func_graph(new_fg); | |||
| } | |||
| res->set_args_spec(args_spec); | |||
| return true; | |||
| } | |||
| namespace { | |||
| OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig a_1 = opt::OptPassConfig({ | |||
| @@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| // Safe inlining | |||
| irpass.inline_, | |||
| irpass.sparse_tensor_eliminate_, | |||
| }); | |||
| opt::OptPassConfig a_2 = opt::OptPassConfig({ | |||
| irpass.merge_addn_, | |||
| @@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.make_ref_eliminate_, | |||
| irpass.get_ref_param_eliminate_, | |||
| irpass.indexed_slices_eliminate_, | |||
| irpass.sparse_tensor_eliminate_, | |||
| }); | |||
| OptPassGroupMap map({ | |||
| {"b_1", b_1}, | |||
| @@ -322,19 +340,23 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { | |||
| return true; | |||
| } | |||
| std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup}, | |||
| {"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"clean_list", CleanListPass}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"cconv", CconvPass}, | |||
| {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, | |||
| {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, | |||
| {"add_control_depend", AddControlDependPass}}; | |||
| std::vector<PassItem> kGePasses = { | |||
| {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, | |||
| {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, | |||
| {"cconv", CconvPass}}; | |||
| std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | |||
| {"opt_a", OptPassAGroup}, | |||
| {"clean_list", CleanListPass}, | |||
| {"opt_b", OptPassBGroup}, | |||
| {"add_control_depend", AddControlDependPass}, | |||
| {"opt_control", ControlGroup}, | |||
| {"opt_prepare", PrepareGroup}, | |||
| {"cconv", CconvPass}}; | |||
| std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; | |||
| } // namespace pipeline | |||
| @@ -22,6 +22,7 @@ | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "abstract/utils.h" | |||
| #include "debug/trace.h" | |||
| #include "utils/context/ms_context.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg | |||
| // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) | |||
| AbstractBasePtrList bparams; | |||
| bparams.push_back(SensitivityTransform(orig_func_)); | |||
| (void)std::transform( | |||
| args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), | |||
| [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool enable_sparse = context->enable_sparse(); | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), | |||
| [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { | |||
| if (enable_sparse && arg_spec->isa<AbstractTensor>()) { | |||
| return std::make_shared<AbstractUndetermined>(); | |||
| } | |||
| return SensitivityTransform(arg_spec); | |||
| }); | |||
| AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams); | |||
| AbstractFunctionPtr bprop = | |||
| std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final); | |||
| @@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout): | |||
| """Backpropagator for primitive `tuple_getitem`.""" | |||
| return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) | |||
| @bprops.register("list_getitem") | |||
| def bprop_list_getitem(data, idx, out, dout): | |||
| """Backpropagator for primitive `list_getitem`.""" | |||
| return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) | |||
| @bprops.register("identity") | |||
| def bprop_identity(x, out, dout): | |||
| @@ -17,6 +17,7 @@ | |||
| from functools import reduce | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore.ops import _selected_grad_ops as SG | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| @@ -33,6 +34,7 @@ shape_op = P.Shape() | |||
| reduce_sum = P.ReduceSum() | |||
| reshape = P.Reshape() | |||
| tile = P.Tile() | |||
| is_sub_class = P.IsSubClass() | |||
| def binop_grad_common(x, y, dx, dy): | |||
| @@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self): | |||
| """Generate bprop for AddN""" | |||
| def bprop(x, out, dout): | |||
| if is_sub_class(F.typeof(x), ms.list_): | |||
| dx = [] | |||
| for _ in range(len(x)): | |||
| dx.append(dout) | |||
| return (dx,) | |||
| dx = () | |||
| for _ in range(len(x)): | |||
| dx = dx + (dout,) | |||
| @@ -16,6 +16,7 @@ import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.composite as C | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| @@ -45,3 +46,17 @@ def test_net(): | |||
| add = Net() | |||
| output = add(x, y) | |||
| assert output == expect | |||
| def test_grad_addn_with_list(): | |||
| grad_op = C.GradOperation('get_all', get_all=True) | |||
| class AddN(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.add_n = P.AddN() | |||
| def construct(self, a, b): | |||
| return self.add_n([a, b]) | |||
| inp = Tensor(np.ones([128, 96]).astype(np.float32)) | |||
| grad_op(AddN())(inp, inp) | |||
| @@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| grad = grad_all(self.network)(x, y) | |||
| return grad, grad[0], grad[1] | |||
| return grad[0].indices(), grad[0].values(), grad[0].dense_shape() | |||
| class SparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(SparseGatherV2, self).__init__() | |||
| @@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||
| weights = self.weights | |||
| grad = grad_by_list(self.network, weights)(x) | |||
| x = grad[0] | |||
| return x, x.values(), x.indices(), x.dense_shape() | |||
| return x.values(), x.indices(), x.dense_shape() | |||
| class SparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(SparseGatherV2, self).__init__() | |||
| @@ -18,6 +18,9 @@ | |||
| @Date : 2020-07-16 | |||
| @Desc : test mindspore sparse_tensor's operation | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import composite as C | |||
| @@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context | |||
| context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) | |||
| class MakeSparseTensor(nn.Cell): | |||
| def __init__(self, dense_shape): | |||
| super(MakeSparseTensor, self).__init__() | |||
| self.dense_shape = dense_shape | |||
| def construct(self, indices, values): | |||
| ret = (SparseTensor(indices, values, self.dense_shape),) | |||
| return ret[0] | |||
| def test_sparse_tensor_make_sparse_tensor(): | |||
| class MakeSparseTensor(nn.Cell): | |||
| def __init__(self): | |||
| super(MakeSparseTensor, self).__init__() | |||
| self.dense_shape = (3, 4) | |||
| def construct(self, indices, values): | |||
| ret = (SparseTensor(indices, values, self.dense_shape),) | |||
| return ret[0] | |||
| indices = Tensor([[0, 1], [1, 2]]) | |||
| values = Tensor([1, 2], dtype=ms.float32) | |||
| MakeSparseTensor()(indices, values) | |||
| MakeSparseTensor((3, 4))(indices, values) | |||
| def test_sparse_tensor_attr(): | |||
| @@ -59,3 +65,20 @@ def test_sparse_tensor_attr(): | |||
| indices = Tensor([[0, 1], [1, 2]]) | |||
| values = Tensor([1, 2], dtype=ms.float32) | |||
| SparseTensorGetAttr()(indices, values) | |||
| grad_op(SparseTensorGetAttr())(indices, values) | |||
| def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim(): | |||
| indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32)) | |||
| values = Tensor(np.array([100, 200], dtype=np.float32)) | |||
| dense_shape = (2, 2) | |||
| with pytest.raises(TypeError): | |||
| MakeSparseTensor(dense_shape)(indices, values) | |||
| def test_sparse_tensor_indices_dim_less_than_dense_shape_dim(): | |||
| indices = Tensor(np.array([[0, 0], [0, 1]], dtype=np.int32)) | |||
| values = Tensor(np.array([100, 200], dtype=np.float32)) | |||
| dense_shape = (2, 2, 2) | |||
| with pytest.raises(TypeError): | |||
| MakeSparseTensor(dense_shape)(indices, values) | |||