|
|
|
@@ -17,7 +17,7 @@ import argparse |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export |
|
|
|
|
|
|
|
from eval import BuildEvalNetwork |
|
|
|
from src.nets import net_factory |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='checkpoint export') |
|
|
|
@@ -43,6 +43,7 @@ if __name__ == '__main__': |
|
|
|
network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True) |
|
|
|
else: |
|
|
|
network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True) |
|
|
|
network = BuildEvalNetwork(network) |
|
|
|
param_dict = load_checkpoint(args.ckpt_file) |
|
|
|
|
|
|
|
# load the parameter into net |
|
|
|
|