From: @big_hair_big_hair Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -295,12 +295,13 @@ class ParallelStrategySearchFactory: | |||
| newest_ckpt_file = find_newest_ckpt_file(ckpt_path) | |||
| return load_checkpoint(newest_ckpt_file) | |||
| def mindspore_auto_parallel_impl(self, dataset, epoch, device_num): | |||
| def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, auto_parallel_search_mode="dynamic_programming"): | |||
| parallel_mode_net = self.parallel_mode_net | |||
| set_algo_parameters(fully_use_devices=False) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, | |||
| device_num=device_num) | |||
| device_num=device_num, | |||
| auto_parallel_search_mode=auto_parallel_search_mode) | |||
| self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net, | |||
| dataset=dataset, epoch=epoch) | |||
| context.reset_auto_parallel_context() | |||
| @@ -352,3 +353,30 @@ def test_auto_parallel_strategy_search_axis_1_basic(): | |||
| fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, | |||
| epoch=2, device_num=8) | |||
| fact.checkpoint_cmp(inputs_np=inputs_np) | |||
| def test_auto_parallel_recursive_strategy_search_axis_1_basic(): | |||
| inputs_np = np.random.randn(32, 3, 224, 224).astype(np.float32) | |||
| standalone_mode_net = ParallelStrategySearchNet(in_channel=3, | |||
| out_channel=8, axis=1, input_shape=(32, 4, 110, -1), | |||
| mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880), | |||
| prelu_size=(1,), transpose_b=True, matmul_size=(1, 12), | |||
| num_class=12) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL) | |||
| parallel_mode_net = ParallelStrategySearchNet(in_channel=3, | |||
| out_channel=8, axis=1, input_shape=(32, 4, 110, -1), | |||
| mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880), | |||
| prelu_size=(1,), transpose_b=True, matmul_size=(1, 12), | |||
| num_class=12) | |||
| standalone_dataset = FakeData(size=128, batch_size=32, | |||
| image_size=(3, 224, 224), num_classes=12) | |||
| fact = ParallelStrategySearchFactory(standalone_mode_net=standalone_mode_net, | |||
| parallel_mode_net=parallel_mode_net) | |||
| fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2) | |||
| parallel_dataset = FakeData(size=128, batch_size=4, | |||
| image_size=(3, 224, 224), use_parallel=True, | |||
| num_classes=12) | |||
| fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, | |||
| epoch=2, device_num=8, auto_parallel_search_mode="recursive_programming") | |||
| fact.checkpoint_cmp(inputs_np=inputs_np) | |||
| @@ -0,0 +1,52 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| set -e | |||
| BASE_PATH=$(cd "$(dirname $0)"; pwd) | |||
| CONFIG_PATH=/home/workspace/mindspore_config | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| source ${BASE_PATH}/env.sh | |||
| unset SLOG_PRINT_TO_STDOUT | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json | |||
| export LD_LIBRARY_PATH=/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} | |||
| export ASCEND_OPP_PATH=/usr/local/Ascend/opp/ | |||
| process_pid=() | |||
| for((i=0; i<$DEVICE_NUM; i++)); do | |||
| rm -rf ${BASE_PATH}/parallel_recursive_strategy_search${i} | |||
| mkdir ${BASE_PATH}/parallel_recursive_strategy_search${i} | |||
| cp -r ${BASE_PATH}/parallel_strategy_search.py ${BASE_PATH}/parallel_recursive_strategy_search${i}/ | |||
| cd ${BASE_PATH}/parallel_recursive_strategy_search${i} | |||
| export RANK_ID=${i} | |||
| export DEVICE_ID=${i} | |||
| echo "start training for device $i" | |||
| env > env$i.log | |||
| pytest -s -v parallel_strategy_search.py::test_auto_parallel_recursive_strategy_search_axis_1_basic > parallel_recursive_strategy_search$i.log 2>&1 & | |||
| process_pid[${i}]=`echo $!` | |||
| done | |||
| for((i=0; i<${DEVICE_NUM}; i++)); do | |||
| wait ${process_pid[i]} | |||
| status=`echo $?` | |||
| if [ "${status}" != "0" ]; then | |||
| echo "[ERROR] test_parallel_recursive_strategy_search failed. status: ${status}" | |||
| exit 1 | |||
| else | |||
| echo "[INFO] test_parallel_recursive_strategy_search success." | |||
| fi | |||
| done | |||
| exit 0 | |||
| @@ -34,7 +34,7 @@ for((i=0; i<$DEVICE_NUM; i++)); do | |||
| export DEVICE_ID=${i} | |||
| echo "start training for device $i" | |||
| env > env$i.log | |||
| pytest -s -v parallel_strategy_search.py > parallel_strategy_search$i.log 2>&1 & | |||
| pytest -s -v parallel_strategy_search.py::test_auto_parallel_strategy_search_axis_1_basic > parallel_strategy_search$i.log 2>&1 & | |||
| process_pid[${i}]=`echo $!` | |||
| done | |||
| @@ -0,0 +1,29 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import pytest | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_single | |||
| def test_sit_parallel_recursive_strategy_search(): | |||
| sh_path = os.path.split(os.path.realpath(__file__))[0] | |||
| ret = os.system(f"sh {sh_path}/run_parallel_recursive_strategy_search.sh") | |||
| os.system( | |||
| f"grep -E 'ERROR|error' " | |||
| f"{sh_path}/parallel_recursive_strategy_search*/parallel_recursive_strategy_search*log -C 3") | |||
| assert ret == 0 | |||
| @@ -20,7 +20,7 @@ import pytest | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_single | |||
| def test_parallel_strategy_search(): | |||
| def test_sit_parallel_strategy_search(): | |||
| sh_path = os.path.split(os.path.realpath(__file__))[0] | |||
| ret = os.system(f"sh {sh_path}/run_parallel_strategy_search.sh") | |||
| os.system(f"grep -E 'ERROR|error' {sh_path}/parallel_strategy_search*/parallel_strategy_search*log -C 3") | |||