Browse Source

support non tensor inputs in model predict

tags/v1.1.0
buxue 5 years ago
parent
commit
f90f77c16e
3 changed files with 5 additions and 5 deletions
  1. +1
    -1
      mindspore/_checkparam.py
  2. +2
    -2
      mindspore/train/model.py
  3. +2
    -2
      tests/ut/python/train/test_training.py

+ 1
- 1
mindspore/_checkparam.py View File

@@ -635,7 +635,7 @@ def check_input_data(*data, data_class):
f' either a single' f' either a single'
f' or a list of {data_class.__name__},' f' or a list of {data_class.__name__},'
f' but got part data type is {str(type(item))}.') f' but got part data type is {str(type(item))}.')
if item.size() == 0:
if hasattr(item, "size") and item.size() == 0:
msg = "Please provide non-empty data." msg = "Please provide non-empty data."
logger.error(msg) logger.error(msg)
raise ValueError(msg) raise ValueError(msg)


+ 2
- 2
mindspore/train/model.py View File

@@ -724,7 +724,7 @@ class Model:
Batch data should be put together in one tensor. Batch data should be put together in one tensor.


Args: Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
predict_data: The predict data, can be array, number, str, dict, list or tuple.


Returns: Returns:
Tensor, array(s) of predictions. Tensor, array(s) of predictions.
@@ -735,7 +735,7 @@ class Model:
>>> result = model.predict(input_data) >>> result = model.predict(input_data)
""" """
self._predict_network.set_train(False) self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor)
check_input_data(*predict_data, data_class=(int, float, str, tuple, list, dict, Tensor))
_parallel_predict_check() _parallel_predict_check()
result = self._predict_network(*predict_data) result = self._predict_network(*predict_data)




+ 2
- 2
tests/ut/python/train/test_training.py View File

@@ -214,8 +214,8 @@ def test_model_build_abnormal_string():
err = False err = False
try: try:
model.predict('aaa') model.predict('aaa')
except ValueError as e:
log.error("Find value error: %r ", e)
except TypeError as e:
log.error("Find type error: %r ", e)
err = True err = True
finally: finally:
assert err assert err


Loading…
Cancel
Save