|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import unittest
- from collections import OrderedDict
- import torch
- from torch import nn
-
- from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts
- from detectron2.utils.logger import setup_logger
-
-
- class TestCheckpointer(unittest.TestCase):
- def setUp(self):
- setup_logger()
-
- def create_complex_model(self):
- m = nn.Module()
- m.block1 = nn.Module()
- m.block1.layer1 = nn.Linear(2, 3)
- m.layer2 = nn.Linear(3, 2)
- m.res = nn.Module()
- m.res.layer2 = nn.Linear(3, 2)
-
- state_dict = OrderedDict()
- state_dict["layer1.weight"] = torch.rand(3, 2)
- state_dict["layer1.bias"] = torch.rand(3)
- state_dict["layer2.weight"] = torch.rand(2, 3)
- state_dict["layer2.bias"] = torch.rand(2)
- state_dict["res.layer2.weight"] = torch.rand(2, 3)
- state_dict["res.layer2.bias"] = torch.rand(2)
- return m, state_dict
-
- def test_complex_model_loaded(self):
- for add_data_parallel in [False, True]:
- model, state_dict = self.create_complex_model()
- if add_data_parallel:
- model = nn.DataParallel(model)
- model_sd = model.state_dict()
-
- align_and_update_state_dicts(model_sd, state_dict)
- for loaded, stored in zip(model_sd.values(), state_dict.values()):
- # different tensor references
- self.assertFalse(id(loaded) == id(stored))
- # same content
- self.assertTrue(loaded.equal(stored))
-
-
- if __name__ == "__main__":
- unittest.main()
|