| @@ -37,7 +37,7 @@ do | |||||
| cp -r ./src ./eval_$i | cp -r ./src ./eval_$i | ||||
| cd ./eval_$i || exit | cd ./eval_$i || exit | ||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| echo "start training for rank $i, device $DEVICE_ID" | |||||
| echo "start infering for rank $i, device $DEVICE_ID" | |||||
| env > env.log | env > env.log | ||||
| python eval.py \ | python eval.py \ | ||||
| --data_dir=$DATASET \ | --data_dir=$DATASET \ | ||||
| @@ -141,7 +141,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank | |||||
| dataset = TxtDataset(root, data_dir) | dataset = TxtDataset(root, data_dir) | ||||
| sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) | sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) | ||||
| de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) | de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) | ||||
| de_dataset.set_dataset_size(len(sampler)) | |||||
| de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) | de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) | ||||
| de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) | de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) | ||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import export, load_checkpoint, load_param_into_net | |||||
| from src.unet.unet_model import UNet | |||||
| parser = argparse.ArgumentParser(description='Export ckpt to air') | |||||
| parser.add_argument('--ckpt_file', type=str, default="ckpt_unet_medical_adam-1_600.ckpt", | |||||
| help='The path of input ckpt file') | |||||
| parser.add_argument('--air_file', type=str, default="unet_medical_adam-1_600.air", help='The path of output air file') | |||||
| args = parser.parse_args() | |||||
| net = UNet(n_channels=1, n_classes=2) | |||||
| # return a parameter dict for model | |||||
| param_dict = load_checkpoint(args.ckpt_file) | |||||
| # load the parameter into net | |||||
| load_param_into_net(net, param_dict) | |||||
| input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 572, 572]).astype(np.float32) | |||||
| export(net, Tensor(input_data), file_name=args.air_file, file_format='AIR') | |||||
| @@ -69,7 +69,7 @@ After installing MindSpore via the official website, you can start training and | |||||
| ``` | ``` | ||||
| # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. | # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. | ||||
| # The parameter of pretrained_backbone is not necessary. | |||||
| # pretrained_backbone can use src/convert_weight.py, convert darknet53.conv.74 to mindspore ckpt, darknet53.conv.74 can get from `https://pjreddie.com/media/files/darknet53.conv.74` . | |||||
| # The parameter of training_shape define image shape for network, default is "". | # The parameter of training_shape define image shape for network, default is "". | ||||
| # It means use 10 kinds of shape as input shape, or it can be set some kind of shape. | # It means use 10 kinds of shape as input shape, or it can be set some kind of shape. | ||||
| # run training example(1p) by python command. | # run training example(1p) by python command. | ||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Convert weight to mindspore ckpt.""" | |||||
| import os | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore.train.serialization import save_checkpoint | |||||
| from mindspore import Tensor | |||||
| from src.yolo import YOLOV3DarkNet53 | |||||
| def load_weight(weights_file): | |||||
| """Loads pre-trained weights.""" | |||||
| if not os.path.isfile(weights_file): | |||||
| raise ValueError(f'"{weights_file}" is not a valid weight file.') | |||||
| with open(weights_file, 'rb') as fp: | |||||
| np.fromfile(fp, dtype=np.int32, count=5) | |||||
| return np.fromfile(fp, dtype=np.float32) | |||||
| def build_network(): | |||||
| """Build YOLOv3 network.""" | |||||
| network = YOLOV3DarkNet53(is_training=True) | |||||
| params = network.get_parameters() | |||||
| params = [p for p in params if 'backbone' in p.name] | |||||
| return params | |||||
| def convert(weights_file, output_file): | |||||
| """Conver weight to mindspore ckpt.""" | |||||
| params = build_network() | |||||
| weights = load_weight(weights_file) | |||||
| index = 0 | |||||
| param_list = [] | |||||
| for i in range(0, len(params), 5): | |||||
| weight = params[i] | |||||
| mean = params[i+1] | |||||
| var = params[i+2] | |||||
| gamma = params[i+3] | |||||
| beta = params[i+4] | |||||
| beta_data = weights[index: index+beta.size()].reshape(beta.shape) | |||||
| index += beta.size() | |||||
| gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape) | |||||
| index += gamma.size() | |||||
| mean_data = weights[index: index+mean.size()].reshape(mean.shape) | |||||
| index += mean.size() | |||||
| var_data = weights[index: index + var.size()].reshape(var.shape) | |||||
| index += var.size() | |||||
| weight_data = weights[index: index+weight.size()].reshape(weight.shape) | |||||
| index += weight.size() | |||||
| param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape, | |||||
| 'data': Tensor(weight_data)}) | |||||
| param_list.append({'name': mean.name, 'type': mean.dtype, 'shape': mean.shape, 'data': Tensor(mean_data)}) | |||||
| param_list.append({'name': var.name, 'type': var.dtype, 'shape': var.shape, 'data': Tensor(var_data)}) | |||||
| param_list.append({'name': gamma.name, 'type': gamma.dtype, 'shape': gamma.shape, 'data': Tensor(gamma_data)}) | |||||
| param_list.append({'name': beta.name, 'type': beta.dtype, 'shape': beta.shape, 'data': Tensor(beta_data)}) | |||||
| save_checkpoint(param_list, output_file) | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description="yolov3 weight convert.") | |||||
| parser.add_argument("--input_file", type=str, default="./darknet53.conv.74", help="input file path.") | |||||
| parser.add_argument("--output_file", type=str, default="./ackbone_darknet53.ckpt", help="output file path.") | |||||
| args_opt = parser.parse_args() | |||||
| convert(args_opt.input_file, args_opt.output_file) | |||||
| @@ -115,39 +115,38 @@ class DarkNet(nn.Cell): | |||||
| out_channels[0], | out_channels[0], | ||||
| kernel_size=3, | kernel_size=3, | ||||
| stride=2) | stride=2) | ||||
| self.conv2 = conv_block(in_channels[1], | |||||
| out_channels[1], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv3 = conv_block(in_channels[2], | |||||
| out_channels[2], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv4 = conv_block(in_channels[3], | |||||
| out_channels[3], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.conv5 = conv_block(in_channels[4], | |||||
| out_channels[4], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.layer1 = self._make_layer(block, | self.layer1 = self._make_layer(block, | ||||
| layer_nums[0], | layer_nums[0], | ||||
| in_channel=out_channels[0], | in_channel=out_channels[0], | ||||
| out_channel=out_channels[0]) | out_channel=out_channels[0]) | ||||
| self.conv2 = conv_block(in_channels[1], | |||||
| out_channels[1], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.layer2 = self._make_layer(block, | self.layer2 = self._make_layer(block, | ||||
| layer_nums[1], | layer_nums[1], | ||||
| in_channel=out_channels[1], | in_channel=out_channels[1], | ||||
| out_channel=out_channels[1]) | out_channel=out_channels[1]) | ||||
| self.conv3 = conv_block(in_channels[2], | |||||
| out_channels[2], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.layer3 = self._make_layer(block, | self.layer3 = self._make_layer(block, | ||||
| layer_nums[2], | layer_nums[2], | ||||
| in_channel=out_channels[2], | in_channel=out_channels[2], | ||||
| out_channel=out_channels[2]) | out_channel=out_channels[2]) | ||||
| self.conv4 = conv_block(in_channels[3], | |||||
| out_channels[3], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.layer4 = self._make_layer(block, | self.layer4 = self._make_layer(block, | ||||
| layer_nums[3], | layer_nums[3], | ||||
| in_channel=out_channels[3], | in_channel=out_channels[3], | ||||
| out_channel=out_channels[3]) | out_channel=out_channels[3]) | ||||
| self.conv5 = conv_block(in_channels[4], | |||||
| out_channels[4], | |||||
| kernel_size=3, | |||||
| stride=2) | |||||
| self.layer5 = self._make_layer(block, | self.layer5 = self._make_layer(block, | ||||
| layer_nums[4], | layer_nums[4], | ||||
| in_channel=out_channels[4], | in_channel=out_channels[4], | ||||