From: @zhangxiaoxiao16 Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34pull/14164/MERGE
| @@ -21,6 +21,7 @@ import cv2 | |||
| from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.nets import net_factory | |||
| @@ -47,6 +48,8 @@ def parse_args(): | |||
| parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | |||
| parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn') | |||
| parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate') | |||
| parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW", | |||
| help="NCHW or NHWC") | |||
| args, _ = parser.parse_known_args() | |||
| return args | |||
| @@ -70,12 +73,16 @@ def resize_long(img, long_size=513): | |||
| class BuildEvalNetwork(nn.Cell): | |||
| def __init__(self, network): | |||
| def __init__(self, network, input_format="NCHW"): | |||
| super(BuildEvalNetwork, self).__init__() | |||
| self.network = network | |||
| self.softmax = nn.Softmax(axis=1) | |||
| self.transpose = ops.Transpose() | |||
| self.format = input_format | |||
| def construct(self, input_data): | |||
| if self.format == "NHWC": | |||
| input_data = self.transpose(input_data, (0, 3, 1, 2)) | |||
| output = self.network(input_data) | |||
| output = self.softmax(output) | |||
| return output | |||
| @@ -96,7 +103,6 @@ def pre_process(args, img_, crop_size=513): | |||
| pad_w = crop_size - img_.shape[1] | |||
| if pad_h > 0 or pad_w > 0: | |||
| img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) | |||
| # hwc to chw | |||
| img_ = img_.transpose((2, 0, 1)) | |||
| return img_, resize_h, resize_w | |||
| @@ -162,7 +168,7 @@ def net_eval(): | |||
| else: | |||
| raise NotImplementedError('model [{:s}] not recognized'.format(args.model)) | |||
| eval_net = BuildEvalNetwork(network) | |||
| eval_net = BuildEvalNetwork(network, args.input_format) | |||
| # load model | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| @@ -32,6 +32,8 @@ parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU" | |||
| parser.add_argument('--model', type=str.lower, default='deeplab_v3_s8', choices=['deeplab_v3_s16', 'deeplab_v3_s8'], | |||
| help='Select model structure (Default: deeplab_v3_s8)') | |||
| parser.add_argument('--num_classes', type=int, default=21, help='the number of classes (Default: 21)') | |||
| parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW", | |||
| help="NCHW or NHWC") | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| @@ -43,10 +45,13 @@ if __name__ == '__main__': | |||
| network = net_factory.nets_map['deeplab_v3_s16']('eval', args.num_classes, 16, True) | |||
| else: | |||
| network = net_factory.nets_map['deeplab_v3_s8']('eval', args.num_classes, 8, True) | |||
| network = BuildEvalNetwork(network) | |||
| network = BuildEvalNetwork(network, args.input_format) | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| # load the parameter into net | |||
| load_param_into_net(network, param_dict) | |||
| input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32)) | |||
| if args.input_format == "NHWC": | |||
| input_data = Tensor(np.ones([args.batch_size, args.input_size, args.input_size, 3]).astype(np.float32)) | |||
| else: | |||
| input_data = Tensor(np.ones([args.batch_size, 3, args.input_size, args.input_size]).astype(np.float32)) | |||
| export(network, input_data, file_name=args.file_name, file_format=args.file_format) | |||
| @@ -39,10 +39,9 @@ using mindspore::dataset::vision::CenterCrop; | |||
| using mindspore::dataset::vision::Normalize; | |||
| using mindspore::dataset::vision::HWC2CHW; | |||
| using mindspore::dataset::TensorTransform; | |||
| using mindspore::GlobalContext; | |||
| using mindspore::Context; | |||
| using mindspore::Serialization; | |||
| using mindspore::Model; | |||
| using mindspore::ModelContext; | |||
| using mindspore::Status; | |||
| using mindspore::ModelType; | |||
| using mindspore::GraphCell; | |||
| @@ -62,14 +61,15 @@ int main(int argc, char **argv) { | |||
| return 1; | |||
| } | |||
| auto context = std::make_shared<Context>(); | |||
| auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||
| ascend310->SetDeviceID(FLAGS_device_id); | |||
| context->MutableDeviceInfo().push_back(ascend310); | |||
| mindspore::Graph graph; | |||
| Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); | |||
| Model model; | |||
| Status ret = model.Build(GraphCell(graph), context); | |||
| GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); | |||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||
| auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); | |||
| auto model_context = std::make_shared<mindspore::ModelContext>(); | |||
| Model model(GraphCell(graph), model_context); | |||
| Status ret = model.Build(); | |||
| if (ret != kSuccess) { | |||
| std::cout << "ERROR: Build failed." << std::endl; | |||
| return 1; | |||
| @@ -66,6 +66,7 @@ int main(int argc, char **argv) { | |||
| auto context = std::make_shared<Context>(); | |||
| auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||
| ascend310->SetDeviceID(FLAGS_device_id); | |||
| ascend310->SetBufferOptimizeMode("off_optimize"); | |||
| context->MutableDeviceInfo().push_back(ascend310); | |||
| mindspore::Graph graph; | |||
| Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); | |||