diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 2412ab8e18..5cb4209698 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -635,7 +635,7 @@ def check_input_data(*data, data_class): f' either a single' f' or a list of {data_class.__name__},' f' but got part data type is {str(type(item))}.') - if item.size() == 0: + if hasattr(item, "size") and item.size() == 0: msg = "Please provide non-empty data." logger.error(msg) raise ValueError(msg) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index ee2c18be46..f6cd82cdd6 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -724,7 +724,7 @@ class Model: Batch data should be put together in one tensor. Args: - predict_data (Tensor): Tensor of predict data. can be array, list or tuple. + predict_data: The predict data, can be array, number, str, dict, list or tuple. Returns: Tensor, array(s) of predictions. @@ -735,7 +735,7 @@ class Model: >>> result = model.predict(input_data) """ self._predict_network.set_train(False) - check_input_data(*predict_data, data_class=Tensor) + check_input_data(*predict_data, data_class=(int, float, str, tuple, list, dict, Tensor)) _parallel_predict_check() result = self._predict_network(*predict_data) diff --git a/tests/ut/python/train/test_training.py b/tests/ut/python/train/test_training.py index 184077195e..c92ff5010b 100644 --- a/tests/ut/python/train/test_training.py +++ b/tests/ut/python/train/test_training.py @@ -214,8 +214,8 @@ def test_model_build_abnormal_string(): err = False try: model.predict('aaa') - except ValueError as e: - log.error("Find value error: %r ", e) + except TypeError as e: + log.error("Find type error: %r ", e) err = True finally: assert err