Browse Source

fix checkpoint evaliaction.

tags/v0.6.0-beta
chenzomi 5 years ago
parent
commit
694a1c8067
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      mindspore/train/serialization.py

+ 4
- 3
mindspore/train/serialization.py View File

@@ -186,9 +186,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
raise ValueError(e.__str__())

parameter_dict = {}
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
if checkpoint_list.model_type:
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
try:
for element in checkpoint_list.value:
data = element.tensor.tensor_content


Loading…
Cancel
Save