|
|
|
@@ -746,19 +746,30 @@ Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
|
|
|
|
size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); |
|
|
|
size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size(); |
|
|
|
size_t s_dim = s.size(); |
|
|
|
// Do Broadcasting in the second tensor. |
|
|
|
if (second_tensor_dim < first_tensor_dim) { |
|
|
|
bool braoadcast_first_tensor = false; |
|
|
|
bool broadcast_first_tensor = false; |
|
|
|
// Push back the first tensor's strategy. |
|
|
|
stra.push_back(s); |
|
|
|
if (s_dim == first_tensor_dim) { |
|
|
|
stra.push_back(s); |
|
|
|
} else { |
|
|
|
Dimensions broadcast_revise_s(first_tensor_dim, 1); |
|
|
|
stra.push_back(broadcast_revise_s); |
|
|
|
} |
|
|
|
// Push back the second tensor's strategy after applying broadcast. |
|
|
|
stra.push_back(ApplyBroadcast(ops, iter_ops, s, second_tensor_dim, first_tensor_dim, braoadcast_first_tensor)); |
|
|
|
stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor)); |
|
|
|
} else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor. |
|
|
|
bool braoadcast_first_tensor = true; |
|
|
|
bool broadcast_first_tensor = true; |
|
|
|
// Push back the first tensor's strategy after applying broadcast. |
|
|
|
stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, braoadcast_first_tensor)); |
|
|
|
stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor)); |
|
|
|
// Push back the second tensor's strategy. |
|
|
|
stra.push_back(s); |
|
|
|
if (s_dim == second_tensor_dim) { |
|
|
|
stra.push_back(s); |
|
|
|
} else { |
|
|
|
Dimensions broadcast_revise_s(second_tensor_dim, 1); |
|
|
|
stra.push_back(broadcast_revise_s); |
|
|
|
} |
|
|
|
} else { // Broadcasting can be ignored or No broadcasting needs to be applied. |
|
|
|
stra = CheckDivisible(ops, iter_ops, s); |
|
|
|
} |
|
|
|
@@ -767,19 +778,25 @@ Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
} |
|
|
|
|
|
|
|
Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s, |
|
|
|
size_t target_tensor_dim, size_t refer_tensor_dim, bool braoadcast_first_tensor) { |
|
|
|
size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor) { |
|
|
|
Dimensions s_empty = {}; |
|
|
|
Dimensions s_broadcast; |
|
|
|
int target_tensor_index = 0; |
|
|
|
int refer_tensor_index = 0; |
|
|
|
size_t target_tensor_dim; |
|
|
|
size_t refer_tensor_dim; |
|
|
|
|
|
|
|
// Indexing target and refer tensor. |
|
|
|
if (braoadcast_first_tensor) { |
|
|
|
if (broadcast_first_tensor) { |
|
|
|
target_tensor_index = 0; |
|
|
|
refer_tensor_index = 1; |
|
|
|
target_tensor_dim = first_tensor_dim; |
|
|
|
refer_tensor_dim = second_tensor_dim; |
|
|
|
} else { |
|
|
|
target_tensor_index = 1; |
|
|
|
refer_tensor_index = 0; |
|
|
|
target_tensor_dim = second_tensor_dim; |
|
|
|
refer_tensor_dim = first_tensor_dim; |
|
|
|
} |
|
|
|
|
|
|
|
// When target tensor with an empty dim. |
|
|
|
|