From fcbd85d79d18b586f6d9280cb7e03a7f31e42f31 Mon Sep 17 00:00:00 2001 From: Sheng Date: Sat, 12 Sep 2020 20:09:22 +0800 Subject: [PATCH] check dim for refer tensor when broadcasting --- .../rec_core/rec_generate_strategy.cc | 33 ++++++++++++++----- .../rec_core/rec_generate_strategy.h | 2 +- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 46cb6a6e50..b519f2dc0c 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -746,19 +746,30 @@ Strategys CheckBroadcast(const std::vector> &ops, size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size(); + size_t s_dim = s.size(); // Do Broadcasting in the second tensor. if (second_tensor_dim < first_tensor_dim) { - bool braoadcast_first_tensor = false; + bool broadcast_first_tensor = false; // Push back the first tensor's strategy. - stra.push_back(s); + if (s_dim == first_tensor_dim) { + stra.push_back(s); + } else { + Dimensions broadcast_revise_s(first_tensor_dim, 1); + stra.push_back(broadcast_revise_s); + } // Push back the second tensor's strategy after applying broadcast. - stra.push_back(ApplyBroadcast(ops, iter_ops, s, second_tensor_dim, first_tensor_dim, braoadcast_first_tensor)); + stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor)); } else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor. - bool braoadcast_first_tensor = true; + bool broadcast_first_tensor = true; // Push back the first tensor's strategy after applying broadcast. - stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, braoadcast_first_tensor)); + stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor)); // Push back the second tensor's strategy. - stra.push_back(s); + if (s_dim == second_tensor_dim) { + stra.push_back(s); + } else { + Dimensions broadcast_revise_s(second_tensor_dim, 1); + stra.push_back(broadcast_revise_s); + } } else { // Broadcasting can be ignored or No broadcasting needs to be applied. stra = CheckDivisible(ops, iter_ops, s); } @@ -767,19 +778,25 @@ Strategys CheckBroadcast(const std::vector> &ops, } Dimensions ApplyBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s, - size_t target_tensor_dim, size_t refer_tensor_dim, bool braoadcast_first_tensor) { + size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor) { Dimensions s_empty = {}; Dimensions s_broadcast; int target_tensor_index = 0; int refer_tensor_index = 0; + size_t target_tensor_dim; + size_t refer_tensor_dim; // Indexing target and refer tensor. - if (braoadcast_first_tensor) { + if (broadcast_first_tensor) { target_tensor_index = 0; refer_tensor_index = 1; + target_tensor_dim = first_tensor_dim; + refer_tensor_dim = second_tensor_dim; } else { target_tensor_index = 1; refer_tensor_index = 0; + target_tensor_dim = second_tensor_dim; + refer_tensor_dim = first_tensor_dim; } // When target tensor with an empty dim. diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 2263deb588..0c034edd84 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -47,7 +47,7 @@ Strategys MakeRecSearchStrategy(const std::shared_ptr &graph, const size_t iter_ops); Strategys CheckBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s); Dimensions ApplyBroadcast(const std::vector> &ops, const size_t iter_ops, Dimensions s, - size_t target_tensor_dim, size_t refer_tensor_dim, bool braoadcast_first_tensor); + size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor); Strategys CheckDivisible(const std::vector> &ops, const size_t iter_ops, Dimensions s); Strategys MakeDataParallelStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph,