Browse Source

pipeline_opt_detection

feature/build-system-rewrite
wangshengnan12@huawei.com 4 years ago
parent
commit
acbefd80ea
4 changed files with 485 additions and 2 deletions
  1. +87
    -0
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  2. +4
    -1
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h
  3. +1
    -1
      mindspore/ccsrc/pipeline/jit/pipeline_split.cc
  4. +393
    -0
      tests/ut/python/parallel/test_pipeline_opt_detection.py

+ 87
- 0
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -25,6 +25,7 @@
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/group_manager.h"
#include "frontend/parallel/parameter_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/node_check.h"
@@ -782,6 +783,7 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
manager_->SetEdge(use_node, SizeToInt(pos), recv);
return nullptr;
}
parameter_color_map[argument].insert(user_stage);
return InsertReceive(main_graph_, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
}
// insert send
@@ -974,7 +976,92 @@ void PipelineTransformer::CoverSensShape() {
manager_->Replace(sens_cnode, new_sens_node);
}

void PipelineTransformer::RedundancyNode(const AnfNodePtr &node,
mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map) {
auto node_users = manager_->node_users()[node];
for (auto &node_user_pair : node_users) {
auto cnode = node_user_pair.first->cast<CNodePtr>();
// node->UpdateState, replaced node wiht U.
auto fg = cnode->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (fg->stage() != -1) {
continue;
}
if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
auto u_node = NewValueNode(kUMonad);
manager_->SetEdge(cnode, node_user_pair.second, u_node);
continue;
}
// node->make_tuple, record with a map, Unified deleted later.
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
if (make_tuple_map->find(cnode) == (*make_tuple_map).end()) {
(*make_tuple_map)[cnode] = {node};
} else {
(*make_tuple_map)[cnode].push_back(node);
}
} else {
RedundancyNode(node_user_pair.first, make_tuple_map);
}
}
}

bool PipelineTransformer::IsRedundancyParameter(const AnfNodePtr &parameter) {
// RedundancyParameter: other stage's parameters included corresponding cloned parameters.
auto parameters = root_->parameters();
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!param_ptr->has_default()) {
return false;
}
auto param_name = param_ptr->name();
for (auto &param : parameters) {
if (ParameterIsCloned(param)) {
continue;
}
auto non_cloned_param = param->cast<ParameterPtr>();
if (param_name.find(non_cloned_param->name()) == std::string::npos) {
continue;
}
auto stage_set = parameter_color_map.at(param);
if (stage_set.empty()) {
return false;
}
return !stage_set.count(stage_);
}
return false;
}

void PipelineTransformer::ElimParameter() {
auto parameters = root_->parameters();
mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> make_tuple_map;
for (auto &parameter : parameters) {
if (!IsRedundancyParameter(parameter)) {
continue;
}
RedundancyNode(parameter, &make_tuple_map);
}
for (auto &temp : make_tuple_map) {
auto make_tuple = temp.first;
auto fg = make_tuple->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto remove_vector = temp.second;
if (remove_vector.empty()) {
continue;
}
auto make_tuple_inputs = make_tuple->inputs();
std::vector<AnfNodePtr> new_inputs;
for (auto &input : make_tuple_inputs) {
if (std::find(remove_vector.begin(), remove_vector.end(), input) == remove_vector.end()) {
new_inputs.push_back(input);
}
}
auto new_make_tuple = fg->NewCNode(new_inputs);
manager_->Replace(make_tuple, new_make_tuple);
}
}

