Browse Source

add se block for resnet50

tags/v0.7.0-beta
qujianwei 5 years ago
parent
commit
b079e34e73
9 changed files with 286 additions and 72 deletions
  1. +5
    -2
      model_zoo/official/cv/resnet/eval.py
  2. +8
    -3
      model_zoo/official/cv/resnet/scripts/run_distribute_train.sh
  3. +8
    -3
      model_zoo/official/cv/resnet/scripts/run_eval.sh
  4. +8
    -3
      model_zoo/official/cv/resnet/scripts/run_standalone_train.sh
  5. +25
    -3
      model_zoo/official/cv/resnet/src/config.py
  6. +53
    -1
      model_zoo/official/cv/resnet/src/dataset.py
  7. +12
    -0
      model_zoo/official/cv/resnet/src/lr_generator.py
  8. +156
    -46
      model_zoo/official/cv/resnet/src/resnet.py
  9. +11
    -11
      model_zoo/official/cv/resnet/train.py

+ 5
- 2
model_zoo/official/cv/resnet/eval.py View File

@@ -38,17 +38,20 @@ de.config.set_seed(1)

if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet

if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
else:
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config
from src.dataset import create_dataset3 as create_dataset
else:
from src.resnet import se_resnet50 as resnet
from src.config import config4 as config
from src.dataset import create_dataset4 as create_dataset

if __name__ == '__main__':
target = args_opt.device_target


+ 8
- 3
model_zoo/official/cv/resnet/scripts/run_distribute_train.sh View File

@@ -16,13 +16,13 @@

if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi

if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
exit 1
fi

@@ -38,6 +38,11 @@ then
exit 1
fi

if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit 1
fi

get_real_path(){
if [ "${1:0:1}" == "/" ]; then


+ 8
- 3
model_zoo/official/cv/resnet/scripts/run_eval.sh View File

@@ -16,13 +16,13 @@

if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi

if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
echo "error: the selected net is neither resnet50 nor resnet101 nor se-resnet50"
exit 1
fi

@@ -38,6 +38,11 @@ then
exit 1
fi

if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit 1
fi

get_real_path(){
if [ "${1:0:1}" == "/" ]; then


+ 8
- 3
model_zoo/official/cv/resnet/scripts/run_standalone_train.sh View File

@@ -16,13 +16,13 @@

if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi

if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
exit 1
fi

@@ -38,6 +38,11 @@ then
exit 1
fi

if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit 1
fi

get_real_path(){
if [ "${1:0:1}" == "/" ]; then


+ 25
- 3
model_zoo/official/cv/resnet/src/config.py View File

@@ -50,12 +50,12 @@ config2 = ed({
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"lr_decay_mode": "linear",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.1
"lr_max": 0.1,
"lr_end": 0.0
})

# config for resent101, imagenet2012
@@ -77,3 +77,25 @@ config3 = ed({
"label_smooth_factor": 0.1,
"lr": 0.1
})

# config for se-resnet50, imagenet2012
config4 = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 28,
"pretrain_epoch_size": 1,
"save_checkpoint": True,
"save_checkpoint_epochs": 4,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 3,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.0,
"lr_max": 0.3,
"lr_end": 0.0001
})

+ 53
- 1
model_zoo/official/cv/resnet/src/dataset.py View File

@@ -22,7 +22,6 @@ import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size


def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or evaluate cifar10 dataset for resnet50
@@ -191,6 +190,59 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target=

return ds

def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval imagenet2012 dataset for se-resnet50

Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend

Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=12, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [123.68, 116.78, 103.94]
std = [1.0, 1.0, 1.0]

# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(292),
C.CenterCrop(256),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]

type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=12, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=12, operations=type_cast_op)

# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)

# apply dataset repeat operation
ds = ds.repeat(repeat_num)

return ds

def _get_rank_info():
"""


+ 12
- 0
model_zoo/official/cv/resnet/src/lr_generator.py View File

@@ -62,6 +62,18 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:


+ 156
- 46
model_zoo/official/cv/resnet/src/resnet.py View File

@@ -15,32 +15,53 @@
"""ResNet."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor

from scipy.stats import truncnorm

def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)

def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)


def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)

@@ -55,9 +76,13 @@ def _bn_last(channel):
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)


@@ -69,6 +94,8 @@ class ResidualBlock(nn.Cell):
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net. Default: False.

Returns:
Tensor, output tensor.
@@ -81,19 +108,30 @@ class ResidualBlock(nn.Cell):
def __init__(self,
in_channel,
out_channel,
stride=1):
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()

self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)

self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)

self.conv3 = _conv1x1(channel, out_channel, stride=1)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)

self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)

if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()

self.down_sample = False
@@ -103,8 +141,17 @@ class ResidualBlock(nn.Cell):
self.down_sample_layer = None

if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
_bn(out_channel)])
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
self.add = P.TensorAdd()

def construct(self, x):
@@ -113,13 +160,23 @@ class ResidualBlock(nn.Cell):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)

if self.down_sample:
identity = self.down_sample_layer(identity)
@@ -141,6 +198,8 @@ class ResNet(nn.Cell):
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
Returns:
Tensor, output tensor.

@@ -159,43 +218,60 @@ class ResNet(nn.Cell):
in_channels,
out_channels,
strides,
num_classes):
num_classes,
use_se=False):
super(ResNet, self).__init__()

if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")

self.conv1 = _conv7x7(3, 64, stride=2)
self.use_se = use_se
self.se_block = False
if self.use_se:
self.se_block = True

if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")

self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)

self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)

def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.

@@ -205,7 +281,7 @@ class ResNet(nn.Cell):
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.

@@ -214,17 +290,31 @@ class ResNet(nn.Cell):
"""
layers = []

resnet_block = block(in_channel, out_channel, stride=stride)
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)

for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)

else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)

def construct(self, x):
x = self.conv1(x)
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
@@ -261,6 +351,26 @@ def resnet50(class_num=10):
[1, 2, 2, 2],
class_num)

def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of SE-ResNet50 neural network.

Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
use_se=True)

def resnet101(class_num=1001):
"""


+ 11
- 11
model_zoo/official/cv/resnet/train.py View File

@@ -50,17 +50,21 @@ de.config.set_seed(1)

if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet

if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
else:
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config
from src.dataset import create_dataset3 as create_dataset
else:
from src.resnet import se_resnet50 as resnet
from src.config import config4 as config
from src.dataset import create_dataset4 as create_dataset


if __name__ == '__main__':
target = args_opt.device_target
@@ -74,7 +78,7 @@ if __name__ == '__main__':
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
if args_opt.net == "resnet50":
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
@@ -112,14 +116,10 @@ if __name__ == '__main__':
cell.weight.dtype)

# init lr
if args_opt.net == "resnet50":
if args_opt.dataset == "cifar10":
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode='poly')
else:
lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine')
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
else:
lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size,
config.pretrain_epoch_size * step_size)


Loading…
Cancel
Save