From 236bfb75e3b756ded916ccf2ebb8d11f19ba7b32 Mon Sep 17 00:00:00 2001 From: Payne Date: Mon, 21 Dec 2020 02:31:06 +0800 Subject: [PATCH] change the int32 restrict to int --- .../multitype_ops/_constexpr_utils.py | 4 +- tests/ut/python/ops/test_int64_support.py | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 tests/ut/python/ops/test_int64_support.py diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index ddfdfc6927..aae375072e 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -348,7 +348,7 @@ def get_index_tensor_dtype(dtype): def check_index_tensors_dtype(dtypes, op_name): """Check a tuple of tensor data type.""" for ele in dtypes: - if not ele == mstype.int32: + if not ele in mstype.int_type: raise IndexError(f"For '{op_name}', the all index tensor " f"data types should be mstype.int32, but got {dtypes}.") return True @@ -357,7 +357,7 @@ def check_index_tensors_dtype(dtypes, op_name): @constexpr def check_index_tensor_dtype(dtype, op_name): """Check a tensor data type.""" - if dtype == mstype.int32: + if dtype in mstype.int_type: return True raise IndexError( f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") diff --git a/tests/ut/python/ops/test_int64_support.py b/tests/ut/python/ops/test_int64_support.py new file mode 100644 index 0000000000..e083b81583 --- /dev/null +++ b/tests/ut/python/ops/test_int64_support.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_int64_support """ +import numpy as np +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +import mindspore as ms + + +def test_parser_support_int64_normal_graph(): + """ test tensor index support int64 -index, graph mode""" + class Net(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, inputs, tensor_in): + result = inputs[tensor_in] + return result + + context.set_context(mode=context.GRAPH_MODE) + input_np_x = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_x = Tensor(input_np_x, ms.float32) + input_np_y = np.random.randint(2, size=[1, 2]).astype(np.int64) + tensor = Tensor(input_np_y, ms.int64) + net = Net() + net(input_me_x, tensor).asnumpy()