void PipelineTransformer::ModifyParameterList() {
ElimParameter();
auto parameters = root_->parameters();
std::vector<AnfNodePtr> parameter_list;
for (auto &parameter : parameters) {


+ 4
- 1
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h View File

@@ -58,7 +58,7 @@ class PipelineTransformer {
void ParameterColoring();
void CoverSensShape();
void ElimGraphStage();
void ElimParameter();
void ModifyParameterList();

private:
void CreateForwardGroup();
@@ -87,6 +87,9 @@ class PipelineTransformer {
CNodePtr GraphOutNode(const AnfNodePtr &node, int tuple_index);
bool IsPipelineCareNode(const CNodePtr &cnode);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
void RedundancyNode(const AnfNodePtr &node, mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map);
bool IsRedundancyParameter(const AnfNodePtr &parameter);
void ElimParameter();
FuncGraphManagerPtr manager_;
int64_t stage_;
FuncGraphPtr root_;


+ 1
- 1
mindspore/ccsrc/pipeline/jit/pipeline_split.cc View File

@@ -263,8 +263,8 @@ bool PipelineSplit(const ResourcePtr &res) {
transformer->CoverSensShape();
}
// step6: Elim Graph stages and no used parameter
transformer->ModifyParameterList();
transformer->ElimGraphStage();
transformer->ElimParameter();
return true;
}
} // namespace pipeline


+ 393
- 0
tests/ut/python/parallel/test_pipeline_opt_detection.py View File

@@ -0,0 +1,393 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell, MicroBatchInterleaved


class DatasetLenet():
def __init__(self, data, label, length=3):
self.data = data
self.label = label
self.index = 1
self.length = length

def __iter__(self):
return self

def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data, self.label

def reset(self):
self.index = 0

@staticmethod
def get_dataset_size():
return 32

@staticmethod
def get_repeat_count():
return 1

@staticmethod
def get_batch_size():
return 32

def create_tuple_iterator(self, num_epochs=1, do_copy=True):
return self


class MatMulCell(nn.Cell):
def __init__(self, strategy1, strategy2, param=None, dtype=ms.float32):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
if param is not None:
self.param = param
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
self.cast = P.Cast()
self.dtype = dtype

def construct(self, x):
out = self.matmul(self.cast(x, self.dtype), self.cast(self.param, self.dtype))
out = self.matmul1(out, self.cast(self.param1, self.dtype))
return out


class Net(nn.Cell):
def __init__(self, strategy1, strategy2, param=None, dtype=ms.float32):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2, param, dtype)
cell.pipeline_stage = i
self.block.append(cell)

def construct(self, x):
for i in range(2):
x = self.block[i](x)
return x


class PipelineSplit(nn.Cell):
def __init__(self, strategy1, strategy2, dtype=ms.float32):
super().__init__()
self.cell = Net(strategy1, strategy2, dtype=dtype)

def construct(self, x, label):
x = self.cell(x)
return x


class PipelineSplitSharedParam(nn.Cell):
def __init__(self, strategy1, strategy2, dtype=ms.float32):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.cell = Net(strategy1, strategy2, self.param, dtype)

def construct(self, x, label):
x = self.cell(x)
return x


def test_pipeline_split_stage0():
"""
Feature:pipeline stage0 + opt detection
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"

def test_pipeline_split_stage1():
"""
Feature:pipeline stage1 + opt detection
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 4)
optimizer = nn.Lamb(params, learning_rate=0.001)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"


def test_pipeline_split_shared_parameter_stage0():
"""
Feature:pipeline stage0 + opt detection + shared parameter
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplitSharedParam(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 6)
optimizer = nn.Lamb(params, learning_rate=0.03)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)


def test_pipeline_split_shared_parameter_stage1():
"""
Feature:pipeline stage1 + opt detection + shared parameter
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplitSharedParam(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 7)
optimizer = nn.Lamb(params, learning_rate=0.04)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)


def test_pipeline_split_stage0_opt_shard():
"""
Feature:pipeline stage0 + opt detection + opt shard
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 6)
optimizer = nn.Lamb(params, learning_rate=0.02)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"


def test_pipeline_split_stage1_opt_shard():
"""
Feature:pipeline stage1 + opt detection + opt shard
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 8)
optimizer = nn.Lamb(params, learning_rate=0.04)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"


def test_pipeline_split_shared_parameter_stage0_opt_shard():
"""
Feature:pipeline stage0 + opt detection + opt shard + shared parameter
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplitSharedParam(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 2)
optimizer = nn.Lamb(params, learning_rate=0.06)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)


def test_pipeline_split_shared_parameter_stage1_opt_shard():
"""
Feature:pipeline stage1 + opt detection + opt shard + shared parameter
Description:pipeline opt detection
Expectation:success
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineCell(PipelineSplitSharedParam(strategy1, strategy2), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 9)
optimizer = nn.Lamb(params, learning_rate=0.06)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)


def test_pipeline_split_with_micro_batch_interleaved_stage0():
"""
Feature: test PipelineSplit with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplit(strategy1, strategy2), micro_batch_interleaved), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.07)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"


def test_pipeline_split_with_micro_batch_interleaved_stage1():
"""
Feature: test PipelineSplit with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplit(strategy1, strategy2), micro_batch_interleaved), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.08)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
for _, param in model._train_network.parameters_and_names():
assert param.name != "cell.block.0.param"
assert param.name != "cell.block.0.param1"


def test_pipeline_split_shared_parameter_with_micro_batch_interleaved_stage0_opt_shard():
"""
Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplitSharedParam(strategy1, strategy2),
micro_batch_interleaved), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 5)
optimizer = nn.Lamb(params, learning_rate=0.06)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)


def test_pipeline_split_shared_parameter_with_micro_batch_interleaved_stage1_opt_shard():
"""
Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
micro_batch_interleaved = 2
net = PipelineCell(MicroBatchInterleaved(PipelineSplitSharedParam(strategy1, strategy2),
micro_batch_interleaved), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 4)
optimizer = nn.Lamb(params, learning_rate=0.02)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)


def run_pipeline_split_function(pipeline_net, micro_batch_interleaved=1):
"""
Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
Description: net with MicroBatchInterleaved in semi auto parallel.
Expectation: success.
"""
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)

net = PipelineCell(MicroBatchInterleaved(pipeline_net, micro_batch_interleaved), 4)
params = net.trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)

Loading…
Cancel
Save