|
|
@@ -265,8 +265,10 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): |
|
|
tensor_index_shape = hyper_map(F.shape, tuple_index) |
|
|
tensor_index_shape = hyper_map(F.shape, tuple_index) |
|
|
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) |
|
|
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) |
|
|
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) |
|
|
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) |
|
|
indices = pack(broadcast_tensors) |
|
|
|
|
|
indices = F.cast(indices, mstype.int64) |
|
|
|
|
|
|
|
|
new_broadcast_tensors = () |
|
|
|
|
|
for tensor in broadcast_tensors: |
|
|
|
|
|
new_broadcast_tensors += (F.cast(tensor, mstype.int64),) |
|
|
|
|
|
indices = pack(new_broadcast_tensors) |
|
|
return indices |
|
|
return indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|