Browse Source

add null byte check in the api of get_plugins

tags/v0.5.0-beta
wangshuide2020 5 years ago
parent
commit
0fad2218fd
6 changed files with 61 additions and 24 deletions
  1. +9
    -0
      mindinsight/datavisual/common/exceptions.py
  2. +4
    -22
      mindinsight/datavisual/data_transform/summary_watcher.py
  3. +4
    -0
      mindinsight/datavisual/processors/train_task_manager.py
  4. +21
    -0
      mindinsight/datavisual/utils/utils.py
  5. +1
    -0
      mindinsight/utils/constant.py
  6. +22
    -2
      tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py

+ 9
- 0
mindinsight/datavisual/common/exceptions.py View File

@@ -107,6 +107,15 @@ class TrainJobNotExistError(MindInsightException):
http_code=400)


class QueryStringContainsNullByteError(MindInsightException):
"""Query string contains null byte error."""
def __init__(self, error_detail):
error_msg = f"Query string contains null byte error. Detail: {error_detail}"
super(QueryStringContainsNullByteError, self).__init__(DataVisualErrors.QUERY_STRING_CONTAINS_NULL_BYTE,
error_msg,
http_code=400)


class PluginNotAvailableError(MindInsightException):
"""The given plugin is not available."""
def __init__(self, error_detail):


+ 4
- 22
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -22,6 +22,7 @@ from pathlib import Path
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.utils.tools import Counter
from mindinsight.datavisual.utils.utils import contains_null_byte
from mindinsight.datavisual.common.exceptions import MaxCountExceededError
from mindinsight.utils.exceptions import FileSystemPermissionError
@@ -61,7 +62,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher()
>>> directories = summary_watcher.list_summary_directories('/summary/base/dir')
"""
if self._contains_null_byte(summary_base_dir=summary_base_dir):
if contains_null_byte(summary_base_dir=summary_base_dir):
return []
relative_path = os.path.join('.', '')
@@ -148,25 +149,6 @@ class SummaryWatcher:
pass
self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry)
def _contains_null_byte(self, **kwargs):
"""
Check if arg contains null byte.
Args:
kwargs (Any): Check if arg contains null byte.
Returns:
bool, indicates if any arg contains null byte.
"""
for key, value in kwargs.items():
if not isinstance(value, str):
continue
if '\x00' in value:
logger.warning('%s contains null byte \\x00.', key)
return True
return False
def _is_valid_summary_directory(self, summary_base_dir, relative_path):
"""
Check if the given summary directory is valid.
@@ -276,7 +258,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher()
>>> summaries = summary_watcher.is_summary_directory('/summary/base/dir', './job-01')
"""
if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
return False
if not self._is_valid_summary_directory(summary_base_dir, relative_path):
@@ -371,7 +353,7 @@ class SummaryWatcher:
>>> summary_watcher = SummaryWatcher()
>>> summaries = summary_watcher.list_summaries('/summary/base/dir', './job-01')
"""
if self._contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
if contains_null_byte(summary_base_dir=summary_base_dir, relative_path=relative_path):
return []
if not self._is_valid_summary_directory(summary_base_dir, relative_path):


+ 4
- 0
mindinsight/datavisual/processors/train_task_manager.py View File

@@ -19,7 +19,9 @@ from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.enums import CacheStatus
from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.utils.utils import contains_null_byte
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY

@@ -57,6 +59,8 @@ class TrainTaskManager(BaseProcessor):
dict, refer to restful api.
"""
Validation.check_param_empty(train_id=train_id)
if contains_null_byte(train_id=train_id):
raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id))

if manual_update:
self._data_manager.cache_train_job(train_id)


+ 21
- 0
mindinsight/datavisual/utils/utils.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""Utils."""
import math
from mindinsight.datavisual.common.log import logger


def calc_histogram_bins(count):
@@ -45,3 +46,23 @@ def calc_histogram_bins(count):
return math.ceil(count / number_per_bucket) + 1

return max_bins


def contains_null_byte(**kwargs):
"""
Check if arg contains null byte.

Args:
kwargs (Any): Check if arg contains null byte.

Returns:
bool, indicates if any arg contains null byte.
"""
for key, value in kwargs.items():
if not isinstance(value, str):
continue
if '\x00' in value:
logger.warning('%s contains null byte \\x00.', key)
return True

return False

+ 1
- 0
mindinsight/utils/constant.py View File

@@ -70,6 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
QUERY_STRING_CONTAINS_NULL_BYTE = 17


class ScriptConverterErrors(Enum):


+ 22
- 2
tests/st/func/datavisual/taskmanager/test_plugins_restful_api.py View File

@@ -79,9 +79,9 @@ class TestPlugins:
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("train_id", ["@#$", "./\x00home", "././/not_exist_id", dict()])
@pytest.mark.parametrize("train_id", ["@#$", "././/not_exist_id", dict()])
def test_plugins_with_special_train_id(self, client, train_id):
"""Test passing train_id with special character, null_byte, invalid id, and wrong type."""
"""Test passing train_id with special character, invalid id, and wrong type."""
params = dict(train_id=train_id)
url = get_url(BASE_URL, params)

@@ -92,6 +92,26 @@ class TestPlugins:
assert response['error_code'] == '50545005'
assert response['error_msg'] == "Train job is not exist. Detail: Can not find the train job in data manager."

@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
@pytest.mark.parametrize("train_id", ["./\x00home"])
def test_plugins_with_null_byte_train_id(self, client, train_id):
"""Test passing train_id with null_byte."""
params = dict(train_id=train_id, manual_update=True)
url = get_url(BASE_URL, params)

response = client.get(url)
assert response.status_code == 400

response = response.get_json()
assert response['error_code'] == '50545011'
assert "Query string contains null byte error. " in response['error_msg']

@pytest.mark.level1
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu


Loading…
Cancel
Save