|
|
|
@@ -186,9 +186,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): |
|
|
|
raise ValueError(e.__str__()) |
|
|
|
|
|
|
|
parameter_dict = {} |
|
|
|
if model_type != checkpoint_list.model_type: |
|
|
|
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( |
|
|
|
checkpoint_list.model_type, model_type)) |
|
|
|
if checkpoint_list.model_type: |
|
|
|
if model_type != checkpoint_list.model_type: |
|
|
|
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( |
|
|
|
checkpoint_list.model_type, model_type)) |
|
|
|
try: |
|
|
|
for element in checkpoint_list.value: |
|
|
|
data = element.tensor.tensor_content |
|
|
|
|