From 21686def1aa525a873e049f72b22383412d62c40 Mon Sep 17 00:00:00 2001 From: yangwei Date: Fri, 26 Mar 2021 10:16:44 +0800 Subject: [PATCH] support negative index --- mindspore/ccsrc/pipeline/jit/pass.cc | 24 ++++++++++--------- .../composite/multitype_ops/getitem_impl.py | 1 + 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 8e1f489758..41eb0fbdb4 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -155,17 +155,19 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_after_grad = opt::OptPassConfig({ irpass.inline_without_move_, }); - opt::OptPassConfig a_3 = opt::OptPassConfig({ - irpass.arithmetic_simplify2_, - irpass.same_eliminate_, - irpass.check_bprop_eliminate_, - irpass.switch_layer_defer_inline_, - irpass.replace_applicator_, - irpass.mirror_mini_step_elim_, - irpass.virtual_add_elim_, - irpass.row_tensor_add_zeros_like_, - irpass.mini_step_allgather_replace_, - }); + opt::OptPassConfig a_3 = opt::OptPassConfig( + { + irpass.arithmetic_simplify2_, + irpass.same_eliminate_, + irpass.check_bprop_eliminate_, + irpass.switch_layer_defer_inline_, + irpass.replace_applicator_, + irpass.mirror_mini_step_elim_, + irpass.virtual_add_elim_, + irpass.row_tensor_add_zeros_like_, + irpass.mini_step_allgather_replace_, + }, + false, true); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::irpass::ResolveIRPassLib resolve_irpass; diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 03b394e1e1..189b991f41 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -113,6 +113,7 @@ def _tuple_getitem_by_tensor(data, tensor_index): Outputs: Type, is the same as the element type of data. """ + tensor_index = F.select(tensor_index >= 0, tensor_index, tensor_index + len(data)) return _tuple_get_item_tensor(data, tensor_index)