From: @zhangxiaoxiao16 Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34pull/13830/MERGE
| @@ -13,11 +13,11 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ ! -d out ]; then | |||
| mkdir out | |||
| if [ -d out ]; then | |||
| rm -rf out | |||
| fi | |||
| cd out | |||
| mkdir out && cd out | |||
| if [ -f "Makefile" ]; then | |||
| make clean | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -31,10 +31,9 @@ | |||
| #include "include/minddata/dataset/include/execute.h" | |||
| #include "../inc/utils.h" | |||
| using mindspore::GlobalContext; | |||
| using mindspore::Context; | |||
| using mindspore::Serialization; | |||
| using mindspore::Model; | |||
| using mindspore::ModelContext; | |||
| using mindspore::Status; | |||
| using mindspore::ModelType; | |||
| using mindspore::GraphCell; | |||
| @@ -127,15 +126,18 @@ int main(int argc, char **argv) { | |||
| std::cout << "Invalid fusion switch path" << std::endl; | |||
| return 1; | |||
| } | |||
| GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); | |||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||
| auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); | |||
| auto model_context = std::make_shared<mindspore::Context>(); | |||
| auto context = std::make_shared<Context>(); | |||
| auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||
| ascend310->SetDeviceID(FLAGS_device_id); | |||
| context->MutableDeviceInfo().push_back(ascend310); | |||
| mindspore::Graph graph; | |||
| Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); | |||
| if (!FLAGS_fusion_switch_path.empty()) { | |||
| ModelContext::SetFusionSwitchConfigPath(model_context, FLAGS_fusion_switch_path); | |||
| ascend310->SetFusionSwitchConfigPath(FLAGS_fusion_switch_path); | |||
| } | |||
| Model model(GraphCell(graph), model_context); | |||
| Status ret = model.Build(); | |||
| Model model; | |||
| Status ret = model.Build(GraphCell(graph), context); | |||
| if (ret != kSuccess) { | |||
| std::cout << "ERROR: Build failed." << std::endl; | |||
| return 1; | |||
| @@ -15,7 +15,7 @@ | |||
| # ============================================================================ | |||
| if [[ $# -lt 4 || $# -gt 5 ]]; then | |||
| echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DATA_ROOT] [DATA_LIST] [DEVICE_ID] | |||
| echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DATA_ROOT] [DATA_LIST] [DEVICE_ID] | |||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" | |||
| exit 1 | |||
| fi | |||
| @@ -406,7 +406,7 @@ The ckpt_file parameter is required, | |||
| ### Infer on Ascend310 | |||
| Before performing inference, the mindir file must bu exported by export script on the 910 environment. We only provide an example of inference using MINDIR model. | |||
| Before performing inference, the mindir file must bu exported by `export.py` script. We only provide an example of inference using MINDIR model. | |||
| Current batch_Size can only be set to 1. The precision calculation process needs about 70G+ memory space, otherwise the process will be killed for execeeding memory limits. | |||
| ```shell | |||
| @@ -17,6 +17,10 @@ | |||
| - [评估过程](#评估过程) | |||
| - [Ascend处理器环境评估](#ascend处理器环境评估) | |||
| - [GPU处理器环境评估](#gpu处理器环境评估) | |||
| - [推理过程](#推理过程) | |||
| - [导出MindIR](#导出mindir) | |||
| - [在Ascend310执行推理](#在ascend310执行推理) | |||
| - [结果](#结果) | |||
| - [模型描述](#模型描述) | |||
| - [性能](#性能) | |||
| - [评估性能](#评估性能) | |||
| @@ -312,6 +316,49 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.686 | |||
| mAP: 0.2244936111705981 | |||
| ``` | |||
| ## 推理过程 | |||
| ### [导出MindIR](#contents) | |||
| ```shell | |||
| python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] | |||
| ``` | |||
| 参数ckpt_file为必填项, | |||
| `EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。 | |||
| ### 在Ascend310执行推理 | |||
| 在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。 | |||
| 目前仅支持batch_Size为1的推理。精度计算过程需要70G+的内存,否则进程将会因为超出内存被系统终止。 | |||
| ```shell | |||
| # Ascend310 inference | |||
| bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] | |||
| ``` | |||
| - `DVPP` 为必填项,需要在["DVPP", "CPU"]选择,大小写均可。需要注意的是ssd_vgg16执行推理的图片尺寸为[300, 300],由于DVPP硬件限制宽为16整除,高为2整除,因此,这个网络需要通过CPU算子对图像进行前处理。 | |||
| - `DEVICE_ID` 可选,默认值为0。 | |||
| ### 结果 | |||
| 推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。 | |||
| ```bash | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.339 | |||
| Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.521 | |||
| Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.370 | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.168 | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.386 | |||
| Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.461 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.310 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.481 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.515 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.293 | |||
| Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.659 | |||
| mAP: 0.33880018942412393 | |||
| ``` | |||
| # 模型描述 | |||
| ## 性能 | |||
| @@ -13,11 +13,11 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ ! -d out ]; then | |||
| mkdir out | |||
| if [ -d out ]; then | |||
| rm -rf out | |||
| fi | |||
| cd out | |||
| mkdir out && cd out | |||
| if [ -f "Makefile" ]; then | |||
| make clean | |||
| @@ -32,10 +32,9 @@ | |||
| #include "include/minddata/dataset/include/vision.h" | |||
| #include "inc/utils.h" | |||
| using mindspore::GlobalContext; | |||
| using mindspore::Context; | |||
| using mindspore::Serialization; | |||
| using mindspore::Model; | |||
| using mindspore::ModelContext; | |||
| using mindspore::Status; | |||
| using mindspore::ModelType; | |||
| using mindspore::GraphCell; | |||
| @@ -64,21 +63,23 @@ int main(int argc, char **argv) { | |||
| return 1; | |||
| } | |||
| GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); | |||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||
| auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); | |||
| auto model_context = std::make_shared<mindspore::ModelContext>(); | |||
| auto context = std::make_shared<Context>(); | |||
| auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); | |||
| ascend310->SetDeviceID(FLAGS_device_id); | |||
| context->MutableDeviceInfo().push_back(ascend310); | |||
| mindspore::Graph graph; | |||
| Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); | |||
| if (FLAGS_cpu_dvpp == "DVPP") { | |||
| if (RealPath(FLAGS_aipp_path).empty()) { | |||
| std::cout << "Invalid aipp path" << std::endl; | |||
| return 1; | |||
| } else { | |||
| ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path); | |||
| ascend310->SetInsertOpConfigPath(FLAGS_aipp_path); | |||
| } | |||
| } | |||
| Model model(GraphCell(graph), model_context); | |||
| Status ret = model.Build(); | |||
| Model model; | |||
| Status ret = model.Build(GraphCell(graph), context); | |||
| if (ret != kSuccess) { | |||
| std::cout << "ERROR: Build failed." << std::endl; | |||
| return 1; | |||
| @@ -15,7 +15,7 @@ | |||
| # ============================================================================ | |||
| if [[ $# -lt 3 || $# -gt 4 ]]; then | |||
| echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] | |||
| echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DVPP] [DEVICE_ID] | |||
| DVPP is mandatory, and must choose from [DVPP|CPU], it's case-insensitive | |||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" | |||
| exit 1 | |||
| @@ -59,7 +59,7 @@ fi | |||
| function compile_app() | |||
| { | |||
| cd ../ascend310_infer | |||
| sh build.sh &> build.log | |||
| bash build.sh &> build.log | |||
| } | |||
| function infer() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -21,9 +21,9 @@ import os | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore import Tensor, context | |||
| from mindspore import export, load_checkpoint, load_param_into_net | |||
| from src.config import lstm_cfg as cfg | |||
| from src.config import lstm_cfg, lstm_cfg_ascend | |||
| from src.lstm import SentimentNet | |||
| if __name__ == '__main__': | |||
| @@ -31,9 +31,32 @@ if __name__ == '__main__': | |||
| parser.add_argument('--preprocess_path', type=str, default='./preprocess', | |||
| help='path where the pre-process data is stored.') | |||
| parser.add_argument('--ckpt_file', type=str, required=True, help='lstm ckpt file.') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | |||
| parser.add_argument("--file_name", type=str, default="lstm", help="output file name.") | |||
| parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'], | |||
| help='the target device to run, support "GPU", "CPU". Default: "Ascend".') | |||
| args = parser.parse_args() | |||
| context.set_context( | |||
| mode=context.GRAPH_MODE, | |||
| save_graphs=False, | |||
| device_target=args.device_target, | |||
| device_id=args.device_id) | |||
| if args.device_target == 'Ascend': | |||
| cfg = lstm_cfg_ascend | |||
| else: | |||
| cfg = lstm_cfg | |||
| embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) | |||
| if args.device_target == 'Ascend': | |||
| pad_num = int(np.ceil(cfg.embed_size / 16) * 16 - cfg.embed_size) | |||
| if pad_num > 0: | |||
| embedding_table = np.pad(embedding_table, [(0, 0), (0, pad_num)], 'constant') | |||
| cfg.embed_size = int(np.ceil(cfg.embed_size / 16) * 16) | |||
| network = SentimentNet(vocab_size=embedding_table.shape[0], | |||
| embed_size=cfg.embed_size, | |||
| num_hiddens=cfg.num_hiddens, | |||
| @@ -46,5 +69,5 @@ if __name__ == '__main__': | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| load_param_into_net(network, param_dict) | |||
| input_arr = Tensor(np.random.uniform(0.0, 1e5, size=[64, 500]).astype(np.int32)) | |||
| export(network, input_arr, file_name="lstm", file_format="MINDIR") | |||
| input_arr = Tensor(np.random.uniform(0.0, 1e5, size=[cfg.batch_size, 500]).astype(np.int32)) | |||
| export(network, input_arr, file_name=args.file_name, file_format=args.file_format) | |||