|
- import builtins
- import copy
- import io
- import json
- import logging
- import sys
- import random
- import typing
- import unittest
-
- import jsonschema
- import numpy
-
- from d3m import container, types, utils
- from d3m.container import list
- from d3m.metadata import base as metadata_base
-
-
- class TestUtils(unittest.TestCase):
- def test_get_type_arguments(self):
- A = typing.TypeVar('A')
- B = typing.TypeVar('B')
- C = typing.TypeVar('C')
-
- class Base(typing.Generic[A, B]):
- pass
-
- class Foo(Base[A, None]):
- pass
-
- class Bar(Foo[A], typing.Generic[A, C]):
- pass
-
- class Baz(Bar[float, int]):
- pass
-
- self.assertEqual(utils.get_type_arguments(Bar), {
- A: typing.Any,
- B: type(None),
- C: typing.Any,
- })
- self.assertEqual(utils.get_type_arguments(Baz), {
- A: float,
- B: type(None),
- C: int,
- })
-
- self.assertEqual(utils.get_type_arguments(Base), {
- A: typing.Any,
- B: typing.Any,
- })
-
- self.assertEqual(utils.get_type_arguments(Base[float, int]), {
- A: float,
- B: int,
- })
-
- self.assertEqual(utils.get_type_arguments(Foo), {
- A: typing.Any,
- B: type(None),
- })
-
- self.assertEqual(utils.get_type_arguments(Foo[float]), {
- A: float,
- B: type(None),
- })
-
- def test_issubclass(self):
- self.assertTrue(utils.is_subclass(list.List, types.Container))
-
- T1 = typing.TypeVar('T1', bound=list.List)
- self.assertTrue(utils.is_subclass(list.List, T1))
-
- def test_create_enum(self):
- obj = {
- 'definitions': {
- 'foobar1':{
- 'type': 'array',
- 'items': {
- 'anyOf':[
- {'enum': ['AAA']},
- {'enum': ['BBB']},
- {'enum': ['CCC']},
- {'enum': ['DDD']},
- ],
- },
- },
- 'foobar2': {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'anyOf': [
- {
- 'properties': {
- 'type': {
- 'type': 'string',
- 'enum': ['EEE'],
- },
- },
- },
- {
- 'properties': {
- 'type': {
- 'type': 'string',
- 'enum': ['FFF'],
- },
- },
- },
- {
- 'properties': {
- 'type': {
- 'type': 'string',
- 'enum': ['GGG'],
- },
- },
- },
- ],
- },
- },
- 'foobar3': {
- 'type': 'string',
- 'enum': ['HHH', 'HHH', 'III', 'JJJ'],
- }
- },
- }
-
- Foobar1 = utils.create_enum_from_json_schema_enum('Foobar1', obj, 'definitions.foobar1.items.anyOf[*].enum[*]')
- Foobar2 = utils.create_enum_from_json_schema_enum('Foobar2', obj, 'definitions.foobar2.items.anyOf[*].properties.type.enum[*]')
- Foobar3 = utils.create_enum_from_json_schema_enum('Foobar3', obj, 'definitions.foobar3.enum[*]')
-
- self.assertSequenceEqual(builtins.list(Foobar1.__members__.keys()), ['AAA', 'BBB', 'CCC', 'DDD'])
- self.assertSequenceEqual([value.value for value in Foobar1.__members__.values()], ['AAA', 'BBB', 'CCC', 'DDD'])
-
- self.assertSequenceEqual(builtins.list(Foobar2.__members__.keys()), ['EEE', 'FFF', 'GGG'])
- self.assertSequenceEqual([value.value for value in Foobar2.__members__.values()], ['EEE', 'FFF', 'GGG'])
-
- self.assertSequenceEqual(builtins.list(Foobar3.__members__.keys()), ['HHH', 'III', 'JJJ'])
- self.assertSequenceEqual([value.value for value in Foobar3.__members__.values()], ['HHH', 'III', 'JJJ'])
-
- self.assertTrue(Foobar1.AAA.name == 'AAA')
- self.assertTrue(Foobar1.AAA.value == 'AAA')
- self.assertTrue(Foobar1.AAA == Foobar1.AAA)
- self.assertTrue(Foobar1.AAA == 'AAA')
-
- def test_extendable_enum(self):
- class Foobar(utils.Enum):
- AAA = 1
- BBB = 2
- CCC = 3
-
- self.assertSequenceEqual(builtins.list(Foobar.__members__.keys()), ['AAA', 'BBB', 'CCC'])
- self.assertSequenceEqual([value.value for value in Foobar.__members__.values()], [1, 2, 3])
-
- with self.assertRaises(AttributeError):
- Foobar.register_value('CCC', 5)
-
- self.assertSequenceEqual(builtins.list(Foobar.__members__.keys()), ['AAA', 'BBB', 'CCC'])
- self.assertSequenceEqual([value.value for value in Foobar.__members__.values()], [1, 2, 3])
-
- Foobar.register_value('DDD', 4)
-
- self.assertSequenceEqual(builtins.list(Foobar.__members__.keys()), ['AAA', 'BBB', 'CCC', 'DDD'])
- self.assertSequenceEqual([value.value for value in Foobar.__members__.values()], [1, 2, 3, 4])
-
- self.assertEqual(Foobar['DDD'], 'DDD')
- self.assertEqual(Foobar(4), 'DDD')
-
- Foobar.register_value('EEE', 4)
-
- self.assertSequenceEqual(builtins.list(Foobar.__members__.keys()), ['AAA', 'BBB', 'CCC', 'DDD', 'EEE'])
- self.assertSequenceEqual([value.value for value in Foobar.__members__.values()], [1, 2, 3, 4, 4])
-
- self.assertEqual(Foobar['EEE'], 'DDD')
- self.assertEqual(Foobar(4), 'DDD')
-
- def test_redirect(self):
- old_stdout = sys.stdout
- old_stderr = sys.stderr
-
- test_stream = io.StringIO()
- sys.stdout = test_stream
- sys.stderr = test_stream
-
- logger = logging.getLogger('test_logger')
- with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
- with utils.redirect_to_logging(logger=logger, pass_through=False):
- print("Test.")
-
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Test.")
-
- with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
- with utils.redirect_to_logging(logger=logger, pass_through=False):
- print("foo", "bar")
-
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "foo bar")
-
- with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
- with utils.redirect_to_logging(logger=logger, pass_through=False):
- print("Test.\nTe", end="")
- print("st2.", end="")
-
- # The incomplete line should not be written to the logger.
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Test.")
-
- # Remaining contents should be written to logger upon closing.
- self.assertEqual(len(cm.records), 2)
- self.assertEqual(cm.records[0].message, "Test.")
- self.assertEqual(cm.records[1].message, "Test2.")
-
- with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
- with utils.redirect_to_logging(logger=logger, pass_through=False):
- print("Test. ")
- print(" ")
- print(" Test2.")
- print(" ")
-
- # Trailing whitespace and new lines should not be logged.
- self.assertEqual(len(cm.records), 2)
- self.assertEqual(cm.records[0].message, "Test.")
- self.assertEqual(cm.records[1].message, " Test2.")
-
- logger2 = logging.getLogger('test_logger2')
- with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
- with self.assertLogs(logger=logger2, level=logging.DEBUG) as cm2:
- with utils.redirect_to_logging(logger=logger, pass_through=True):
- print("Test.")
- with utils.redirect_to_logging(logger=logger2, pass_through=True):
- print("Test2.")
-
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Test.")
- self.assertEqual(len(cm2.records), 1)
- self.assertEqual(cm2.records[0].message, "Test2.")
-
- pass_through_lines = test_stream.getvalue().split('\n')
- self.assertEqual(len(pass_through_lines), 3)
- self.assertEqual(pass_through_lines[0], "Test.")
- self.assertEqual(pass_through_lines[1], "Test2.")
- self.assertEqual(pass_through_lines[2], "")
-
- records = []
-
- def callback(record):
- nonlocal records
- records.append(record)
-
- # Test recursion prevention.
- with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
- with self.assertLogs(logger=logger2, level=logging.DEBUG) as cm2:
- # We add it twice so that we test that handler does not modify record while running.
- logger2.addHandler(utils.CallbackHandler(callback))
- logger2.addHandler(utils.CallbackHandler(callback))
-
- with utils.redirect_to_logging(logger=logger, pass_through=False):
- print("Test.")
- with utils.redirect_to_logging(logger=logger2, pass_through=False):
- # We configure handler after redirecting.
- handler = logging.StreamHandler(sys.stdout)
- handler.setFormatter(logging.Formatter('Test format: %(message)s'))
- logger2.addHandler(handler)
- print("Test2.")
-
- # We use outer "redirect_to_logging" to make sure nothing from inner gets out.
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Test.")
-
- self.assertEqual(len(cm2.records), 2)
- # This one comes from the print.
- self.assertEqual(cm2.records[0].message, "Test2.")
- # And this one comes from the stream handler.
- self.assertEqual(cm2.records[1].message, "Test format: Test2.")
-
- self.assertEqual(len(records), 4)
- self.assertEqual(records[0]['message'], "Test2.")
- self.assertEqual(records[1]['message'], "Test2.")
- self.assertEqual(records[2]['message'], "Test format: Test2.")
- self.assertEqual(records[3]['message'], "Test format: Test2.")
-
- test_stream.close()
- sys.stdout = old_stdout
- sys.stderr = old_stderr
-
- def test_columns_sum(self):
- dataframe = container.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}, generate_metadata=True)
-
- dataframe_sum = utils.columns_sum(dataframe)
-
- self.assertEqual(dataframe_sum.values.tolist(), [[6, 15]])
- self.assertEqual(dataframe_sum.metadata.query((metadata_base.ALL_ELEMENTS, 0))['name'], 'a')
- self.assertEqual(dataframe_sum.metadata.query((metadata_base.ALL_ELEMENTS, 1))['name'], 'b')
-
- array = container.ndarray(dataframe, generate_metadata=True)
-
- array_sum = utils.columns_sum(array)
-
- self.assertEqual(array_sum.tolist(), [[6, 15]])
- self.assertEqual(array_sum.metadata.query((metadata_base.ALL_ELEMENTS, 0))['name'], 'a')
- self.assertEqual(array_sum.metadata.query((metadata_base.ALL_ELEMENTS, 1))['name'], 'b')
-
- def test_numeric(self):
- self.assertTrue(utils.is_float(type(1.0)))
- self.assertFalse(utils.is_float(type(1)))
- self.assertFalse(utils.is_int(type(1.0)))
- self.assertTrue(utils.is_int(type(1)))
- self.assertTrue(utils.is_numeric(type(1.0)))
- self.assertTrue(utils.is_numeric(type(1)))
-
- def test_yaml_representers(self):
- self.assertEqual(utils.yaml_load(utils.yaml_dump(numpy.int32(1))), 1)
- self.assertEqual(utils.yaml_load(utils.yaml_dump(numpy.int64(1))), 1)
- self.assertEqual(utils.yaml_load(utils.yaml_dump(numpy.float32(1.0))), 1.0)
- self.assertEqual(utils.yaml_load(utils.yaml_dump(numpy.float64(1.0))), 1.0)
-
- def test_json_schema_python_type(self):
- schemas = copy.copy(metadata_base.SCHEMAS)
- schemas['http://example.com/testing_python_type.json'] = {
- 'id': 'http://example.com/testing_python_type.json',
- 'properties': {
- 'foobar': {
- '$ref': 'https://metadata.datadrivendiscovery.org/schemas/v0/definitions.json#/definitions/python_type',
- },
- },
- }
-
- validator, = utils.load_schema_validators(schemas, ('testing_python_type.json',))
-
- validator.validate({'foobar': 'str'})
- validator.validate({'foobar': str})
-
- with self.assertRaisesRegex(jsonschema.exceptions.ValidationError, 'python-type'):
- validator.validate({'foobar': 1})
-
- def test_json_schema_numeric(self):
- schemas = copy.copy(metadata_base.SCHEMAS)
- schemas['http://example.com/testing_numeric.json'] = {
- 'id': 'http://example.com/testing_numeric.json',
- 'properties': {
- 'int': {
- 'type': 'integer',
- },
- 'float': {
- 'type': 'number',
- },
- },
- }
-
- validator, = utils.load_schema_validators(schemas, ('testing_numeric.json',))
-
- validator.validate({'float': 0})
- validator.validate({'float': 1.0})
- validator.validate({'float': 1.2})
-
- with self.assertRaisesRegex(jsonschema.exceptions.ValidationError, 'float'):
- validator.validate({'float': '1.2'})
-
- validator.validate({'int': 0})
- validator.validate({'int': 1.0})
-
- with self.assertRaisesRegex(jsonschema.exceptions.ValidationError, 'int'):
- validator.validate({'int': 1.2})
-
- with self.assertRaisesRegex(jsonschema.exceptions.ValidationError, 'int'):
- validator.validate({'int': '1.0'})
-
- def test_digest(self):
- self.assertEqual(utils.compute_digest({'a': 1.0, 'digest': 'xxx'}), utils.compute_digest({'a': 1.0}))
- self.assertEqual(utils.compute_hash_id({'a': 1.0, 'id': 'xxx'}), utils.compute_hash_id({'a': 1.0}))
-
- self.assertEqual(utils.compute_digest({'a': 1.0}), utils.compute_digest({'a': 1}))
- self.assertEqual(utils.compute_hash_id({'a': 1.0}), utils.compute_hash_id({'a': 1}))
-
- def test_json_equals(self):
- basic_cases = ['hello', 0, -2, 3.14, False, True, [1, 2, 3], {'a': 1}, set(['z', 'y', 'x'])]
- for case in basic_cases:
- self.assertTrue(utils.json_structure_equals(case, case))
-
- self.assertFalse(utils.json_structure_equals({'extra_key': 'value'}, {}))
- self.assertFalse(utils.json_structure_equals({}, {'extra_key': 'value'}))
- self.assertTrue(utils.json_structure_equals({}, {'extra_key': 'value'}, ignore_keys={'extra_key'}))
-
- list1 = {'a': builtins.list('type')}
- list2 = {'a': builtins.list('typo')}
- self.assertFalse(utils.json_structure_equals(list1, list2))
-
- json1 = {
- 'a': 1,
- 'b': True,
- 'c': 'hello',
- 'd': -2.4,
- 'e': {
- 'a': 'world',
- },
- 'f': [
- 0,
- 1,
- 2
- ],
- 'ignore': {
- 'a': False
- },
- 'deep': [
- {
- 'a': {},
- 'ignore': {}
- },
- {
- 'b': [],
- 'ignore': -1
- }
- ]
- }
- json2 = {
- 'a': 1,
- 'b': True,
- 'c': 'hello',
- 'd': -2.4,
- 'e': {
- 'a': 'world',
- },
- 'f': [
- 0,
- 1,
- 2
- ],
- 'ignore': {
- 'a': True
- },
- 'deep': [
- {
- 'a': {},
- 'ignore': {
- 'not_empty': 'hello world'
- }
- },
- {
- 'b': [],
- 'ignore': 1
- }
- ]
- }
-
- self.assertTrue(utils.json_structure_equals(json1, json2, ignore_keys={'ignore'}))
- self.assertFalse(utils.json_structure_equals(json1, json2))
-
- def test_reversible_json(self):
- for obj in [
- 1,
- "foobar",
- b"foobar",
- [1, 2, 3],
- [1, [2], 3],
- 1.2,
- type(None),
- int,
- str,
- numpy.ndarray,
- {'foo': 'bar'},
- {'encoding': 'something', 'value': 'else'},
- metadata_base.NO_VALUE,
- metadata_base.ALL_ELEMENTS,
- ]:
- self.assertEqual(utils.from_reversible_json_structure(json.loads(json.dumps(utils.to_reversible_json_structure(obj)))), obj, str(obj))
-
- self.assertTrue(numpy.isnan(utils.from_reversible_json_structure(json.loads(json.dumps(utils.to_reversible_json_structure(float('nan')))))))
-
- self.assertEqual(utils.from_reversible_json_structure(json.loads(json.dumps(utils.to_reversible_json_structure(numpy.array([1, 2, 3]))))).tolist(), [1, 2, 3])
-
- with self.assertRaises(TypeError):
- utils.to_reversible_json_structure({1: 2})
-
- def test_global_randomness_warning(self):
- with self.assertLogs(logger=utils.logger, level=logging.DEBUG) as cm:
- with utils.global_randomness_warning():
- random.randint(0, 10)
-
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Using global/shared random source using 'random.randint' can make execution not reproducible.")
-
- with self.assertLogs(logger=utils.logger, level=logging.DEBUG) as cm:
- with utils.global_randomness_warning():
- numpy.random.randint(0, 10)
-
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Using global/shared random source using 'numpy.random.randint' can make execution not reproducible.")
-
- if hasattr(numpy.random, 'default_rng'):
- with self.assertLogs(logger=utils.logger, level=logging.DEBUG) as cm:
- with utils.global_randomness_warning():
- numpy.random.default_rng()
-
- self.assertEqual(len(cm.records), 1)
- self.assertEqual(cm.records[0].message, "Using 'numpy.random.default_rng' without a seed can make execution not reproducible.")
-
- def test_yaml_float_parsing(self):
- self.assertEqual(json.loads('1000.0'), 1000)
- self.assertEqual(utils.yaml_load('1000.0'), 1000)
-
- self.assertEqual(json.loads('1e+3'), 1000)
- self.assertEqual(utils.yaml_load('1e+3'), 1000)
-
-
- if __name__ == '__main__':
- unittest.main()
|