Merge pull request !3950 from ZPaC/ci-add-ps-casestags/v0.7.0-beta
| @@ -25,9 +25,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| size_t axis = kShape4dDims - input_shape.size(); | |||
| CPUKernelUtils::ExpandDimsTo4(&input_shape); | |||
| CPUKernelUtils::ExpandDimsTo4(&output_shape); | |||
| size_t axis = kShape2dDims - input_shape.size(); | |||
| for (auto dim : input_shape) { | |||
| input_dims_ *= dim; | |||
| } | |||
| @@ -40,6 +38,8 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| values.insert(values.end(), input_shape.begin(), input_shape.end()); | |||
| values.insert(values.end(), indices_shape.begin(), indices_shape.end()); | |||
| values.insert(values.end(), output_shape.begin(), output_shape.end()); | |||
| MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape | |||
| << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; | |||
| std::vector<int> lens{SizeToInt(input_shape.size()), SizeToInt(indices_shape.size()), SizeToInt(output_shape.size())}; | |||
| const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); | |||
| if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { | |||
| @@ -25,11 +25,15 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace ps { | |||
| using mindspore::parallel::ps::Util; | |||
| constexpr int kAxis = 2; | |||
| constexpr int kAxis = 0; | |||
| void EmbeddingLookUpPSKernel::InitKernel( | |||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes; | |||
| input_shape_ = *(shape_vec[0]); | |||
| first_dim_size_ = input_shape_[0]; | |||
| for (size_t i = 1; i < input_shape_.size(); ++i) { | |||
| outer_dim_size_ *= input_shape_[i]; | |||
| } | |||
| auto indices_shape = *(shape_vec[1]); | |||
| indices_lens_ = 1; | |||
| for (auto shape : indices_shape) { | |||
| @@ -49,7 +53,6 @@ void EmbeddingLookUpPSKernel::InitKernel( | |||
| size_t output_size = | |||
| std::accumulate(output_shape.begin(), output_shape.end(), sizeof(float), std::multiplies<size_t>()); | |||
| output_size_list_.emplace_back(output_size); | |||
| CPUKernelUtils::ExpandDimsTo4(&input_shape_); | |||
| } | |||
| void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| @@ -77,7 +77,7 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| size_t worker_num) { | |||
| AddressPtr weight_addr = std::make_shared<kernel::Address>(); | |||
| weight_addr->addr = weight->data(); | |||
| weight_addr->size = weight->size(); | |||
| weight_addr->size = weight->size() * sizeof(float); | |||
| AddressPtr m = std::make_shared<kernel::Address>(); | |||
| m->addr = new float[weight->size()]; | |||
| m->size = weight->size() * sizeof(float); | |||
| @@ -156,7 +156,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| size_t worker_num) { | |||
| AddressPtr weight_addr = std::make_shared<kernel::Address>(); | |||
| weight_addr->addr = weight->data(); | |||
| weight_addr->size = weight->size(); | |||
| weight_addr->size = weight->size() * sizeof(float); | |||
| AddressPtr accum = std::make_shared<kernel::Address>(); | |||
| accum->addr = new float[weight->size()]; | |||
| accum->size = weight->size() * sizeof(float); | |||
| @@ -166,7 +166,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| } | |||
| AddressPtr linear = std::make_shared<kernel::Address>(); | |||
| linear->addr = new float[weight->size()]; | |||
| auto ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); | |||
| int ret = memset_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memset_s error, errorno(" << ret << ")"; | |||
| } | |||
| @@ -176,9 +176,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies<size_t>()); | |||
| AddressPtr grad = std::make_shared<kernel::Address>(); | |||
| grad->addr = new float[total_grad_size * worker_num]; | |||
| auto ret1 = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); | |||
| if (ret1 != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret1 << ")"; | |||
| ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| grad->size = lens[0] * sizeof(float); | |||
| @@ -187,10 +187,10 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies<size_t>()); | |||
| AddressPtr indices = std::make_shared<kernel::Address>(); | |||
| indices->addr = new float[total_indice_size * worker_num]; | |||
| auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast<float *>(values.data()) + lens[0], | |||
| lens[1] * sizeof(float)); | |||
| if (ret2 != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; | |||
| ret = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast<float *>(values.data()) + lens[0], | |||
| lens[1] * sizeof(float)); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| indices->size = lens[1] * sizeof(int); | |||
| @@ -0,0 +1,75 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| #bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | |||
| # LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM | |||
| # SCHED_HOST SCHED_PORT ROLE | |||
| export RANK_SIZE=$1 | |||
| export EPOCH_SIZE=$2 | |||
| export DATASET=$3 | |||
| export RANK_TABLE_FILE=$4 | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| export MS_WORKER_NUM=$RANK_SIZE | |||
| export LOCAL_WORKER_NUM=$5 | |||
| export LOCAL_SERVER_NUM=$6 | |||
| export MS_SERVER_NUM=$7 | |||
| export MS_SCHED_HOST=$8 | |||
| export MS_SCHED_PORT=$9 | |||
| export MS_ROLE=${10} | |||
| echo "=====Role is $MS_ROLE======" | |||
| if [ "$MS_ROLE" == "MS_SCHED" ];then | |||
| for((i=0;i<1;i++)); | |||
| do | |||
| rm -rf ${execute_path}/sched_$i/ | |||
| mkdir ${execute_path}/sched_$i/ | |||
| cd ${execute_path}/sched_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >sched_$i.log 2>&1 & | |||
| done | |||
| fi | |||
| if [ "$MS_ROLE" == "MS_PSERVER" ];then | |||
| for((i=0;i<$LOCAL_SERVER_NUM;i++)); | |||
| do | |||
| rm -rf ${execute_path}/server_$i/ | |||
| mkdir ${execute_path}/server_$i/ | |||
| cd ${execute_path}/server_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >server_$i.log 2>&1 & | |||
| done | |||
| fi | |||
| if [ "$MS_ROLE" == "MS_WORKER" ];then | |||
| for((i=0;i<$LOCAL_WORKER_NUM;i++)); | |||
| do | |||
| rm -rf ${execute_path}/worker_$i/ | |||
| mkdir ${execute_path}/worker_$i/ | |||
| cd ${execute_path}/worker_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_parameter_server.py --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 >worker_$i.log 2>&1 & | |||
| done | |||
| fi | |||
| @@ -0,0 +1,55 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| self_path=$(dirname "${script_self}") | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| DEVICE_TARGET=$1 | |||
| export MS_WORKER_NUM=$2 | |||
| export MS_SERVER_NUM=$3 | |||
| export MS_SCHED_HOST=$4 | |||
| export MS_SCHED_PORT=$5 | |||
| export MS_ROLE=MS_SCHED | |||
| for((i=0;i<1;i++)); | |||
| do | |||
| rm -rf ${execute_path}/sched_$i/ | |||
| mkdir ${execute_path}/sched_$i/ | |||
| cd ${execute_path}/sched_$i/ || exit | |||
| python ${self_path}/../test_cmp_sparse_embedding.py & | |||
| done | |||
| export MS_ROLE=MS_PSERVER | |||
| for((i=0;i<$MS_SERVER_NUM;i++)); | |||
| do | |||
| rm -rf ${execute_path}/server_$i/ | |||
| mkdir ${execute_path}/server_$i/ | |||
| cd ${execute_path}/server_$i/ || exit | |||
| python ${self_path}/../test_cmp_sparse_embedding.py & | |||
| done | |||
| export MS_ROLE=MS_WORKER | |||
| for((i=0;i<$MS_WORKER_NUM;i++)); | |||
| do | |||
| rm -rf ${execute_path}/worker_$i/ | |||
| mkdir ${execute_path}/worker_$i/ | |||
| cd ${execute_path}/worker_$i/ || exit | |||
| python ${self_path}/../test_cmp_sparse_embedding.py & | |||
| done | |||
| wait $! | |||
| exit $? | |||
| @@ -0,0 +1,106 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Adam | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal, initializer | |||
| from mindspore import Parameter | |||
| parser = argparse.ArgumentParser(description="test_sparse_embedding") | |||
| parser.add_argument("--device_target", type=str, default="Ascend") | |||
| args, _ = parser.parse_known_args() | |||
| device_target = args.device_target | |||
| context.set_context( | |||
| mode=context.GRAPH_MODE, device_target=device_target, enable_sparse=True | |||
| ) | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| """weight initial for fc layer""" | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| """weight initial""" | |||
| return TruncatedNormal(0.02) | |||
| class LeNet5(nn.Cell): | |||
| def __init__(self, num_class=10): | |||
| super(LeNet5, self).__init__() | |||
| self.cast = P.Cast() | |||
| self.flatten = nn.Flatten() | |||
| self.embedding_table = Parameter( | |||
| initializer("normal", (16, 4), mstype.float32), name="embedding_table" | |||
| ) | |||
| self.embedding = nn.EmbeddingLookup() | |||
| self.relu = nn.ReLU() | |||
| self.fc = fc_with_initialize(12, num_class) | |||
| def construct(self, x): | |||
| x = self.cast(x, mstype.int32) | |||
| x = self.embedding(self.embedding_table, x) | |||
| x = self.flatten(x) | |||
| x = self.fc(x) | |||
| return x | |||
| def do_sparse_embedding(ps=False): | |||
| epoch = 10 | |||
| net = LeNet5(10) | |||
| if ps: | |||
| net.embedding_table.set_param_ps() | |||
| optimizer = Adam(filter(lambda x: x.requires_grad, net.get_parameters())) | |||
| optimizer.sparse_opt.add_prim_attr("primitive_target", "CPU") | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits( | |||
| is_grad=False, sparse=True, reduction="mean" | |||
| ) | |||
| net_with_criterion = WithLossCell(net, criterion) | |||
| train_network = TrainOneStepCell(net_with_criterion, optimizer) | |||
| train_network.set_train() | |||
| losses = [] | |||
| for _ in range(epoch): | |||
| data = Tensor(np.random.randint(0, 15, (32, 3), np.int32)) | |||
| label = Tensor(np.random.randint(0, 9, (32), np.int32)) | |||
| loss = train_network(data, label).asnumpy() | |||
| losses.append(loss) | |||
| print(losses) | |||
| return losses | |||
| envs = os.environ | |||
| if __name__ == "__main__": | |||
| np.random.seed(0) | |||
| ps_loss = do_sparse_embedding(True) | |||
| if envs.get("MS_ROLE") == "MS_WORKER": | |||
| envs["MS_ROLE"] = "" | |||
| np.random.seed(0) | |||
| no_ps_loss = do_sparse_embedding() | |||
| envs["MS_ROLE"] = "MS_WORKER" | |||
| assert np.allclose(ps_loss, no_ps_loss, rtol=1.0e-6, atol=1.0e-6) | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import pytest | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_cmp_sparse_embedding(): | |||
| return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8081") | |||
| assert return_code == 0 | |||
| @@ -0,0 +1,56 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| self_path=$(dirname "${script_self}") | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| DEVICE_TARGET=$1 | |||
| DATASET_PATH=$2 | |||
| export MS_WORKER_NUM=$3 | |||
| export MS_SERVER_NUM=$4 | |||
| export MS_SCHED_HOST=$5 | |||
| export MS_SCHED_PORT=$6 | |||
| export MS_ROLE=MS_SCHED | |||
| for((i=0;i<1;i++)); | |||
| do | |||
| rm -rf ${execute_path}/sched_$i/ | |||
| mkdir ${execute_path}/sched_$i/ | |||
| cd ${execute_path}/sched_$i/ || exit | |||
| python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & | |||
| done | |||
| export MS_ROLE=MS_PSERVER | |||
| for((i=0;i<$MS_SERVER_NUM;i++)); | |||
| do | |||
| rm -rf ${execute_path}/server_$i/ | |||
| mkdir ${execute_path}/server_$i/ | |||
| cd ${execute_path}/server_$i/ || exit | |||
| python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & | |||
| done | |||
| export MS_ROLE=MS_WORKER | |||
| for((i=0;i<$MS_WORKER_NUM;i++)); | |||
| do | |||
| rm -rf ${execute_path}/worker_$i/ | |||
| mkdir ${execute_path}/worker_$i/ | |||
| cd ${execute_path}/worker_$i/ || exit | |||
| python ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET --dataset_path=$DATASET_PATH & | |||
| done | |||
| wait $! | |||
| exit $? | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import pytest | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_full_ps_ascend_lenet(): | |||
| return_code = os.system( | |||
| "bash shell_run_test.sh Ascend /home/workspace/mindspore_dataset/mnist 1 1 127.0.0.1 8082" | |||
| ) | |||
| assert return_code == 0 | |||
| @@ -17,14 +17,16 @@ import os | |||
| # @pytest.mark.level0 | |||
| # @pytest.mark.platform_arm_ascend_training | |||
| # @pytest.mark.platform_x86_ascend_training | |||
| # @pytest.mark.env_onecard | |||
| def test_full_ps_ascend_lenet(): | |||
| return_code = os.system("bash run_full_ps_lenet.sh Ascend 1 1 127.0.0.1 8088") | |||
| # @pytest.mark.env_single | |||
| def test_multi_worker_full_ps_ascend_lenet(): | |||
| return_code = os.system("bash shell_run_test.sh Ascend 8 1 127.0.0.1 8088") | |||
| assert return_code == 0 | |||
| # @pytest.mark.level0 | |||
| # @pytest.mark.platform_x86_gpu_training | |||
| # @pytest.mark.platform_arm_ascend_training | |||
| # @pytest.mark.platform_x86_ascend_training | |||
| # @pytest.mark.env_onecard | |||
| def test_full_ps_gpu_lenet(): | |||
| return_code = os.system("bash run_full_ps_lenet.sh GPU 1 1 127.0.0.1 8088") | |||
| def test_full_ps_ascend_lenet(): | |||
| return_code = os.system("bash shell_run_test.sh Ascend 1 1 127.0.0.1 8088") | |||
| assert return_code == 0 | |||
| @@ -32,7 +32,7 @@ do | |||
| cd ${execute_path}/sched_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET & | |||
| python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & | |||
| done | |||
| export MS_ROLE=MS_PSERVER | |||
| @@ -43,7 +43,7 @@ do | |||
| cd ${execute_path}/server_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET & | |||
| python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & | |||
| done | |||
| export MS_ROLE=MS_WORKER | |||
| @@ -54,7 +54,7 @@ do | |||
| cd ${execute_path}/worker_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../test_full_ps_lenet.py --device_target=$DEVICE_TARGET & | |||
| python ${self_path}/../test_multi_worker_full_ps_lenet.py --device_target=$DEVICE_TARGET & | |||
| done | |||
| wait $! | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore import Tensor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| parser = argparse.ArgumentParser(description="test_ps_lenet") | |||
| parser.add_argument("--device_target", type=str, default="Ascend") | |||
| args, _ = parser.parse_known_args() | |||
| device_target = args.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target) | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| """weight initial for conv layer""" | |||
| weight = weight_variable() | |||
| return nn.Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size=kernel_size, | |||
| stride=stride, | |||
| padding=padding, | |||
| weight_init=weight, | |||
| has_bias=False, | |||
| pad_mode="valid", | |||
| ) | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| """weight initial for fc layer""" | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| """weight initial""" | |||
| return TruncatedNormal(0.02) | |||
| class LeNet5(nn.Cell): | |||
| def __init__(self, num_class=10, channel=3): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = conv(channel, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| if __name__ == "__main__": | |||
| epoch = 5 | |||
| np.random.seed(0) | |||
| network = LeNet5(10) | |||
| network.set_param_ps() | |||
| criterion = nn.SoftmaxCrossEntropyWithLogits( | |||
| is_grad=False, sparse=True, reduction="mean" | |||
| ) | |||
| net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) | |||
| net_with_criterion = WithLossCell(network, criterion) | |||
| train_network = TrainOneStepCell(net_with_criterion, net_opt) | |||
| train_network.set_train() | |||
| losses = [] | |||
| for _ in range(epoch): | |||
| data = Tensor(np.random.rand(32, 3, 32, 32).astype(np.float32)) | |||
| label = Tensor(np.random.randint(0, 9, (32)).astype(np.int32)) | |||
| loss = train_network(data, label).asnumpy() | |||
| losses.append(loss) | |||
| print(losses) | |||