| @@ -61,6 +61,8 @@ constexpr char CROSS_BATCH[] = "cross_batch"; | |||
| constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin"; | |||
| constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; | |||
| constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; | |||
| constexpr char REQUIRES_GRAD[] = "requires_grad"; | |||
| constexpr char PARAM_NAME[] = "name"; | |||
| constexpr char RELU_TYPE[] = "relu"; | |||
| constexpr char RELU6_TYPE[] = "relu6"; | |||
| @@ -387,8 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| operator_info->set_outputs_dtype(cnode->Type()); | |||
| operator_info->set_cnode(cnode); | |||
| // key of strategy map | |||
| std::string instance_name = prim->instance_name(); | |||
| std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| std::string strategy_key_name = NodeParameterName(cnode); | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); | |||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | |||
| @@ -1423,11 +1423,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| // load strategy checkpoint | |||
| // key of strategy map | |||
| std::string instance_name = prim->instance_name(); | |||
| std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| std::string strategy_key_name = NodeParameterName(cnode); | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | |||
| if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { | |||
| MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() | |||
| << " is empty, using batch parallel"; | |||
| @@ -2038,17 +2036,20 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo | |||
| } | |||
| } | |||
| bool NodeWithParameter(const CNodePtr &node) { | |||
| std::string NodeParameterName(const CNodePtr &node) { | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| for (auto input : node_inputs) { | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default()) { | |||
| return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); | |||
| if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) { | |||
| return py::cast<std::string>( | |||
| parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| return ""; | |||
| } | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| @@ -2060,21 +2061,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| for (auto &node : all_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| continue; | |||
| } | |||
| std::string param_name = NodeParameterName(cnode); | |||
| if (param_name.empty()) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->operator_info(); | |||
| if (operator_info) { | |||
| if (prim->instance_name().empty()) { | |||
| MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name"; | |||
| } | |||
| std::string instance_name = prim->instance_name(); | |||
| StrategyPtr strategyPtr = operator_info->strategy(); | |||
| MS_EXCEPTION_IF_NULL(node->scope()); | |||
| std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; | |||
| stra_map[node_name] = strategyPtr; | |||
| stra_map[param_name] = strategyPtr; | |||
| } | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { | |||
| @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); | |||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | |||
| const FuncGraphManagerPtr &manager); | |||
| bool NodeWithParameter(const CNodePtr &node); | |||
| std::string NodeParameterName(const CNodePtr &node); | |||
| void CheckpointStrategy(const FuncGraphPtr &func_graph); | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | |||
| class StrategyCheckpoint { | |||
| public: | |||
| @@ -59,6 +59,7 @@ def test_six_matmul_save(): | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") | |||
| def construct(self, x1, x6): | |||
| out = self.matmul1(x1, self.weight1) | |||
| @@ -66,6 +67,7 @@ def test_six_matmul_save(): | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = out + self.weight6 | |||
| out = self.matmul6(out, x6) | |||
| return out | |||
| @@ -118,12 +120,14 @@ def test_six_matmul_load(): | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") | |||
| def construct(self, x1, x6, x7): | |||
| out = self.matmul1(x1, self.weight1) | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = out + self.weight6 | |||
| out = self.matmul6(out, x6) | |||
| out = self.matmul7(out, x7) | |||
| return out | |||
| @@ -179,6 +183,7 @@ def test_six_matmul_save_auto(): | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") | |||
| def construct(self, x1, x6): | |||
| out = self.matmul1(x1, self.weight1) | |||
| @@ -186,6 +191,7 @@ def test_six_matmul_save_auto(): | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = out + self.weight6 | |||
| out = self.matmul6(out, x6) | |||
| return out | |||
| @@ -232,12 +238,14 @@ def test_six_matmul_load_auto(): | |||
| self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") | |||
| self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") | |||
| self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") | |||
| self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") | |||
| def construct(self, x1, x6, x7): | |||
| out = self.matmul1(x1, self.weight1) | |||
| out = self.matmul3(out, self.weight3) | |||
| out = self.matmul4(out, self.weight4) | |||
| out = self.matmul5(out, self.weight5) | |||
| out = out + self.weight6 | |||
| out = self.matmul6(out, x6) | |||
| out = self.matmul7(out, x7) | |||
| return out | |||