diff --git a/mindspore/ccsrc/ir/anf_extends.cc b/mindspore/ccsrc/ir/anf_extends.cc index 51b4ed28e6..42ef6b44e4 100644 --- a/mindspore/ccsrc/ir/anf_extends.cc +++ b/mindspore/ccsrc/ir/anf_extends.cc @@ -57,9 +57,6 @@ std::string CNode::fullname_with_scope() { if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || IsApply(prim::kPrimHistogramSummary)) { std::string tag = GetValue(GetValueNode(input(1))); - if (tag == "") { - MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; - } std::string name; if (IsApply(prim::kPrimScalarSummary)) { name = tag + "[:Scalar]"; diff --git a/mindspore/ccsrc/operator/prim_debug.cc b/mindspore/ccsrc/operator/prim_debug.cc index d73c34bf85..a9962c6d14 100644 --- a/mindspore/ccsrc/operator/prim_debug.cc +++ b/mindspore/ccsrc/operator/prim_debug.cc @@ -21,64 +21,5 @@ #include "utils/symbolic.h" namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplScalarSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a scalar and a tensor or scalar. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - - // check the tag - AbstractScalarPtr descriptions = CheckArg(op_name, args_spec_list, 0); - - // check the value: scalar or shape = (1,) - auto scalar_value = dyn_cast(args_spec_list[1]); - if (scalar_value == nullptr) { - auto tensor_value = dyn_cast(args_spec_list[1]); - if (tensor_value == nullptr) { - MS_LOG(EXCEPTION) << "Input must be scalar or shape(1,)"; - } - } else { - auto item_v = scalar_value->BuildValue(); - if (item_v->isa()) { - auto value = item_v->cast()->value(); - if (value.empty()) { - MS_LOG(EXCEPTION) << "Input summary value can't be null"; - } - } - } - - // Reomve the force check to support batch set summary use 'for' loop - auto item_v = descriptions->BuildValue(); - if (!item_v->isa()) { - MS_EXCEPTION(TypeError) << "Summary first parameter should be string"; - } - - return std::make_shared(kAnyValue, kBool); -} - -AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a scalar(tag) and a tensor(value) - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - - // check the tag - auto descriptions = CheckArg(op_name, args_spec_list, 0); - auto tensor_value = CheckArg(op_name, args_spec_list, 1); - - int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); - if (tensor_rank == 0) { - MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; - } - - // Reomve the force check to support batch set summary use 'for' loop - auto item_v = descriptions->BuildValue(); - if (!item_v->isa()) { - MS_EXCEPTION(TypeError) << "Summary first parameter should be string"; - } - - return std::make_shared(kAnyValue, std::make_shared()); -} -} // namespace abstract +namespace abstract {} // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index d37657ce48..21426e1268 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -128,11 +128,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimDepend, {InferImplDepend, true}}, {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, {prim::kPrimControlDepend, {InferImplControlDepend, true}}, - // Debug - {prim::kPrimScalarSummary, {InferImplScalarSummary, true}}, - {prim::kPrimImageSummary, {InferImplTensorSummary, true}}, - {prim::kPrimTensorSummary, {InferImplTensorSummary, true}}, - {prim::kPrimHistogramSummary, {InferImplTensorSummary, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 3969e76bf3..f72fdd257c 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -326,11 +326,6 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplScalarSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 6887c778ed..166ebdc395 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -16,10 +16,22 @@ """debug_ops""" from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer +from ..primitive import prim_attr_register, PrimitiveWithInfer -class ScalarSummary(Primitive): +def _check_summary_param(name, value, class_name): + """Check the name and value is valid for summary.""" + n_type = name['dtype'] + n_value = name['value'] + validator.check_value_type('name', n_type, [type(mstype.string)], class_name) + if not n_value: + raise ValueError(f"For 'name' the value should by valid string in {class_name}, but got {n_value}.") + + v_type = value['dtype'] + validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name) + + +class ScalarSummary(PrimitiveWithInfer): """ Output scalar to protocol buffer through scalar summary operator. @@ -45,11 +57,19 @@ class ScalarSummary(Primitive): def __init__(self): """init""" - def __call__(self, *args, **kwargs): - pass + def __infer__(self, name, value): + _check_summary_param(name, value, self.__class__.__name__) + v_shape = value['shape'] + # In the summary, the value whose shape is [1] is also considered as a scalar. + if v_shape and v_shape != [1]: + raise ValueError(f"For 'value' the type should be scalar, " + f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.") -class ImageSummary(Primitive): + return value + + +class ImageSummary(PrimitiveWithInfer): """ Output image tensor to protocol buffer through image summary operator. @@ -73,11 +93,20 @@ class ImageSummary(Primitive): def __init__(self): """init""" - def __call__(self, *args, **kwargs): - pass + def __infer__(self, name, value): + _check_summary_param(name, value, self.__class__.__name__) + + # The shape dim of image should be 4. + v_shape = value['shape'] + image_dim = 4 + if len(v_shape) != image_dim: + raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__}," + f" but got {len(v_shape)}.") + return value -class TensorSummary(Primitive): + +class TensorSummary(PrimitiveWithInfer): """ Output tensor to protocol buffer through tensor summary operator. @@ -103,11 +132,19 @@ class TensorSummary(Primitive): def __init__(self): """init""" - def __call__(self, *args, **kwargs): - pass + def __infer__(self, name, value): + _check_summary_param(name, value, self.__class__.__name__) + + v_shape = value['shape'] + # In the summary, the value whose shape is [] is not considered as a tensor. + if not v_shape: + raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " + f"shape should not be [].") + + return value -class HistogramSummary(Primitive): +class HistogramSummary(PrimitiveWithInfer): """ Output tensor to protocol buffer through histogram summary operator. @@ -133,6 +170,17 @@ class HistogramSummary(Primitive): def __init__(self): """init""" + def __infer__(self, name, value): + _check_summary_param(name, value, self.__class__.__name__) + + v_shape = value['shape'] + # In the summary, the histogram value should be a tensor whose shape is not []. + if not v_shape: + raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " + f"shape should not be [].") + + return value + class InsertGradientOf(PrimitiveWithInfer): """ diff --git a/tests/st/summary/test_gpu_summary.py b/tests/st/summary/test_gpu_summary.py index 7712b213a2..a1e8ca17d8 100644 --- a/tests/st/summary/test_gpu_summary.py +++ b/tests/st/summary/test_gpu_summary.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +"""Summary gpu st.""" import os import random +import tempfile import shutil -import pytest + import numpy as np +import pytest import mindspore.context as context import mindspore.nn as nn @@ -26,36 +29,9 @@ from mindspore.train.summary.summary_record import SummaryRecord context.set_context(mode=context.GRAPH_MODE, device_target="GPU") -CUR_DIR = os.getcwd() -SUMMARY_DIR_ME = CUR_DIR + "/test_me_summary_event_file/" -SUMMARY_DIR_ME_TEMP = CUR_DIR + "/test_me_temp_summary_event_file/" - - -def clean_environment_file(srcDir): - if os.path.exists(srcDir): - ls = os.listdir(srcDir) - for line in ls: - filePath = os.path.join(srcDir, line) - os.remove(filePath) - os.removedirs(srcDir) - - -def save_summary_events_file(srcDir, desDir): - if not os.path.exists(desDir): - print("-- create desDir") - os.makedirs(desDir) - - ls = os.listdir(srcDir) - for line in ls: - filePath = os.path.join(srcDir, line) - if os.path.isfile(filePath): - print("-- move events file : {}".format(filePath)) - shutil.copy(filePath, desDir) - os.remove(filePath) - os.removedirs(srcDir) - class SummaryNet(nn.Cell): + """Summary net.""" def __init__(self, tag_tuple=None, scalar=1): super(SummaryNet, self).__init__() self.summary_s = P.ScalarSummary() @@ -66,8 +42,9 @@ class SummaryNet(nn.Cell): self.tag_tuple = tag_tuple self.scalar = scalar - def construct(self, x, y): - self.summary_i("image", x) + def construct(self, x, y, image): + """Run summary net.""" + self.summary_i("image", image) self.summary_s("x1", x) z = self.add(x, y) self.summary_t("z1", z) @@ -75,32 +52,38 @@ class SummaryNet(nn.Cell): return z -def train_summary_record_scalar_for_1(test_writer, steps): +def train_summary_record(test_writer, steps): + """Train and record summary.""" net = SummaryNet() out_me_dict = {} for i in range(0, steps): x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) - out_put = net(x, y) + image = Tensor(np.array([[[[1.2]]]]).astype(np.float32)) + out_put = net(x, y, image) test_writer.record(i) - print("-----------------output: %s-------------\n", out_put.asnumpy()) out_me_dict[i] = out_put.asnumpy() return out_me_dict -def me_scalar_summary(steps): - with SummaryRecord(SUMMARY_DIR_ME_TEMP) as test_writer: - out_me_dict = train_summary_record_scalar_for_1(test_writer, steps) +class TestGpuSummary: + """Test Gpu summary.""" + summary_dir = tempfile.mkdtemp(suffix='_gpu_summary') - return out_me_dict + def setup_method(self): + """Run before method.""" + if not os.path.exists(self.summary_dir): + os.mkdir(self.summary_dir) + def teardown_emthod(self): + """Run after method.""" + if os.path.exists(self.summary_dir): + shutil.rmtree(self.summary_dir) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_scalarsummary_scalar1_step10_summaryrecord1(): - clean_environment_file(SUMMARY_DIR_ME_TEMP) - output_dict = me_scalar_summary(10) - print("test_scalarsummary_scalar1_step10_summaryrecord1 \n", output_dict) - save_summary_events_file(SUMMARY_DIR_ME_TEMP, SUMMARY_DIR_ME) - clean_environment_file(SUMMARY_DIR_ME) + @pytest.mark.level0 + @pytest.mark.platform_x86_gpu_training + @pytest.mark.env_onecard + def test_summary_step10_summaryrecord1(self): + """Test record 10 step summary.""" + with SummaryRecord(self.summary_dir) as test_writer: + train_summary_record(test_writer, steps=10) diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index b7905aa12c..8f3fae1d71 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -492,7 +492,7 @@ test_cases = [ }), ('ScalarSummary', { 'block': ScalarSummaryNet(), - 'desc_inputs': [2.2], + 'desc_inputs': [Tensor(2.2)], }), ('L2Normalize', { 'block': L2NormalizeNet(), diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index 558ede7834..9c17c5dcd0 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -112,7 +112,7 @@ def test_InsertGradientOf_3(): def f(x, y): return C.grad_all(debug_test)(x, y) - print("debug_gradient:", f(1, 2)) + print("debug_gradient:", f(Tensor(1.0), Tensor(2.0))) def test_print_shape_type(): diff --git a/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py b/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py index 56ea9e8d17..4b5180b963 100644 --- a/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py +++ b/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py @@ -12,260 +12,139 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -@File : test_summary.py -@Author: -@Date : 2019-08-5 -@Desc : test summary function of ops params valid check -""" -import logging -import numpy as np +"""Test summary function of ops params valid check.""" import os +import tempfile +import shutil +from enum import Enum + +import numpy as np import pytest -import random import mindspore.nn as nn from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.train.summary.summary_record import SummaryRecord -CUR_DIR = os.getcwd() -SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" - -log = logging.getLogger("test") -log.setLevel(level=logging.ERROR) - - -class SummaryDemoTag(nn.Cell): - """ SummaryDemoTag definition """ - - def __init__(self, tag1, tag2, tag3): - super(SummaryDemoTag, self).__init__() - self.s = P.ScalarSummary() - self.histogram_summary = P.HistogramSummary() - self.add = P.TensorAdd() - self.tag1 = tag1 - self.tag2 = tag2 - self.tag3 = tag3 - - def construct(self, x, y): - self.s(self.tag1, x) - z = self.add(x, y) - self.s(self.tag2, z) - self.s(self.tag3, y) - self.histogram_summary(self.tag1, x) - return z +class SummaryEnum(Enum): + """Summary enum.""" + IMAGE = P.ImageSummary.__name__ + SCALAR = P.ScalarSummary.__name__ + TENSOR = P.TensorSummary.__name__ + HISTOGRAM = P.HistogramSummary.__name__ -class SummaryDemoTagForSet(nn.Cell): - """ SummaryDemoTagForSet definition """ - def __init__(self, tag_tuple): - super(SummaryDemoTagForSet, self).__init__() - self.s = P.ScalarSummary() - self.histogram_summary = P.HistogramSummary() +class SummaryNet(nn.Cell): + """Summary net definition.""" + def __init__(self, summary_type, tag, data): + super(SummaryNet, self).__init__() + self.tag = tag + self.data = data + self.summary_fn = getattr(P, summary_type)() + self.one = Tensor(np.array([1]).astype(np.float32)) self.add = P.TensorAdd() - self.tag_tuple = tag_tuple - - def construct(self, x, y): - z = self.add(x, y) - for tag in self.tag_tuple: - self.s(tag, x) - self.histogram_summary(tag, x) - return z - - -class SummaryDemoValue(nn.Cell): - """ SummaryDemoValue definition """ - def __init__(self, value): - super(SummaryDemoValue, self).__init__() - self.s = P.ScalarSummary() - self.add = P.TensorAdd() - self.v = value - - def construct(self, x, y): - self.s("x", self.v) - z = self.add(x, y) - self.s("z", self.v) - self.s("y", self.v) - return z - - -class SummaryDemoValueForSet(nn.Cell): - """ SummaryDemoValueForSet definition """ - - def __init__(self, value, tag_tuple): - super(SummaryDemoValueForSet, self).__init__() - self.s = P.ScalarSummary() - self.add = P.TensorAdd() - self.tag_tuple = tag_tuple - self.v = value - - def construct(self, x, y): - z = self.add(x, y) - for tag in self.tag_tuple: - self.s(tag, self.v) - return z - - -class HistogramSummaryNet(nn.Cell): - "HistogramSummaryNet definition" - - def __init__(self, value): - self.histogram_summary = P.HistogramSummary() - self.add = P.TensorAdd() - self.value = value + def construct(self): + self.summary_fn(self.tag, self.data) + return self.add(self.one, self.one) - def construct(self, tensors1, tensor2): - self.histogram_summary("value", self.value) - return self.add(tensors1, tensor2) +class TestSummaryOps: + """Test summary operators.""" + summary_dir = '' -def run_case(net): - """ run_case """ - # step 0: create the thread - with SummaryRecord(SUMMARY_DIR) as test_writer: - # step 1: create the network for summary - x = Tensor(np.array([1.1]).astype(np.float32)) - y = Tensor(np.array([1.2]).astype(np.float32)) + @classmethod + def run_case(cls, net): + """ run_case """ net.set_train() - - # step 2: create the Event - steps = 100 - for i in range(1, steps): - x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) - y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) - net(x, y) - test_writer.record(i) - - -# Test 1: use the repeat tag -def test_summary_use_repeat_tag(): - log.debug("begin test_summary_use_repeat_tag") - net = SummaryDemoTag("x", "x", "x") - try: - run_case(net) - except: - assert False - else: - assert True - log.debug("finished test_summary_use_repeat_tag") - - -# Test 2: repeat tag use for set summary -def test_summary_use_repeat_tag_for_set(): - log.debug("begin test_summary_use_repeat_tag_for_set") - net = SummaryDemoTagForSet(("x", "x", "x")) - try: - run_case(net) - except: - assert False - else: - assert True - log.debug("finished test_summary_use_repeat_tag_for_set") - - -# Test3: test with invalid tag(None, bool, "", int) -def test_summary_use_invalid_tag_None(): - log.debug("begin test_summary_use_invalid_tag_None") - net = SummaryDemoTag(None, None, None) - try: - run_case(net) - except: - assert True - else: - assert False - log.debug("finished test_summary_use_invalid_tag_None") - - -# Test4: test with invalid tag(None, bool, "", int) -def test_summary_use_invalid_tag_Bool(): - log.debug("begin test_summary_use_invalid_tag_Bool") - net = SummaryDemoTag(True, True, True) - with pytest.raises(TypeError): - run_case(net) - log.debug("finished test_summary_use_invalid_tag_Bool") - - -# Test5: test with invalid tag(None, bool, "", int) -def test_summary_use_invalid_tag_null(): - log.debug("begin test_summary_use_invalid_tag_null") - net = SummaryDemoTag("", "", "") - run_case(net) - log.debug("finished test_summary_use_invalid_tag_null") - - -# Test6: test with invalid tag(None, bool, "", int) -def test_summary_use_invalid_tag_Int(): - log.debug("begin test_summary_use_invalid_tag_Int") - net = SummaryDemoTag(1, 2, 3) - with pytest.raises(TypeError): - run_case(net) - log.debug("finished test_summary_use_invalid_tag_Int") - - -# Test7: test with invalid value(None, "") -def test_scalar_summary_use_invalid_value_None(): - log.debug("begin test_scalar_summary_use_invalid_tag_Int") - net = SummaryDemoValue(None) - try: - run_case(net) - except: - assert True - else: - assert False - log.debug("finished test_scalar_summary_use_invalid_tag_Int") - - -# Test8: test with invalid value(None, "") -def test_scalar_summary_use_invalid_value_None_ForSet(): - log.debug("begin test_scalar_summary_use_invalid_value_None_ForSet") - try: - net = SummaryDemoValueForSet(None, ("x1", "x2")) - run_case(net) - except: - assert True - else: - assert False - log.debug("finished test_scalar_summary_use_invalid_value_None_ForSet") - - -# Test9: test with invalid value(None, "") -def test_scalar_summary_use_invalid_value_null(): - log.debug("begin test_scalar_summary_use_invalid_value_null") - try: - net = SummaryDemoValue("") - run_case(net) - except: - assert True - else: - assert False - log.debug("finished test_scalar_summary_use_invalid_value_null") - - -def test_histogram_summary_use_valid_value(): - """Test histogram summary with valid value""" - log.debug("Begin test_histogram_summary_use_valid_value") - try: - net = HistogramSummaryNet(Tensor(np.array([1, 2, 3]))) - run_case(net) - except: - assert True - else: - assert False - log.debug("Finished test_histogram_summary_use_valid_value") - - -def test_histogram_summary_use_scalar_value(): - """Test histogram summary use scalar value""" - log.debug("Begin test_histogram_summary_use_scalar_value") - try: - scalar = Tensor(1) - net = HistogramSummaryNet(scalar) - run_case(net) - except: - assert True - else: - assert False - log.debug("Finished test_histogram_summary_use_scalar_value") + steps = 10 + with SummaryRecord(cls.summary_dir) as test_writer: + for i in range(1, steps): + net() + test_writer.record(i) + + @classmethod + def setup_class(cls): + """Run before class.""" + if not os.path.exists(cls.summary_dir): + cls.summary_dir = tempfile.mkdtemp(suffix='_summary') + + @classmethod + def teardown_class(cls): + """Run after class.""" + if os.path.exists(cls.summary_dir): + shutil.rmtree(cls.summary_dir) + + @pytest.mark.parametrize( + "summary_type, value", + [ + (SummaryEnum.SCALAR.value, Tensor(1)), + (SummaryEnum.SCALAR.value, Tensor(np.array([1]))), + (SummaryEnum.IMAGE.value, Tensor(np.array([[[[1], [2], [3], [4]]]]))), + (SummaryEnum.TENSOR.value, Tensor(np.array([[1], [2], [3], [4]]))), + (SummaryEnum.HISTOGRAM.value, Tensor(np.array([[1], [2], [3], [4]]))), + ]) + def test_summary_success(self, summary_type, value): + """Test summary success with valid tag and valid data.""" + net = SummaryNet(summary_type, tag='tag', data=value) + TestSummaryOps.run_case(net) + + @pytest.mark.parametrize( + "summary_type", + [ + SummaryEnum.SCALAR.value, + SummaryEnum.IMAGE.value, + SummaryEnum.HISTOGRAM.value, + SummaryEnum.TENSOR.value + ]) + def test_summary_tag_is_none(self, summary_type): + """Test summary tag is None, all summary operator validation rules are consistent.""" + net = SummaryNet(summary_type, tag=None, data=Tensor(0)) + with pytest.raises(TypeError): + TestSummaryOps.run_case(net) + + + @pytest.mark.parametrize( + "summary_type", + [ + SummaryEnum.SCALAR.value, + SummaryEnum.IMAGE.value, + SummaryEnum.HISTOGRAM.value, + SummaryEnum.TENSOR.value + ]) + def test_summary_tag_is_empty_string(self, summary_type): + """Test summary tag is a empty string, all summary operator validation rules are consistent.""" + net = SummaryNet(summary_type, tag='', data=Tensor(0)) + with pytest.raises(ValueError): + TestSummaryOps.run_case(net) + + @pytest.mark.parametrize("tag", [123, True, Tensor(0)]) + def test_summary_tag_is_not_string(self, tag): + """Test summary tag is not a string, all summary operator validation rules are consistent.""" + # All summary operator validation rules are consistent, so we only test scalar summary. + net = SummaryNet(SummaryEnum.SCALAR.value, tag=tag, data=Tensor(0)) + with pytest.raises(TypeError): + TestSummaryOps.run_case(net) + + @pytest.mark.parametrize("value", [123, True, 'data']) + def test_summary_value_type_invalid(self, value): + """Test the type of summary value is invalid, all summary operator validation rules are consistent.""" + # All summary operator validation rules are consistent, so we only test scalar summary. + net = SummaryNet(SummaryEnum.SCALAR.value, tag='tag', data=value) + with pytest.raises(TypeError): + TestSummaryOps.run_case(net) + + @pytest.mark.parametrize( + "summary_type, value", + [ + (SummaryEnum.IMAGE.value, Tensor(np.array([1, 2]))), + (SummaryEnum.SCALAR.value, Tensor(np.array([1, 2]))), + (SummaryEnum.TENSOR.value, Tensor(0)), + (SummaryEnum.HISTOGRAM.value, Tensor(0)) + ]) + def test_value_shape_invalid(self, summary_type, value): + """Test invalid shape of every summary operators.""" + net = SummaryNet(summary_type, tag='tag', data=value) + with pytest.raises(ValueError): + TestSummaryOps.run_case(net)