| @@ -135,8 +135,8 @@ def validate_condition(search_condition): | |||
| sorted_type_param = ['ascending', 'descending', None] | |||
| if "sorted_type" in search_condition: | |||
| if "sorted_name" not in search_condition: | |||
| log.error("The sorted_name have to exist when sorted_type exists.") | |||
| raise LineageParamValueError("The sorted_name have to exist when sorted_type exists.") | |||
| log.error("The sorted_name must exist when sorted_type exists.") | |||
| raise LineageParamValueError("The sorted_name must exist when sorted_type exists.") | |||
| if search_condition.get("sorted_type") not in sorted_type_param: | |||
| err_msg = "The sorted_type must be ascending or descending." | |||
| @@ -395,7 +395,7 @@ class TestModelApi(TestCase): | |||
| } | |||
| self.assertRaisesRegex( | |||
| LineageSearchConditionParamError, | |||
| 'The sorted_name have to exist when sorted_type exists.', | |||
| 'The sorted_name must exist when sorted_type exists.', | |||
| filter_summary_lineage, | |||
| LINEAGE_DATA_MANAGER, | |||
| search_condition | |||
| @@ -17,14 +17,15 @@ from unittest import TestCase | |||
| from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError | |||
| from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter | |||
| from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| from mindinsight.lineagemgr.common.validator.validate import \ | |||
| validate_search_model_condition, validate_condition, validate_train_id | |||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError | |||
| class TestValidateSearchModelCondition(TestCase): | |||
| """Test the mothod of validate_search_model_condition.""" | |||
| def test_validate_search_model_condition_param_type_error(self): | |||
| """Test the mothod of validate_search_model_condition with LineageParamTypeError.""" | |||
| """Test the method of validate_search_model_condition with LineageParamTypeError.""" | |||
| condition = { | |||
| 'summary_dir': 'xxx' | |||
| } | |||
| @@ -282,3 +283,65 @@ class TestValidateSearchModelCondition(TestCase): | |||
| condition (dict): The parameter of search condition. | |||
| """ | |||
| self._assert_raise(LineageParamTypeError, msg, condition) | |||
| def test_validate_condition(self): | |||
| """Test the method of validate_condition.""" | |||
| condition = [1, 2, 3] | |||
| self._assert_raise_2(LineageParamTypeError, "Invalid search_condition type, it should be dict.", condition) | |||
| condition = { | |||
| 'limit': False | |||
| } | |||
| self._assert_raise_2(LineageParamTypeError, "The limit must be int.", condition) | |||
| condition = { | |||
| 'offset': False | |||
| } | |||
| self._assert_raise_2(LineageParamTypeError, "The offset must be int.", condition) | |||
| condition = { | |||
| 'sorted_type': 'ascending' | |||
| } | |||
| msg = "The sorted_name must exist when sorted_type exists." | |||
| self._assert_raise_2(LineageParamValueError, msg, condition) | |||
| condition = { | |||
| 'sorted_type': 'invalid', | |||
| 'sorted_name': 'tag' | |||
| } | |||
| msg = "The sorted_type must be ascending or descending." | |||
| self._assert_raise_2(LineageParamValueError, msg, condition) | |||
| def _assert_raise_2(self, exception, msg, condition): | |||
| """ | |||
| Assert raise by unittest. | |||
| Args: | |||
| exception (Type): Exception class expected to be raised. | |||
| msg (msg): Expected error message. | |||
| condition (dict): The parameter of search condition. | |||
| """ | |||
| self.assertRaisesRegex( | |||
| exception, | |||
| msg, | |||
| validate_condition, | |||
| condition | |||
| ) | |||
| def test_validate_train_id(self): | |||
| """Test the test_validate_train_id function.""" | |||
| path = 'invalid' | |||
| self.assertRaisesRegex( | |||
| ParamValueError, | |||
| "Summary dir should be relative path starting with './'.", | |||
| validate_train_id, | |||
| path | |||
| ) | |||
| path = './a/b/c' | |||
| self.assertRaisesRegex( | |||
| ParamValueError, | |||
| "Summary dir should be relative path starting with './'.", | |||
| validate_train_id, | |||
| path | |||
| ) | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test the validate module.""" | |||
| import os | |||
| from unittest import mock | |||
| from marshmallow import ValidationError | |||
| import pytest | |||
| from mindinsight.lineagemgr.common.validator.validate_path import safe_normalize_path, validate_and_normalize_path | |||
| class TestValidatePath: | |||
| """Test the method of validate_path.""" | |||
| @pytest.mark.parametrize('path, check, allow', | |||
| [('', False, False), ('../', False, False), ('invalid', True, False)]) | |||
| def test_validate_and_normalize_path(self, path, check, allow): | |||
| """Test the method of validate_path with ValidationError.""" | |||
| key = 'path' | |||
| path = '' | |||
| with pytest.raises(ValidationError) as info: | |||
| validate_and_normalize_path(path, key, check, allow) | |||
| assert "The path is invalid!" in str(info.value) | |||
| path = '/a/b' | |||
| assert validate_and_normalize_path(path, key, False, True) == os.path.realpath(path) | |||
| @mock.patch('mindinsight.lineagemgr.common.validator.validate_path.validate_and_normalize_path') | |||
| @pytest.mark.parametrize('prefix', [None, ['/']]) | |||
| def test_safe_normalize_path(self, mock_validate_and_normalize_path, prefix): | |||
| """Test the method of safe_normalize_path.""" | |||
| key = 'path' | |||
| path = '/a/b' | |||
| mock_validate_and_normalize_path.return_value = os.path.realpath(path) | |||
| assert safe_normalize_path(path, key, prefix, False, True) == os.path.realpath(path) | |||
| def test_safe_normalize_path_exception(self): | |||
| """Test the method of safe_normalize_path with invalid prefix""" | |||
| key = 'path' | |||
| path = '/a/b' | |||
| prefix = ['invalid'] | |||
| with pytest.raises(ValidationError) as info: | |||
| safe_normalize_path(path, key, prefix, False, True) | |||
| assert "The path is invalid!" in str(info.value) | |||