| @@ -107,6 +107,15 @@ class TrainJobNotExistError(MindInsightException): | |||
| http_code=400) | |||
| class QueryStringContainsNullByteError(MindInsightException): | |||
| """Query string contains null byte error.""" | |||
| def __init__(self, error_detail): | |||
| error_msg = f"Query string contains null byte error. Detail: {error_detail}" | |||
| super(QueryStringContainsNullByteError, self).__init__(DataVisualErrors.QUERY_STRING_CONTAINS_NULL_BYTE, | |||
| error_msg, | |||
| http_code=400) | |||
| class PluginNotAvailableError(MindInsightException): | |||
| """The given plugin is not available.""" | |||
| def __init__(self, error_detail): | |||
| @@ -22,6 +22,7 @@ from pathlib import Path | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.utils.tools import Counter | |||
| from mindinsight.datavisual.utils.utils import contains_null_byte | |||
| from mindinsight.datavisual.common.exceptions import MaxCountExceededError | |||
| from mindinsight.utils.exceptions import FileSystemPermissionError | |||
| @@ -61,7 +62,7 @@ class SummaryWatcher: | |||
| >>> summary_watcher = SummaryWatcher() | |||
| >>> directories = summary_watcher.list_summary_directories('/summary/base/dir') | |||
| """ | |||
| if self._contains_null_byte(summary_base_dir=summary_base_dir): | |||
| if contains_null_byte(summary_base_dir=summary_base_dir): | |||
| return [] | |||
| relative_path = os.path.join('.', '') | |||
| @@ -148,25 +149,6 @@ class SummaryWatcher: | |||
| pass | |||
| self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry) | |||
| def _contains_null_byte(self, **kwargs): | |||
| """ | |||
| Check if arg contains null byte. | |||
| Args: | |||
| kwargs (Any): Check if arg contains null byte. | |||
| Returns: | |||
| bool, indicates if any arg contains null byte. | |||
| """ | |||
| for key, value in kwargs.items(): | |||
| if not isinstance(value, str): | |||
| continue | |||
| if '\x00' in value: | |||
| logger.warning('%s contains null byte \\x00.', key) | |||
| return True | |||
| return False | |||
| def _is_valid_summary_directory(self, summary_base_dir, relative_path): | |||
| """ | |||
| Check if the given summary directory is valid. | |||
| @@ -276,7 +258,7 @@ class SummaryWatcher: | |||
| >>> summary_watcher = SummaryWatcher() | |||
| >>> summaries = summary_watcher.is_summary_directory('/summary/base/dir', './job-01') | |||
| """ | |||
| if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): | |||
| if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): | |||
| return False | |||
| if not self._is_valid_summary_directory(summary_base_dir, relative_path): | |||
| @@ -371,7 +353,7 @@ class SummaryWatcher: | |||
| >>> summary_watcher = SummaryWatcher() | |||
| >>> summaries = summary_watcher.list_summaries('/summary/base/dir', './job-01') | |||
| """ | |||
| if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): | |||
| if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path): | |||
| return [] | |||
| if not self._is_valid_summary_directory(summary_base_dir, relative_path): | |||
| @@ -19,7 +19,9 @@ from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.common.enums import CacheStatus | |||
| from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.utils.utils import contains_null_byte | |||
| from mindinsight.datavisual.processors.base_processor import BaseProcessor | |||
| from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY | |||
| @@ -57,6 +59,8 @@ class TrainTaskManager(BaseProcessor): | |||
| dict, refer to restful api. | |||
| """ | |||
| Validation.check_param_empty(train_id=train_id) | |||
| if contains_null_byte(train_id=train_id): | |||
| raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id)) | |||
| if manual_update: | |||
| self._data_manager.cache_train_job(train_id) | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Utils.""" | |||
| import math | |||
| from mindinsight.datavisual.common.log import logger | |||
| def calc_histogram_bins(count): | |||
| @@ -45,3 +46,23 @@ def calc_histogram_bins(count): | |||
| return math.ceil(count / number_per_bucket) + 1 | |||
| return max_bins | |||
| def contains_null_byte(**kwargs): | |||
| """ | |||
| Check if arg contains null byte. | |||
| Args: | |||
| kwargs (Any): Check if arg contains null byte. | |||
| Returns: | |||
| bool, indicates if any arg contains null byte. | |||
| """ | |||
| for key, value in kwargs.items(): | |||
| if not isinstance(value, str): | |||
| continue | |||
| if '\x00' in value: | |||
| logger.warning('%s contains null byte \\x00.', key) | |||
| return True | |||
| return False | |||
| @@ -70,6 +70,7 @@ class DataVisualErrors(Enum): | |||
| SCALAR_NOT_EXIST = 14 | |||
| HISTOGRAM_NOT_EXIST = 15 | |||
| TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 | |||
| QUERY_STRING_CONTAINS_NULL_BYTE = 17 | |||
| class ScriptConverterErrors(Enum): | |||
| @@ -79,9 +79,9 @@ class TestPlugins: | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.usefixtures("init_summary_logs") | |||
| @pytest.mark.parametrize("train_id", ["@#$", "./\x00home", "././/not_exist_id", dict()]) | |||
| @pytest.mark.parametrize("train_id", ["@#$", "././/not_exist_id", dict()]) | |||
| def test_plugins_with_special_train_id(self, client, train_id): | |||
| """Test passing train_id with special character, null_byte, invalid id, and wrong type.""" | |||
| """Test passing train_id with special character, invalid id, and wrong type.""" | |||
| params = dict(train_id=train_id) | |||
| url = get_url(BASE_URL, params) | |||
| @@ -92,6 +92,26 @@ class TestPlugins: | |||
| assert response['error_code'] == '50545005' | |||
| assert response['error_msg'] == "Train job is not exist. Detail: Can not find the train job in data manager." | |||
| @pytest.mark.level1 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.usefixtures("init_summary_logs") | |||
| @pytest.mark.parametrize("train_id", ["./\x00home"]) | |||
| def test_plugins_with_null_byte_train_id(self, client, train_id): | |||
| """Test passing train_id with null_byte.""" | |||
| params = dict(train_id=train_id, manual_update=True) | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| assert response.status_code == 400 | |||
| response = response.get_json() | |||
| assert response['error_code'] == '50545011' | |||
| assert "Query string contains null byte error. " in response['error_msg'] | |||
| @pytest.mark.level1 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||