Browse Source

fix(mge/module): add the more warnings in load_state_dict

In the before, the information of miss matching and unused operators won't be printed in the non-trict mode, this commit add the information.

GitOrigin-RevId: b2543eb832
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
9a42c63641
1 changed files with 19 additions and 9 deletions
  1. +19
    -9
      python_module/megengine/module/module.py

+ 19
- 9
python_module/megengine/module/module.py View File

@@ -360,14 +360,25 @@ class Module(metaclass=ABCMeta):
loaded, skipped = self._load_state_dict_with_closure(closure)
unused = set(unused) - loaded

if strict and len(unused) != 0:
raise KeyError(
"Unused params violate `strict=True`, unused={}".format(unused)
)
if strict and len(skipped) != 0:
raise KeyError(
"Missing params violate `strict=True`, missing={}".format(skipped)
)
if len(unused) != 0:
if strict:
raise KeyError(
"Unused params violate `strict=True`, unused={}".format(unused)
)
else:
logger.warning(
"Unused params in `strict=False` mode, unused={}".format(unused)
)

if len(skipped) != 0:
if strict:
raise KeyError(
"Missing params violate `strict=True`, missing={}".format(skipped)
)
else:
logger.warning(
"Missing params in `strict=False` mode, missing={}".format(skipped)
)

def _load_state_dict_with_closure(self, closure):
"""Advance state_dict load through callable `closure` whose signature is
@@ -383,7 +394,6 @@ class Module(metaclass=ABCMeta):
for k, var in local_state_dict.items():
to_be_load = closure(k, var)
if to_be_load is None:
logger.warning("skip loading param `%s`", k)
skipped.append(k)
continue
assert isinstance(


Loading…
Cancel
Save