Browse Source

getitem pack debug

tags/v1.2.0-rc1
yepei6 4 years ago
parent
commit
19859b74f5
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      mindspore/ops/composite/multitype_ops/_compile_utils.py

+ 4
- 2
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -265,8 +265,10 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
indices = F.cast(indices, mstype.int64)
new_broadcast_tensors = ()
for tensor in broadcast_tensors:
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
indices = pack(new_broadcast_tensors)
return indices




Loading…
Cancel
Save