@@ -360,14 +360,25 @@ class Module(metaclass=ABCMeta):
loaded, skipped = self._load_state_dict_with_closure(closure)
unused = set(unused) - loaded
if strict and len(unused) != 0:
raise KeyError(
"Unused params violate `strict=True`, unused={}".format(unused)
)
if strict and len(skipped) != 0:
raise KeyError(
"Missing params violate `strict=True`, missing={}".format(skipped)
)
if len(unused) != 0:
if strict:
raise KeyError(
"Unused params violate `strict=True`, unused={}".format(unused)
)
else:
logger.warning(
"Unused params in `strict=False` mode, unused={}".format(unused)
)
if len(skipped) != 0:
if strict:
raise KeyError(
"Missing params violate `strict=True`, missing={}".format(skipped)
)
else:
logger.warning(
"Missing params in `strict=False` mode, missing={}".format(skipped)
)
def _load_state_dict_with_closure(self, closure):
"""Advance state_dict load through callable `closure` whose signature is
@@ -383,7 +394,6 @@ class Module(metaclass=ABCMeta):
for k, var in local_state_dict.items():
to_be_load = closure(k, var)
if to_be_load is None:
logger.warning("skip loading param `%s`", k)
skipped.append(k)
continue
assert isinstance(