You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import os.path as osp
  4. import tempfile
  5. import unittest
  6. from pathlib import Path
  7. from maas_lib.fileio import dump, load
  8. from maas_lib.utils.config import Config
  9. obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
  10. class ConfigTest(unittest.TestCase):
  11. def test_json(self):
  12. config_file = 'configs/examples/config.json'
  13. cfg = Config.from_file(config_file)
  14. self.assertEqual(cfg.a, 1)
  15. self.assertEqual(cfg.b, obj['b'])
  16. def test_yaml(self):
  17. config_file = 'configs/examples/config.yaml'
  18. cfg = Config.from_file(config_file)
  19. self.assertEqual(cfg.a, 1)
  20. self.assertEqual(cfg.b, obj['b'])
  21. def test_py(self):
  22. config_file = 'configs/examples/config.py'
  23. cfg = Config.from_file(config_file)
  24. self.assertEqual(cfg.a, 1)
  25. self.assertEqual(cfg.b, obj['b'])
  26. def test_dump(self):
  27. config_file = 'configs/examples/config.py'
  28. cfg = Config.from_file(config_file)
  29. self.assertEqual(cfg.a, 1)
  30. self.assertEqual(cfg.b, obj['b'])
  31. pretty_text = 'a = 1\n'
  32. pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
  33. json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
  34. yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
  35. with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
  36. self.assertEqual(pretty_text, cfg.dump())
  37. cfg.dump(ofile.name)
  38. with open(ofile.name, 'r') as infile:
  39. self.assertEqual(json_str, infile.read())
  40. with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
  41. cfg.dump(ofile.name)
  42. with open(ofile.name, 'r') as infile:
  43. self.assertEqual(yaml_str, infile.read())
  44. def test_to_dict(self):
  45. config_file = 'configs/examples/config.json'
  46. cfg = Config.from_file(config_file)
  47. d = cfg.to_dict()
  48. print(d)
  49. self.assertTrue(isinstance(d, dict))
  50. def test_to_args(self):
  51. def parse_fn(args):
  52. parser = argparse.ArgumentParser(prog='PROG')
  53. parser.add_argument('--model-dir', default='')
  54. parser.add_argument('--lr', type=float, default=0.001)
  55. parser.add_argument('--optimizer', default='')
  56. parser.add_argument('--weight-decay', type=float, default=1e-7)
  57. parser.add_argument(
  58. '--save-checkpoint-epochs', type=int, default=30)
  59. return parser.parse_args(args)
  60. cfg = Config.from_file('configs/examples/plain_args.yaml')
  61. args = cfg.to_args(parse_fn)
  62. self.assertEqual(args.model_dir, 'path/to/model')
  63. self.assertAlmostEqual(args.lr, 0.01)
  64. self.assertAlmostEqual(args.weight_decay, 1e-6)
  65. self.assertEqual(args.optimizer, 'Adam')
  66. self.assertEqual(args.save_checkpoint_epochs, 20)
  67. if __name__ == '__main__':
  68. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展