Browse Source

use param name as the key of strategy checkpoint

tags/v0.3.0-alpha
yao_yf 6 years ago
parent
commit
5a6540450e
6 changed files with 25 additions and 17 deletions
  1. +2
    -0
      mindspore/ccsrc/parallel/ops_info/ops_utils.h
  2. +1
    -2
      mindspore/ccsrc/parallel/step_auto_parallel.cc
  3. +13
    -13
      mindspore/ccsrc/parallel/step_parallel.cc
  4. +1
    -1
      mindspore/ccsrc/parallel/step_parallel.h
  5. +0
    -1
      mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h
  6. +8
    -0
      tests/ut/python/parallel/test_strategy_checkpoint.py

+ 2
- 0
mindspore/ccsrc/parallel/ops_info/ops_utils.h View File

@@ -61,6 +61,8 @@ constexpr char CROSS_BATCH[] = "cross_batch";
constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin";
constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
constexpr char REQUIRES_GRAD[] = "requires_grad";
constexpr char PARAM_NAME[] = "name";

constexpr char RELU_TYPE[] = "relu";
constexpr char RELU6_TYPE[] = "relu6";


+ 1
- 2
mindspore/ccsrc/parallel/step_auto_parallel.cc View File

@@ -387,8 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode);
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
std::string strategy_key_name = NodeParameterName(cnode);
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for


+ 13
- 13
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -1423,11 +1423,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
}
// load strategy checkpoint
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
std::string strategy_key_name = NodeParameterName(cnode);
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();

if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
@@ -2038,17 +2036,20 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
}
}

bool NodeWithParameter(const CNodePtr &node) {
std::string NodeParameterName(const CNodePtr &node) {
std::vector<AnfNodePtr> node_inputs{node->inputs()};
for (auto input : node_inputs) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) {
return py::cast<std::string>(
parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME));
}
}
}
}
return false;
return "";
}

void CheckpointStrategy(const FuncGraphPtr &func_graph) {
@@ -2060,21 +2061,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
std::string param_name = NodeParameterName(cnode);
if (param_name.empty()) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) {
if (prim->instance_name().empty()) {
MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
}
std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
stra_map[node_name] = strategyPtr;
stra_map[param_name] = strategyPtr;
}
}
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {


+ 1
- 1
mindspore/ccsrc/parallel/step_parallel.h View File

@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);

bool NodeWithParameter(const CNodePtr &node);
std::string NodeParameterName(const CNodePtr &node);

void CheckpointStrategy(const FuncGraphPtr &func_graph);



+ 0
- 1
mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h View File

@@ -25,7 +25,6 @@

namespace mindspore {
namespace parallel {

using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
class StrategyCheckpoint {
public:


+ 8
- 0
tests/ut/python/parallel/test_strategy_checkpoint.py View File

@@ -59,6 +59,7 @@ def test_six_matmul_save():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")

def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1)
@@ -66,6 +67,7 @@ def test_six_matmul_save():
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
return out

@@ -118,12 +120,14 @@ def test_six_matmul_load():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")

def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out
@@ -179,6 +183,7 @@ def test_six_matmul_save_auto():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")

def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1)
@@ -186,6 +191,7 @@ def test_six_matmul_save_auto():
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
return out

@@ -232,12 +238,14 @@ def test_six_matmul_load_auto():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")

def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out


Loading…
Cancel
Save