Merge pull request !1329 from ougongchang/mastertags/v0.3.0-alpha
| @@ -57,9 +57,6 @@ std::string CNode::fullname_with_scope() { | |||
| if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || | |||
| IsApply(prim::kPrimHistogramSummary)) { | |||
| std::string tag = GetValue<std::string>(GetValueNode(input(1))); | |||
| if (tag == "") { | |||
| MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; | |||
| } | |||
| std::string name; | |||
| if (IsApply(prim::kPrimScalarSummary)) { | |||
| name = tag + "[:Scalar]"; | |||
| @@ -21,64 +21,5 @@ | |||
| #include "utils/symbolic.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| AbstractBasePtr InferImplScalarSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a scalar and a tensor or scalar. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| // check the tag | |||
| AbstractScalarPtr descriptions = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| // check the value: scalar or shape = (1,) | |||
| auto scalar_value = dyn_cast<AbstractScalar>(args_spec_list[1]); | |||
| if (scalar_value == nullptr) { | |||
| auto tensor_value = dyn_cast<AbstractTensor>(args_spec_list[1]); | |||
| if (tensor_value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Input must be scalar or shape(1,)"; | |||
| } | |||
| } else { | |||
| auto item_v = scalar_value->BuildValue(); | |||
| if (item_v->isa<StringImm>()) { | |||
| auto value = item_v->cast<StringImmPtr>()->value(); | |||
| if (value.empty()) { | |||
| MS_LOG(EXCEPTION) << "Input summary value can't be null"; | |||
| } | |||
| } | |||
| } | |||
| // Reomve the force check to support batch set summary use 'for' loop | |||
| auto item_v = descriptions->BuildValue(); | |||
| if (!item_v->isa<StringImm>()) { | |||
| MS_EXCEPTION(TypeError) << "Summary first parameter should be string"; | |||
| } | |||
| return std::make_shared<AbstractScalar>(kAnyValue, kBool); | |||
| } | |||
| AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a scalar(tag) and a tensor(value) | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| // check the tag | |||
| auto descriptions = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| auto tensor_value = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); | |||
| if (tensor_rank == 0) { | |||
| MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; | |||
| } | |||
| // Reomve the force check to support batch set summary use 'for' loop | |||
| auto item_v = descriptions->BuildValue(); | |||
| if (!item_v->isa<StringImm>()) { | |||
| MS_EXCEPTION(TypeError) << "Summary first parameter should be string"; | |||
| } | |||
| return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<Bool>()); | |||
| } | |||
| } // namespace abstract | |||
| namespace abstract {} // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -128,11 +128,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimDepend, {InferImplDepend, true}}, | |||
| {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, | |||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | |||
| // Debug | |||
| {prim::kPrimScalarSummary, {InferImplScalarSummary, true}}, | |||
| {prim::kPrimImageSummary, {InferImplTensorSummary, true}}, | |||
| {prim::kPrimTensorSummary, {InferImplTensorSummary, true}}, | |||
| {prim::kPrimHistogramSummary, {InferImplTensorSummary, true}}, | |||
| }; | |||
| return prim_eval_implement_map; | |||
| } | |||
| @@ -326,11 +326,6 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplScalarSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTensorSummary(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -16,10 +16,22 @@ | |||
| """debug_ops""" | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer | |||
| from ..primitive import prim_attr_register, PrimitiveWithInfer | |||
| class ScalarSummary(Primitive): | |||
| def _check_summary_param(name, value, class_name): | |||
| """Check the name and value is valid for summary.""" | |||
| n_type = name['dtype'] | |||
| n_value = name['value'] | |||
| validator.check_value_type('name', n_type, [type(mstype.string)], class_name) | |||
| if not n_value: | |||
| raise ValueError(f"For 'name' the value should by valid string in {class_name}, but got {n_value}.") | |||
| v_type = value['dtype'] | |||
| validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name) | |||
| class ScalarSummary(PrimitiveWithInfer): | |||
| """ | |||
| Output scalar to protocol buffer through scalar summary operator. | |||
| @@ -45,11 +57,19 @@ class ScalarSummary(Primitive): | |||
| def __init__(self): | |||
| """init""" | |||
| def __call__(self, *args, **kwargs): | |||
| pass | |||
| def __infer__(self, name, value): | |||
| _check_summary_param(name, value, self.__class__.__name__) | |||
| v_shape = value['shape'] | |||
| # In the summary, the value whose shape is [1] is also considered as a scalar. | |||
| if v_shape and v_shape != [1]: | |||
| raise ValueError(f"For 'value' the type should be scalar, " | |||
| f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.") | |||
| class ImageSummary(Primitive): | |||
| return value | |||
| class ImageSummary(PrimitiveWithInfer): | |||
| """ | |||
| Output image tensor to protocol buffer through image summary operator. | |||
| @@ -73,11 +93,20 @@ class ImageSummary(Primitive): | |||
| def __init__(self): | |||
| """init""" | |||
| def __call__(self, *args, **kwargs): | |||
| pass | |||
| def __infer__(self, name, value): | |||
| _check_summary_param(name, value, self.__class__.__name__) | |||
| # The shape dim of image should be 4. | |||
| v_shape = value['shape'] | |||
| image_dim = 4 | |||
| if len(v_shape) != image_dim: | |||
| raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__}," | |||
| f" but got {len(v_shape)}.") | |||
| return value | |||
| class TensorSummary(Primitive): | |||
| class TensorSummary(PrimitiveWithInfer): | |||
| """ | |||
| Output tensor to protocol buffer through tensor summary operator. | |||
| @@ -103,11 +132,19 @@ class TensorSummary(Primitive): | |||
| def __init__(self): | |||
| """init""" | |||
| def __call__(self, *args, **kwargs): | |||
| pass | |||
| def __infer__(self, name, value): | |||
| _check_summary_param(name, value, self.__class__.__name__) | |||
| v_shape = value['shape'] | |||
| # In the summary, the value whose shape is [] is not considered as a tensor. | |||
| if not v_shape: | |||
| raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " | |||
| f"shape should not be [].") | |||
| return value | |||
| class HistogramSummary(Primitive): | |||
| class HistogramSummary(PrimitiveWithInfer): | |||
| """ | |||
| Output tensor to protocol buffer through histogram summary operator. | |||
| @@ -133,6 +170,17 @@ class HistogramSummary(Primitive): | |||
| def __init__(self): | |||
| """init""" | |||
| def __infer__(self, name, value): | |||
| _check_summary_param(name, value, self.__class__.__name__) | |||
| v_shape = value['shape'] | |||
| # In the summary, the histogram value should be a tensor whose shape is not []. | |||
| if not v_shape: | |||
| raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, " | |||
| f"shape should not be [].") | |||
| return value | |||
| class InsertGradientOf(PrimitiveWithInfer): | |||
| """ | |||
| @@ -12,11 +12,14 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Summary gpu st.""" | |||
| import os | |||
| import random | |||
| import tempfile | |||
| import shutil | |||
| import pytest | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| @@ -26,36 +29,9 @@ from mindspore.train.summary.summary_record import SummaryRecord | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| CUR_DIR = os.getcwd() | |||
| SUMMARY_DIR_ME = CUR_DIR + "/test_me_summary_event_file/" | |||
| SUMMARY_DIR_ME_TEMP = CUR_DIR + "/test_me_temp_summary_event_file/" | |||
| def clean_environment_file(srcDir): | |||
| if os.path.exists(srcDir): | |||
| ls = os.listdir(srcDir) | |||
| for line in ls: | |||
| filePath = os.path.join(srcDir, line) | |||
| os.remove(filePath) | |||
| os.removedirs(srcDir) | |||
| def save_summary_events_file(srcDir, desDir): | |||
| if not os.path.exists(desDir): | |||
| print("-- create desDir") | |||
| os.makedirs(desDir) | |||
| ls = os.listdir(srcDir) | |||
| for line in ls: | |||
| filePath = os.path.join(srcDir, line) | |||
| if os.path.isfile(filePath): | |||
| print("-- move events file : {}".format(filePath)) | |||
| shutil.copy(filePath, desDir) | |||
| os.remove(filePath) | |||
| os.removedirs(srcDir) | |||
| class SummaryNet(nn.Cell): | |||
| """Summary net.""" | |||
| def __init__(self, tag_tuple=None, scalar=1): | |||
| super(SummaryNet, self).__init__() | |||
| self.summary_s = P.ScalarSummary() | |||
| @@ -66,8 +42,9 @@ class SummaryNet(nn.Cell): | |||
| self.tag_tuple = tag_tuple | |||
| self.scalar = scalar | |||
| def construct(self, x, y): | |||
| self.summary_i("image", x) | |||
| def construct(self, x, y, image): | |||
| """Run summary net.""" | |||
| self.summary_i("image", image) | |||
| self.summary_s("x1", x) | |||
| z = self.add(x, y) | |||
| self.summary_t("z1", z) | |||
| @@ -75,32 +52,38 @@ class SummaryNet(nn.Cell): | |||
| return z | |||
| def train_summary_record_scalar_for_1(test_writer, steps): | |||
| def train_summary_record(test_writer, steps): | |||
| """Train and record summary.""" | |||
| net = SummaryNet() | |||
| out_me_dict = {} | |||
| for i in range(0, steps): | |||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||
| out_put = net(x, y) | |||
| image = Tensor(np.array([[[[1.2]]]]).astype(np.float32)) | |||
| out_put = net(x, y, image) | |||
| test_writer.record(i) | |||
| print("-----------------output: %s-------------\n", out_put.asnumpy()) | |||
| out_me_dict[i] = out_put.asnumpy() | |||
| return out_me_dict | |||
| def me_scalar_summary(steps): | |||
| with SummaryRecord(SUMMARY_DIR_ME_TEMP) as test_writer: | |||
| out_me_dict = train_summary_record_scalar_for_1(test_writer, steps) | |||
| class TestGpuSummary: | |||
| """Test Gpu summary.""" | |||
| summary_dir = tempfile.mkdtemp(suffix='_gpu_summary') | |||
| return out_me_dict | |||
| def setup_method(self): | |||
| """Run before method.""" | |||
| if not os.path.exists(self.summary_dir): | |||
| os.mkdir(self.summary_dir) | |||
| def teardown_emthod(self): | |||
| """Run after method.""" | |||
| if os.path.exists(self.summary_dir): | |||
| shutil.rmtree(self.summary_dir) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_scalarsummary_scalar1_step10_summaryrecord1(): | |||
| clean_environment_file(SUMMARY_DIR_ME_TEMP) | |||
| output_dict = me_scalar_summary(10) | |||
| print("test_scalarsummary_scalar1_step10_summaryrecord1 \n", output_dict) | |||
| save_summary_events_file(SUMMARY_DIR_ME_TEMP, SUMMARY_DIR_ME) | |||
| clean_environment_file(SUMMARY_DIR_ME) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_summary_step10_summaryrecord1(self): | |||
| """Test record 10 step summary.""" | |||
| with SummaryRecord(self.summary_dir) as test_writer: | |||
| train_summary_record(test_writer, steps=10) | |||
| @@ -492,7 +492,7 @@ test_cases = [ | |||
| }), | |||
| ('ScalarSummary', { | |||
| 'block': ScalarSummaryNet(), | |||
| 'desc_inputs': [2.2], | |||
| 'desc_inputs': [Tensor(2.2)], | |||
| }), | |||
| ('L2Normalize', { | |||
| 'block': L2NormalizeNet(), | |||
| @@ -112,7 +112,7 @@ def test_InsertGradientOf_3(): | |||
| def f(x, y): | |||
| return C.grad_all(debug_test)(x, y) | |||
| print("debug_gradient:", f(1, 2)) | |||
| print("debug_gradient:", f(Tensor(1.0), Tensor(2.0))) | |||
| def test_print_shape_type(): | |||
| @@ -12,260 +12,139 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| @File : test_summary.py | |||
| @Author: | |||
| @Date : 2019-08-5 | |||
| @Desc : test summary function of ops params valid check | |||
| """ | |||
| import logging | |||
| import numpy as np | |||
| """Test summary function of ops params valid check.""" | |||
| import os | |||
| import tempfile | |||
| import shutil | |||
| from enum import Enum | |||
| import numpy as np | |||
| import pytest | |||
| import random | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train.summary.summary_record import SummaryRecord | |||
| CUR_DIR = os.getcwd() | |||
| SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" | |||
| log = logging.getLogger("test") | |||
| log.setLevel(level=logging.ERROR) | |||
| class SummaryDemoTag(nn.Cell): | |||
| """ SummaryDemoTag definition """ | |||
| def __init__(self, tag1, tag2, tag3): | |||
| super(SummaryDemoTag, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| self.tag1 = tag1 | |||
| self.tag2 = tag2 | |||
| self.tag3 = tag3 | |||
| def construct(self, x, y): | |||
| self.s(self.tag1, x) | |||
| z = self.add(x, y) | |||
| self.s(self.tag2, z) | |||
| self.s(self.tag3, y) | |||
| self.histogram_summary(self.tag1, x) | |||
| return z | |||
| class SummaryEnum(Enum): | |||
| """Summary enum.""" | |||
| IMAGE = P.ImageSummary.__name__ | |||
| SCALAR = P.ScalarSummary.__name__ | |||
| TENSOR = P.TensorSummary.__name__ | |||
| HISTOGRAM = P.HistogramSummary.__name__ | |||
| class SummaryDemoTagForSet(nn.Cell): | |||
| """ SummaryDemoTagForSet definition """ | |||
| def __init__(self, tag_tuple): | |||
| super(SummaryDemoTagForSet, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.histogram_summary = P.HistogramSummary() | |||
| class SummaryNet(nn.Cell): | |||
| """Summary net definition.""" | |||
| def __init__(self, summary_type, tag, data): | |||
| super(SummaryNet, self).__init__() | |||
| self.tag = tag | |||
| self.data = data | |||
| self.summary_fn = getattr(P, summary_type)() | |||
| self.one = Tensor(np.array([1]).astype(np.float32)) | |||
| self.add = P.TensorAdd() | |||
| self.tag_tuple = tag_tuple | |||
| def construct(self, x, y): | |||
| z = self.add(x, y) | |||
| for tag in self.tag_tuple: | |||
| self.s(tag, x) | |||
| self.histogram_summary(tag, x) | |||
| return z | |||
| class SummaryDemoValue(nn.Cell): | |||
| """ SummaryDemoValue definition """ | |||
| def __init__(self, value): | |||
| super(SummaryDemoValue, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.add = P.TensorAdd() | |||
| self.v = value | |||
| def construct(self, x, y): | |||
| self.s("x", self.v) | |||
| z = self.add(x, y) | |||
| self.s("z", self.v) | |||
| self.s("y", self.v) | |||
| return z | |||
| class SummaryDemoValueForSet(nn.Cell): | |||
| """ SummaryDemoValueForSet definition """ | |||
| def __init__(self, value, tag_tuple): | |||
| super(SummaryDemoValueForSet, self).__init__() | |||
| self.s = P.ScalarSummary() | |||
| self.add = P.TensorAdd() | |||
| self.tag_tuple = tag_tuple | |||
| self.v = value | |||
| def construct(self, x, y): | |||
| z = self.add(x, y) | |||
| for tag in self.tag_tuple: | |||
| self.s(tag, self.v) | |||
| return z | |||
| class HistogramSummaryNet(nn.Cell): | |||
| "HistogramSummaryNet definition" | |||
| def __init__(self, value): | |||
| self.histogram_summary = P.HistogramSummary() | |||
| self.add = P.TensorAdd() | |||
| self.value = value | |||
| def construct(self): | |||
| self.summary_fn(self.tag, self.data) | |||
| return self.add(self.one, self.one) | |||
| def construct(self, tensors1, tensor2): | |||
| self.histogram_summary("value", self.value) | |||
| return self.add(tensors1, tensor2) | |||
| class TestSummaryOps: | |||
| """Test summary operators.""" | |||
| summary_dir = '' | |||
| def run_case(net): | |||
| """ run_case """ | |||
| # step 0: create the thread | |||
| with SummaryRecord(SUMMARY_DIR) as test_writer: | |||
| # step 1: create the network for summary | |||
| x = Tensor(np.array([1.1]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2]).astype(np.float32)) | |||
| @classmethod | |||
| def run_case(cls, net): | |||
| """ run_case """ | |||
| net.set_train() | |||
| # step 2: create the Event | |||
| steps = 100 | |||
| for i in range(1, steps): | |||
| x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) | |||
| y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) | |||
| net(x, y) | |||
| test_writer.record(i) | |||
| # Test 1: use the repeat tag | |||
| def test_summary_use_repeat_tag(): | |||
| log.debug("begin test_summary_use_repeat_tag") | |||
| net = SummaryDemoTag("x", "x", "x") | |||
| try: | |||
| run_case(net) | |||
| except: | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_summary_use_repeat_tag") | |||
| # Test 2: repeat tag use for set summary | |||
| def test_summary_use_repeat_tag_for_set(): | |||
| log.debug("begin test_summary_use_repeat_tag_for_set") | |||
| net = SummaryDemoTagForSet(("x", "x", "x")) | |||
| try: | |||
| run_case(net) | |||
| except: | |||
| assert False | |||
| else: | |||
| assert True | |||
| log.debug("finished test_summary_use_repeat_tag_for_set") | |||
| # Test3: test with invalid tag(None, bool, "", int) | |||
| def test_summary_use_invalid_tag_None(): | |||
| log.debug("begin test_summary_use_invalid_tag_None") | |||
| net = SummaryDemoTag(None, None, None) | |||
| try: | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_summary_use_invalid_tag_None") | |||
| # Test4: test with invalid tag(None, bool, "", int) | |||
| def test_summary_use_invalid_tag_Bool(): | |||
| log.debug("begin test_summary_use_invalid_tag_Bool") | |||
| net = SummaryDemoTag(True, True, True) | |||
| with pytest.raises(TypeError): | |||
| run_case(net) | |||
| log.debug("finished test_summary_use_invalid_tag_Bool") | |||
| # Test5: test with invalid tag(None, bool, "", int) | |||
| def test_summary_use_invalid_tag_null(): | |||
| log.debug("begin test_summary_use_invalid_tag_null") | |||
| net = SummaryDemoTag("", "", "") | |||
| run_case(net) | |||
| log.debug("finished test_summary_use_invalid_tag_null") | |||
| # Test6: test with invalid tag(None, bool, "", int) | |||
| def test_summary_use_invalid_tag_Int(): | |||
| log.debug("begin test_summary_use_invalid_tag_Int") | |||
| net = SummaryDemoTag(1, 2, 3) | |||
| with pytest.raises(TypeError): | |||
| run_case(net) | |||
| log.debug("finished test_summary_use_invalid_tag_Int") | |||
| # Test7: test with invalid value(None, "") | |||
| def test_scalar_summary_use_invalid_value_None(): | |||
| log.debug("begin test_scalar_summary_use_invalid_tag_Int") | |||
| net = SummaryDemoValue(None) | |||
| try: | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_scalar_summary_use_invalid_tag_Int") | |||
| # Test8: test with invalid value(None, "") | |||
| def test_scalar_summary_use_invalid_value_None_ForSet(): | |||
| log.debug("begin test_scalar_summary_use_invalid_value_None_ForSet") | |||
| try: | |||
| net = SummaryDemoValueForSet(None, ("x1", "x2")) | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_scalar_summary_use_invalid_value_None_ForSet") | |||
| # Test9: test with invalid value(None, "") | |||
| def test_scalar_summary_use_invalid_value_null(): | |||
| log.debug("begin test_scalar_summary_use_invalid_value_null") | |||
| try: | |||
| net = SummaryDemoValue("") | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("finished test_scalar_summary_use_invalid_value_null") | |||
| def test_histogram_summary_use_valid_value(): | |||
| """Test histogram summary with valid value""" | |||
| log.debug("Begin test_histogram_summary_use_valid_value") | |||
| try: | |||
| net = HistogramSummaryNet(Tensor(np.array([1, 2, 3]))) | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("Finished test_histogram_summary_use_valid_value") | |||
| def test_histogram_summary_use_scalar_value(): | |||
| """Test histogram summary use scalar value""" | |||
| log.debug("Begin test_histogram_summary_use_scalar_value") | |||
| try: | |||
| scalar = Tensor(1) | |||
| net = HistogramSummaryNet(scalar) | |||
| run_case(net) | |||
| except: | |||
| assert True | |||
| else: | |||
| assert False | |||
| log.debug("Finished test_histogram_summary_use_scalar_value") | |||
| steps = 10 | |||
| with SummaryRecord(cls.summary_dir) as test_writer: | |||
| for i in range(1, steps): | |||
| net() | |||
| test_writer.record(i) | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Run before class.""" | |||
| if not os.path.exists(cls.summary_dir): | |||
| cls.summary_dir = tempfile.mkdtemp(suffix='_summary') | |||
| @classmethod | |||
| def teardown_class(cls): | |||
| """Run after class.""" | |||
| if os.path.exists(cls.summary_dir): | |||
| shutil.rmtree(cls.summary_dir) | |||
| @pytest.mark.parametrize( | |||
| "summary_type, value", | |||
| [ | |||
| (SummaryEnum.SCALAR.value, Tensor(1)), | |||
| (SummaryEnum.SCALAR.value, Tensor(np.array([1]))), | |||
| (SummaryEnum.IMAGE.value, Tensor(np.array([[[[1], [2], [3], [4]]]]))), | |||
| (SummaryEnum.TENSOR.value, Tensor(np.array([[1], [2], [3], [4]]))), | |||
| (SummaryEnum.HISTOGRAM.value, Tensor(np.array([[1], [2], [3], [4]]))), | |||
| ]) | |||
| def test_summary_success(self, summary_type, value): | |||
| """Test summary success with valid tag and valid data.""" | |||
| net = SummaryNet(summary_type, tag='tag', data=value) | |||
| TestSummaryOps.run_case(net) | |||
| @pytest.mark.parametrize( | |||
| "summary_type", | |||
| [ | |||
| SummaryEnum.SCALAR.value, | |||
| SummaryEnum.IMAGE.value, | |||
| SummaryEnum.HISTOGRAM.value, | |||
| SummaryEnum.TENSOR.value | |||
| ]) | |||
| def test_summary_tag_is_none(self, summary_type): | |||
| """Test summary tag is None, all summary operator validation rules are consistent.""" | |||
| net = SummaryNet(summary_type, tag=None, data=Tensor(0)) | |||
| with pytest.raises(TypeError): | |||
| TestSummaryOps.run_case(net) | |||
| @pytest.mark.parametrize( | |||
| "summary_type", | |||
| [ | |||
| SummaryEnum.SCALAR.value, | |||
| SummaryEnum.IMAGE.value, | |||
| SummaryEnum.HISTOGRAM.value, | |||
| SummaryEnum.TENSOR.value | |||
| ]) | |||
| def test_summary_tag_is_empty_string(self, summary_type): | |||
| """Test summary tag is a empty string, all summary operator validation rules are consistent.""" | |||
| net = SummaryNet(summary_type, tag='', data=Tensor(0)) | |||
| with pytest.raises(ValueError): | |||
| TestSummaryOps.run_case(net) | |||
| @pytest.mark.parametrize("tag", [123, True, Tensor(0)]) | |||
| def test_summary_tag_is_not_string(self, tag): | |||
| """Test summary tag is not a string, all summary operator validation rules are consistent.""" | |||
| # All summary operator validation rules are consistent, so we only test scalar summary. | |||
| net = SummaryNet(SummaryEnum.SCALAR.value, tag=tag, data=Tensor(0)) | |||
| with pytest.raises(TypeError): | |||
| TestSummaryOps.run_case(net) | |||
| @pytest.mark.parametrize("value", [123, True, 'data']) | |||
| def test_summary_value_type_invalid(self, value): | |||
| """Test the type of summary value is invalid, all summary operator validation rules are consistent.""" | |||
| # All summary operator validation rules are consistent, so we only test scalar summary. | |||
| net = SummaryNet(SummaryEnum.SCALAR.value, tag='tag', data=value) | |||
| with pytest.raises(TypeError): | |||
| TestSummaryOps.run_case(net) | |||
| @pytest.mark.parametrize( | |||
| "summary_type, value", | |||
| [ | |||
| (SummaryEnum.IMAGE.value, Tensor(np.array([1, 2]))), | |||
| (SummaryEnum.SCALAR.value, Tensor(np.array([1, 2]))), | |||
| (SummaryEnum.TENSOR.value, Tensor(0)), | |||
| (SummaryEnum.HISTOGRAM.value, Tensor(0)) | |||
| ]) | |||
| def test_value_shape_invalid(self, summary_type, value): | |||
| """Test invalid shape of every summary operators.""" | |||
| net = SummaryNet(summary_type, tag='tag', data=value) | |||
| with pytest.raises(ValueError): | |||
| TestSummaryOps.run_case(net) | |||