| @@ -37,7 +37,7 @@ do | |||
| cp -r ./src ./eval_$i | |||
| cd ./eval_$i || exit | |||
| export RANK_ID=$i | |||
| echo "start training for rank $i, device $DEVICE_ID" | |||
| echo "start infering for rank $i, device $DEVICE_ID" | |||
| env > env.log | |||
| python eval.py \ | |||
| --data_dir=$DATASET \ | |||
| @@ -141,7 +141,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank | |||
| dataset = TxtDataset(root, data_dir) | |||
| sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) | |||
| de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) | |||
| de_dataset.set_dataset_size(len(sampler)) | |||
| de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) | |||
| de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # less required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.train.serialization import export, load_checkpoint, load_param_into_net | |||
| from src.unet.unet_model import UNet | |||
| parser = argparse.ArgumentParser(description='Export ckpt to air') | |||
| parser.add_argument('--ckpt_file', type=str, default="ckpt_unet_medical_adam-1_600.ckpt", | |||
| help='The path of input ckpt file') | |||
| parser.add_argument('--air_file', type=str, default="unet_medical_adam-1_600.air", help='The path of output air file') | |||
| args = parser.parse_args() | |||
| net = UNet(n_channels=1, n_classes=2) | |||
| # return a parameter dict for model | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| # load the parameter into net | |||
| load_param_into_net(net, param_dict) | |||
| input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 572, 572]).astype(np.float32) | |||
| export(net, Tensor(input_data), file_name=args.air_file, file_format='AIR') | |||
| @@ -69,7 +69,7 @@ After installing MindSpore via the official website, you can start training and | |||
| ``` | |||
| # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. | |||
| # The parameter of pretrained_backbone is not necessary. | |||
| # pretrained_backbone can use src/convert_weight.py, convert darknet53.conv.74 to mindspore ckpt, darknet53.conv.74 can get from `https://pjreddie.com/media/files/darknet53.conv.74` . | |||
| # The parameter of training_shape define image shape for network, default is "". | |||
| # It means use 10 kinds of shape as input shape, or it can be set some kind of shape. | |||
| # run training example(1p) by python command. | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Convert weight to mindspore ckpt.""" | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from mindspore import Tensor | |||
| from src.yolo import YOLOV3DarkNet53 | |||
| def load_weight(weights_file): | |||
| """Loads pre-trained weights.""" | |||
| if not os.path.isfile(weights_file): | |||
| raise ValueError(f'"{weights_file}" is not a valid weight file.') | |||
| with open(weights_file, 'rb') as fp: | |||
| np.fromfile(fp, dtype=np.int32, count=5) | |||
| return np.fromfile(fp, dtype=np.float32) | |||
| def build_network(): | |||
| """Build YOLOv3 network.""" | |||
| network = YOLOV3DarkNet53(is_training=True) | |||
| params = network.get_parameters() | |||
| params = [p for p in params if 'backbone' in p.name] | |||
| return params | |||
| def convert(weights_file, output_file): | |||
| """Conver weight to mindspore ckpt.""" | |||
| params = build_network() | |||
| weights = load_weight(weights_file) | |||
| index = 0 | |||
| param_list = [] | |||
| for i in range(0, len(params), 5): | |||
| weight = params[i] | |||
| mean = params[i+1] | |||
| var = params[i+2] | |||
| gamma = params[i+3] | |||
| beta = params[i+4] | |||
| beta_data = weights[index: index+beta.size()].reshape(beta.shape) | |||
| index += beta.size() | |||
| gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape) | |||
| index += gamma.size() | |||
| mean_data = weights[index: index+mean.size()].reshape(mean.shape) | |||
| index += mean.size() | |||
| var_data = weights[index: index + var.size()].reshape(var.shape) | |||
| index += var.size() | |||
| weight_data = weights[index: index+weight.size()].reshape(weight.shape) | |||
| index += weight.size() | |||
| param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape, | |||
| 'data': Tensor(weight_data)}) | |||
| param_list.append({'name': mean.name, 'type': mean.dtype, 'shape': mean.shape, 'data': Tensor(mean_data)}) | |||
| param_list.append({'name': var.name, 'type': var.dtype, 'shape': var.shape, 'data': Tensor(var_data)}) | |||
| param_list.append({'name': gamma.name, 'type': gamma.dtype, 'shape': gamma.shape, 'data': Tensor(gamma_data)}) | |||
| param_list.append({'name': beta.name, 'type': beta.dtype, 'shape': beta.shape, 'data': Tensor(beta_data)}) | |||
| save_checkpoint(param_list, output_file) | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description="yolov3 weight convert.") | |||
| parser.add_argument("--input_file", type=str, default="./darknet53.conv.74", help="input file path.") | |||
| parser.add_argument("--output_file", type=str, default="./ackbone_darknet53.ckpt", help="output file path.") | |||
| args_opt = parser.parse_args() | |||
| convert(args_opt.input_file, args_opt.output_file) | |||
| @@ -115,39 +115,38 @@ class DarkNet(nn.Cell): | |||
| out_channels[0], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv2 = conv_block(in_channels[1], | |||
| out_channels[1], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv3 = conv_block(in_channels[2], | |||
| out_channels[2], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv4 = conv_block(in_channels[3], | |||
| out_channels[3], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.conv5 = conv_block(in_channels[4], | |||
| out_channels[4], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.layer1 = self._make_layer(block, | |||
| layer_nums[0], | |||
| in_channel=out_channels[0], | |||
| out_channel=out_channels[0]) | |||
| self.conv2 = conv_block(in_channels[1], | |||
| out_channels[1], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.layer2 = self._make_layer(block, | |||
| layer_nums[1], | |||
| in_channel=out_channels[1], | |||
| out_channel=out_channels[1]) | |||
| self.conv3 = conv_block(in_channels[2], | |||
| out_channels[2], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.layer3 = self._make_layer(block, | |||
| layer_nums[2], | |||
| in_channel=out_channels[2], | |||
| out_channel=out_channels[2]) | |||
| self.conv4 = conv_block(in_channels[3], | |||
| out_channels[3], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.layer4 = self._make_layer(block, | |||
| layer_nums[3], | |||
| in_channel=out_channels[3], | |||
| out_channel=out_channels[3]) | |||
| self.conv5 = conv_block(in_channels[4], | |||
| out_channels[4], | |||
| kernel_size=3, | |||
| stride=2) | |||
| self.layer5 = self._make_layer(block, | |||
| layer_nums[4], | |||
| in_channel=out_channels[4], | |||