Merge pull request !1187 from yangzhenzhang/ckpt-and-restore-parameter-shapetags/v0.3.0-alpha
| @@ -22,12 +22,15 @@ | |||
| #include <memory> | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include <map> | |||
| #include "common/utils.h" | |||
| #include "parallel/device_manager.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| static std::map<std::string, std::vector<int>> param_shapes; | |||
| std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, | |||
| AUTO_PARALLEL}; | |||
| std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; | |||
| @@ -136,5 +139,56 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const | |||
| } | |||
| return {}; | |||
| } | |||
| // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | |||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { | |||
| return; | |||
| } | |||
| param_shapes.clear(); | |||
| } | |||
| // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode | |||
| void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| AbstractBasePtr ptr) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) || | |||
| func_graph->flags()[TRAINING]) { | |||
| return; | |||
| } | |||
| auto iter = param_shapes.find(param_node->name()); | |||
| if (iter == param_shapes.end()) { | |||
| MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); | |||
| return; | |||
| } | |||
| std::vector<int> shape = iter->second; | |||
| std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape); | |||
| ptr->set_shape(base_shape); | |||
| MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; | |||
| } | |||
| // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode | |||
| void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| const AbstractBasePtr &ptr) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { | |||
| return; | |||
| } | |||
| std::vector<int> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape(); | |||
| auto ret = param_shapes.try_emplace(param_node->name(), shape); | |||
| if (!ret.second) { | |||
| MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,9 @@ | |||
| #include "parallel/ops_info/ops_utils.h" | |||
| #include "parallel/status.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "debug/info.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -38,6 +41,8 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; | |||
| constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; | |||
| constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; | |||
| constexpr char TRAINING[] = "training"; | |||
| class ParallelContext { | |||
| public: | |||
| ~ParallelContext() = default; | |||
| @@ -114,6 +119,12 @@ class ParallelContext { | |||
| std::string strategy_ckpt_load_file_; | |||
| std::string strategy_ckpt_save_file_; | |||
| }; | |||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph); | |||
| void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| AbstractBasePtr ptr); | |||
| void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| const AbstractBasePtr &ptr); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -25,6 +25,7 @@ | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "parallel/costmodel_context.h" | |||
| #include "parallel/context.h" | |||
| #include "pipeline/pass.h" | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| @@ -217,6 +218,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||
| parallel::ParallelParameterContextInit(func_graph); | |||
| // suppose that there is not KeywordArgument for the top graph | |||
| // get the hyper parameter | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| @@ -224,7 +227,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| if (param_node->has_default()) { | |||
| AbstractBasePtr ptr = | |||
| abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | |||
| args_spec.push_back(ptr); | |||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); | |||
| } | |||
| } | |||
| // Analyze | |||
| @@ -379,7 +379,7 @@ class _Executor: | |||
| self._params_init_data(obj, params) | |||
| if not enable_debug_runtime or enable_ge: | |||
| if auto_parallel_mode: | |||
| if auto_parallel_mode and "train" in phase: | |||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | |||
| obj.load_parameter_slice(params) | |||
| @@ -47,7 +47,7 @@ def test_get_parameter_layout(): | |||
| net = Net(strategy1, strategy2, weight) | |||
| net.set_auto_parallel() | |||
| exe = me._executor | |||
| exe.compile(net, x, auto_parallel_mode=True) | |||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | |||
| x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.api import _executor | |||
| class Net(Cell): | |||
| def __init__(self, mul_weight, strategy1=None, strategy2=None): | |||
| super().__init__() | |||
| self.mul = P.Mul().set_strategy(strategy1) | |||
| self.neg = P.Neg().set_strategy(strategy2) | |||
| self.mul_weight = Parameter(mul_weight, "w1") | |||
| def construct(self, x, b): | |||
| out = self.mul(x, self.mul_weight) | |||
| out = self.neg(out) | |||
| return out | |||
| class EvalNet(Cell): | |||
| def __init__(self, network, strategy2=None): | |||
| super().__init__() | |||
| self.network = network | |||
| self.relu = P.ReLU().set_strategy(strategy2) | |||
| def construct(self, x, b): | |||
| out = self.network(x, b) | |||
| out = self.relu(out) | |||
| return out | |||
| _x = Tensor(np.ones([8, 8]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([8, 8]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([8, 8]), dtype=ms.float32) | |||
| def test_train_and_eval(): | |||
| context.set_context(save_graphs=True, mode=0) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16) | |||
| strategy1 = ((4, 4), (4, 4)) | |||
| strategy2 = ((4, 4), ) | |||
| net = Net(_w1, strategy1, strategy2) | |||
| eval_net = EvalNet(net, strategy2=strategy2) | |||
| net.set_train() | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True) | |||
| eval_net.set_train(mode=False) | |||
| eval_net.set_auto_parallel() | |||
| _executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True) | |||
| context.reset_auto_parallel_context() | |||