Browse Source

!29087 optimize sharding propagation&add ut on pangu

Merge pull request !29087 from bichaoyang/master
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
55ba926a04
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 815 additions and 12 deletions
  1. +25
    -7
      mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc
  2. +8
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc
  3. +1
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h
  4. +7
    -3
      mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc
  5. +10
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  6. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  7. +17
    -0
      mindspore/ccsrc/frontend/parallel/parameter_manager.cc
  8. +15
    -0
      mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc
  9. +2
    -0
      mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h
  10. +729
    -0
      tests/ut/python/parallel/test_auto_parallel_pangu_alpha_shard_propagation.py

+ 25
- 7
mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc View File

@@ -48,7 +48,9 @@ Status Edge::InitEdgeCost() {
for (auto &target_input : next_op_input_) {
auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
auto target_input_str = target_input.first;
if (target_output_lyt == target_input_lyt) {
// for identity_info ops, no need to compare device_matrix
if ((target_output_lyt == target_input_lyt) || (target_output_lyt.IsSameWithoutSplit(target_input_lyt) &&
edge_name().find(IDENTITY_INFO) != std::string::npos)) {
CostPtrKey ck = {target_output_str, target_input_str};
CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
MS_EXCEPTION_IF_NULL(cost);
@@ -383,10 +385,18 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPt
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
<< " with zero communication cost, choose the one with minimum computation costs.";
}
auto next_op = next_op_;
std::sort(next_op_stras.begin(), next_op_stras.end(),
[this](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return !IsDoubleEqual(a.second, b.second) ? a.second < b.second
: a.first->PartitionNum() > b.first->PartitionNum();
[this, &next_op](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
if (!IsDoubleEqual(a.second, b.second)) {
return a.second < b.second;
}
auto cost_a = next_op->GetCostByStrategyPtr(a.first)[0]->communication_cost_;
auto cost_b = next_op->GetCostByStrategyPtr(b.first)[0]->communication_cost_;
if (!IsDoubleEqual(cost_a, cost_b)) {
return cost_a < cost_b;
}
return a.first->PartitionNum() > b.first->PartitionNum();
});
return next_op_stras[0].first;
}
@@ -425,10 +435,18 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPt
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
<< " with zero communication costs, choose the one with minimum computation costs.";
}
auto prev_op = prev_op_;
std::sort(prev_op_stras.begin(), prev_op_stras.end(),
[this](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return !IsDoubleEqual(a.second, b.second) ? a.second < b.second
: a.first->PartitionNum() > b.first->PartitionNum();
[this, &prev_op](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
if (!IsDoubleEqual(a.second, b.second)) {
return a.second < b.second;
}
auto cost_a = prev_op->GetCostByStrategyPtr(a.first)[0]->communication_cost_;
auto cost_b = prev_op->GetCostByStrategyPtr(b.first)[0]->communication_cost_;
if (!IsDoubleEqual(cost_a, cost_b)) {
return cost_a < cost_b;
}
return a.first->PartitionNum() > b.first->PartitionNum();
});
return prev_op_stras[0].first;
}


+ 8
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc View File

@@ -99,6 +99,13 @@ void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, O
for (auto &op_stra : ops_stras) {
BFS(op_stra.first, op_stra.second, ops_stras, &visited);
}

// GetNext as a isolate op can not be propagated
for (auto &op : entire_costgraph->GetOperators()) {
if ((op->name().find(GET_NEXT) != std::string::npos) && (op->selected_strategy() == nullptr)) {
op->SetSelectedStrategy(op->strategy_cost()[0]->strategy_ptr, 0);
}
}
}

void CheckVisitedEdgeConsistency(const EdgePtr &edge) {
@@ -173,7 +180,7 @@ void CheckConfiguredPrevEdgeConsistency(const EdgePtr &edge,
}

void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_ops,
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops,
std::map<OperatorInfoPtr, bool> *visited) {
std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
(void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);


+ 1
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h View File

@@ -53,7 +53,7 @@ class CostGraph {
void RemoveOperator(const OperatorInfoPtr &op);
bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &);
void BFS(const OperatorInfoPtr &, const StrategyPtr &, const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare>,
void BFS(const OperatorInfoPtr &, const StrategyPtr &, const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &,
std::map<OperatorInfoPtr, bool> *);
// the edge is in the form: u --> v
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);


+ 7
- 3
mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc View File

@@ -51,10 +51,14 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
}

AnfNodePtr GetRealInput(const AnfNodePtr &input) {
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
return input->cast<CNodePtr>()->input(1);
auto res = input;
while (IsPrimitiveCNode(res, prim::kPrimLoad) || IsPrimitiveCNode(res, prim::kPrimDepend)) {
res = res->cast<CNodePtr>()->input(1);
if (!res->isa<CNode>()) {
return res;
}
}
return input;
return res;
}

// Given the node, return whether each input is a parameter or a output of a operator.


+ 10
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -1390,6 +1390,16 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
return SUCCESS;
}

CostPtrList OperatorInfo::GetCostByStrategyPtr(const StrategyPtr &strategy) {
auto target = std::find_if(
strategy_cost_.begin(), strategy_cost_.end(),
[&](const std::shared_ptr<StrategyWithCost> &stra_cost) { return stra_cost->strategy_ptr == strategy; });
if (target == strategy_cost_.end()) {
MS_LOG(ERROR) << "There is no StrategyWithCost with a strategy";
}
return (*target)->cost_list;
}

TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index) {
auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);


+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -101,6 +101,7 @@ class OperatorInfo {
// This is a common method for setting operator cost for a given strategy, in which the validity of this strategy
// is checked
Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
CostPtrList GetCostByStrategyPtr(const StrategyPtr &strategy);
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &);
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving


+ 17
- 0
mindspore/ccsrc/frontend/parallel/parameter_manager.cc View File

@@ -111,6 +111,23 @@ static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node) {
auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node];
for (auto &node_user : load_node_users) {
auto cnode = node_user.first->cast<CNodePtr>();
if (IsSomePrimitive(cnode, DEPEND)) {
auto depend_node_users = node->func_graph()->manager()->node_users()[node_user.first];
for (auto depend_user : depend_node_users) {
if (IsPrimitiveCNode(depend_user.first, prim::kPrimLoad)) {
auto local_load_node_users = node->func_graph()->manager()->node_users()[depend_user.first];
for (auto local_load_user : local_load_node_users) {
auto local_cnode = local_load_user.first->cast<CNodePtr>();
if (local_cnode == nullptr || !local_cnode->has_user_data<OperatorInfo>() ||
IsSomePrimitive(local_cnode, RECEIVE)) {
continue;
}
parameter_user_info.second.second.insert(local_load_user);
}
}
}
}

if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
continue;
}


+ 15
- 0
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc View File

@@ -362,6 +362,21 @@ bool TensorLayout::operator!=(const TensorLayout &t1) const {
return !(IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
}

bool TensorLayout::IsSameWithoutSplit(const TensorLayout &t1) const {
if (!IsSameTensorMap(t1) || !IsSameTensorShape(t1)) {
return false;
}
auto first_array = tensor_map().array();
auto second_array = t1.tensor_map().array();
auto first_pos = std::find_if(first_array.begin(), first_array.end(), [&](const int64_t ele) { return ele != -1; });
auto second_pos =
std::find_if(second_array.begin(), second_array.end(), [&](const int64_t ele) { return ele != -1; });
if (first_pos != first_array.end() || second_pos != second_array.end()) {
return false;
}
return true;
}

/*
* remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ]
* example 1:


+ 2
- 0
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h View File

@@ -84,6 +84,8 @@ class TensorLayout {

bool operator!=(const TensorLayout &t1) const;

bool IsSameWithoutSplit(const TensorLayout &t1) const;

bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const;

std::shared_ptr<Arrangement> ComputeExpandedTensorShape(const Arrangement &expand_shape) const;


+ 729
- 0
tests/ut/python/parallel/test_auto_parallel_pangu_alpha_shard_propagation.py View File

@@ -0,0 +1,729 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import math
import re
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn import Cell
from mindspore import context
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel import set_algo_parameters
from mindspore.nn.layer.activation import get_activation
from mindspore.train.model import Model
from mindspore.common.api import _cell_graph_executor
from tests.dataset_mock import MindData

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")


class Dataset(MindData):
def __init__(self, *inputs, length=3):
super(Dataset, self).__init__(size=length)
self.inputs = inputs
self.index = 0
self.length = length

def __iter__(self):
return self

def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.inputs

def reset(self):
self.index = 0


def set_parallel_configure_for_layer(network, layer_id, offset, layers):
pp_dis = max(int((layers + 1) / 1), 1)
pp_id = min((layer_id + offset) // pp_dis, 0)
network.pipeline_stage = pp_id
dis = max(int((layers + 1) / 4), 1)
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)


class _LayerNorm(Cell):
def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32):
super(_LayerNorm, self).__init__()
if normalized_shape[0] <= 1024:
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
begin_params_axis=-1,
epsilon=eps)
self.is_self_defined = normalized_shape[0] > 1024
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
parallel_optimizer=False)
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
parallel_optimizer=False)
self.mean = P.ReduceMean(keep_dims=True)
self.square = P.Square()
self.sqrt = P.Sqrt()
self.sub1 = P.Sub()
self.sub2 = P.Sub()
self.add = P.Add()
self.eps = eps
self.mul = P.Mul()
self.add2 = P.Add()
self.real_div = P.RealDiv()

def construct(self, x):
if self.is_self_defined:
mean = self.mean(x, -1)
diff = self.sub1(x, mean)
variance = self.mean(self.square(diff), -1)
variance_eps = self.sqrt(self.add(variance, self.eps))
output = self.real_div(diff, variance_eps)
output = self.add2(self.mul(output, self.gamma), self.beta)
else:
output, _, _ = self.layer_norm(x, self.gamma, self.beta)
return output

def shard(self, strategy):
if self.is_self_defined:
self.mean.shard(strategy)
self.square.shard(strategy)
self.sqrt.shard(strategy)
self.sub1.shard((strategy[0], strategy[0]))
self.sub2.shard((strategy[0], strategy[0]))
self.add.shard((strategy[0], ()))
self.mul.shard((strategy[0], (1,)))
self.add2.shard((strategy[0], (1,)))
self.real_div.shard((strategy[0], strategy[0]))
else:
self.layer_norm.shard((strategy[0], (1,), (1,)))
return self


class _Linear(Cell):
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
transpose_b=True,
param_init_type=mstype.float32,
compute_dtype=mstype.float16):
super(_Linear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
self.expert_flag = False
self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
self.matmul = P.MatMul(transpose_b=transpose_b)
self.bias = None
self.has_bias = has_bias
if self.has_bias:
self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
self.bias_add = P.Add()
self.act_name = activation
self.activation = get_activation(activation) if isinstance(activation, str) else activation
self.activation_flag = self.activation is not None
self.dtype = compute_dtype
self.cast = P.Cast()

def construct(self, x):
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
x = P.Reshape()(x, (-1, self.in_channels))
weight = self.cast(self.weight, self.dtype)
x = self.matmul(x, weight)
if self.has_bias:
x = self.bias_add(x, self.cast(self.bias, self.dtype))
if self.activation_flag:
x = self.activation(x)
output = P.Reshape()(x, out_shape)
return output

def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
self.matmul.shard(strategy_matmul)
if self.has_bias:
self.bias_add.shard(strategy_bias)
if self.activation_flag:
if self.act_name.lower() == "leakyrelu":
self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
elif self.act_name.lower() == "logsigmoid":
self.activation.mul.shard((strategy_activation[0], ()))
self.activation.exp.shard(strategy_activation)
self.activation.add.shard((strategy_activation[0], ()))
self.activation.rec.shard(strategy_activation)
self.activation.log.shard(strategy_activation)
else:
getattr(self.activation, self.act_name).shard(strategy_activation)
return self


class FeedForward(Cell):
def __init__(self, hidden_size,
ffn_hidden_size,
dropout_rate,
hidden_act='gelu',
param_init_type=mstype.float32):
super(FeedForward, self).__init__()
dp = 2
mp = 8
input_size = hidden_size
output_size = ffn_hidden_size
self.mapping = _Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
param_init_type=param_init_type)
self.projection = _Linear(in_channels=output_size,
out_channels=input_size,
transpose_b=False,
param_init_type=param_init_type)
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
self.projection.bias.parallel_optimizer = False
self.dropout = nn.Dropout(1 - dropout_rate)
self.dropout_3d = nn.Dropout(1 - dropout_rate)
self.cast = P.Cast()

def construct(self, x):
x = self.cast(x, mstype.float16)
hidden = self.mapping(x)
output = self.projection(hidden)
if len(F.shape(output)) == 3:
output = self.dropout_3d(output)
else:
output = self.dropout(output)
return output


class MultiHeadAttention(Cell):
def __init__(self, batch_size,
src_seq_length,
tgt_seq_length,
hidden_size,
num_heads,
hidden_dropout_rate=0.1,
attention_dropout_rate=0.1,
compute_dtype=mstype.float16,
softmax_compute_type=mstype.float32,
param_init_type=mstype.float32):
super(MultiHeadAttention, self).__init__()
self.src_seq_length = src_seq_length
self.tgt_seq_length = tgt_seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
self.projection = _Linear(in_channels=hidden_size,
out_channels=hidden_size,
transpose_b=False,
param_init_type=param_init_type).to_float(compute_dtype)
self.projection.shard(strategy_matmul=((2, 8), (8, 1)))
self.projection.bias.parallel_optimizer = False
self.transpose = P.Transpose()
self.merger_head_transpose = P.Transpose()
self.reshape = P.Reshape()
self.n_head = num_heads
self.size_per_head = hidden_size // self.n_head
self.concat_k = P.Concat(axis=3)
self.concat_v = P.Concat(axis=2)
self.multiply_data = Tensor([
-10000.0,
], dtype=softmax_compute_type)
self.batch_matmul = P.BatchMatMul()
self.real_div = P.RealDiv()
self.sub = P.Sub()
self.mul = P.Mul()
self.add = P.Add()
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
self.softmax = nn.Softmax().to_float(softmax_compute_type)
self.expand_dims = P.ExpandDims()
# Query
self.dense1 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
# Key
self.dense2 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
# Value
self.dense3 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
self.dtype = compute_dtype
self.softmax_dtype = softmax_compute_type

def construct(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor,
attention_mask)
query = self.dense1(query_tensor)
key = self.dense2(key_tensor)
value = self.dense3(value_tensor)
query = self.transpose(
F.reshape(
query,
(batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 1, 3))
key = self.transpose(
F.reshape(
key, (batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 3, 1))
value = self.transpose(
F.reshape(
value,
(batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 1, 3))
if len(F.shape(attention_mask)) == 3:
attention_mask = self.expand_dims(attention_mask, 1)
key_present = key
value_present = value
layer_present = (key_present, value_present)
attention = self._attn(query, key, value, attention_mask)
output = self.projection(attention)
output = self.dropout(output)
output = F.reshape(output, ori_shape)
return output, layer_present

def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
query_shape = F.shape(query_tensor)
query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
key_shape = F.shape(key_tensor)
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
value_shape = F.shape(value_tensor)
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape

def _merge_heads(self, x):
x = self.merger_head_transpose(
x, (0, 2, 1, 3))
x_shape = P.Shape()(x)
new_shape = (-1, x_shape[-2] * x_shape[-1])
x_merge = self.reshape(x, new_shape)
return x_merge

def _attn(self, query, key, value, attention_mask):
score = self.batch_matmul(query, key)
score = self.real_div(
score,
P.Cast()(self.scale_factor, P.DType()(score)))

ori_dtype = P.DType()(score)
score = P.Cast()(score, self.softmax_dtype)

multiplu_out = self.sub(
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
P.Cast()(attention_mask, P.DType()(score)))
adder = self.mul(multiplu_out, self.multiply_data)
attention_scores = self.add(adder, score)
shape = F.shape(attention_scores)
attention_probs = self.softmax(
F.reshape(attention_scores,
(shape[0], -1, shape[-1])))
attention_probs = P.Cast()(attention_probs, ori_dtype)
attention_probs = F.reshape(attention_probs, shape)

attention_probs = self.prob_dropout(attention_probs)
weighted_values = self.batch_matmul(attention_probs, value)
attention_merge = self._merge_heads(weighted_values)
return attention_merge


class TransformerEncoderLayer(Cell):
def __init__(self,
batch_size,
hidden_size,
ffn_hidden_size,
num_heads,
seq_length,
attention_dropout_rate=0.1,
hidden_dropout_rate=0.1,
layernorm_compute_type=mstype.float32,
softmax_compute_type=mstype.float32,
param_init_type=mstype.float32,
hidden_act='gelu'):
super(TransformerEncoderLayer, self).__init__()
self.seq_length = seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type)
self.output = FeedForward(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
param_init_type=param_init_type,
hidden_act=hidden_act)
self.add = P.Add()
self.add_3d = P.Add()
self.dtype = mstype.float16
self.key_past = None
self.value_past = None

def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
x_shape = F.shape(x)
x = F.reshape(x, (-1, x_shape[-1]))
input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype)
attention, layer_present = self.attention(input_x, input_x, input_x, input_mask,
self.key_past, self.value_past, batch_valid_length)
x = self.add(x, attention)
output_x = self.layernorm2(x)
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
value_update = None
key_update = None
mlp_logit = F.depend(mlp_logit, value_update)
mlp_logit = F.depend(mlp_logit, key_update)
if len(x_shape) == 3:
output_x = P.Reshape()(output_x, x_shape)
mlp_logit = P.Reshape()(mlp_logit, x_shape)
x = P.Reshape()(x, x_shape)
output = self.add_3d(x, mlp_logit)
else:
output = self.add(x, mlp_logit)
output = F.reshape(output, x_shape)
return output, layer_present


class TransformerEncoder(Cell):
def __init__(self,
batch_size,
num_layers,
hidden_size,
ffn_hidden_size,
seq_length,
num_heads,
attention_dropout_rate=0.1,
hidden_dropout_rate=0.1,
hidden_act='gelu',
layernorm_compute_type=mstype.float32,
softmax_compute_type=mstype.float32,
param_init_type=mstype.float32,
lambda_func=None,
offset=0):
super(TransformerEncoder, self).__init__()
self.add = P.Add().shard(((), ()))
self.aux_loss = Tensor(0.0, mstype.float32)
self.num_layers = num_layers
self.blocks = nn.CellList()
for i in range(num_layers):
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
param_init_type=param_init_type)
lambda_func(block, layer_id=i, layers=num_layers,
offset=offset)
self.blocks.append(block)

def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None):
present_layer = ()
for i in range(self.num_layers):
hidden_states, present = self.blocks[i](hidden_states,
attention_mask,
init_reset,
batch_valid_length)
present_layer = present_layer + (present,)
return hidden_states, present_layer


class VocabEmbedding(Cell):
def __init__(self, vocab_size, embedding_size, param_init='normal'):
super(VocabEmbedding, self).__init__()
self.embedding_table = Parameter(initializer(param_init, [vocab_size, embedding_size]),
name='embedding_table', parallel_optimizer=False)
self.gather = P.GatherV2().shard(((1, 1), (2, 1)))

def construct(self, input_ids):
output = self.gather(self.embedding_table, input_ids, 0)
return output, self.embedding_table


class EmbeddingLayer(nn.Cell):
def __init__(self):
super(EmbeddingLayer, self).__init__()
self.word_embedding = VocabEmbedding(vocab_size=40000, embedding_size=2560)
self.position_embedding = VocabEmbedding(vocab_size=40000, embedding_size=2560)
self.add = P.Add()
self.dropout = nn.Dropout(0.9)

def construct(self, input_ids, input_position, init_reset, batch_valid_length):
word_embedding, word_table = self.word_embedding(input_ids)
position_embedding, _ = self.position_embedding(input_position)
embed = self.add(word_embedding, position_embedding)
embed = self.dropout(embed)
return embed, word_table


class QueryLayer(TransformerEncoderLayer):
def __init__(self, batch_size,
hidden_size,
ffn_hidden_size,
num_heads,
seq_length,
attention_dropout_rate=0.1,
hidden_dropout_rate=0.1,
param_init_type=mstype.float32,
hidden_act='gelu',
softmax_compute_type=mstype.float32):
super(QueryLayer, self).__init__(batch_size=batch_size,
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_heads=num_heads,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
param_init_type=param_init_type,
hidden_act=hidden_act,
softmax_compute_type=softmax_compute_type)

def construct(self, x, query_vector, input_mask, init_reset=True, batch_valid_length=None):
input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype)
attention, layer_present = self.attention(query_vector, input_x, input_x, input_mask,
self.key_past, self.value_past, batch_valid_length)
x = self.add(x, attention)
output_x = self.layernorm2(x)
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)

value_update = None
key_update = None
mlp_logit = F.depend(mlp_logit, value_update)
mlp_logit = F.depend(mlp_logit, key_update)
output = self.add(x, mlp_logit)
return output, layer_present


class PanGuHead(Cell):
def __init__(self,
hidden_size,
compute_type=mstype.float16):
super(PanGuHead, self).__init__()
self.matmul = P.MatMul(transpose_b=True).shard(((2, 1), (1, 1)))
self.hidden_size = hidden_size
self.dtype = compute_type
self.cast = P.Cast()

def construct(self, state, embed):
state = P.Reshape()(state, (-1, self.hidden_size))
logits = self.matmul(self.cast(state, self.dtype), self.cast(embed, self.dtype))
return logits


class PanguAlphaRawModel(Cell):
def __init__(self):
super(PanguAlphaRawModel, self).__init__()
self.embedding = EmbeddingLayer()
self.layernorm = _LayerNorm((2560,)).to_float(mstype.float32)
self.layernorm.set_comm_fusion(4)
self.embedding.pipeline_stage = 0
self.blocks = TransformerEncoder(num_layers=1,
batch_size=32,
hidden_size=2560,
ffn_hidden_size=10240,
num_heads=32,
seq_length=1024,
attention_dropout_rate=0.1,
hidden_dropout_rate=0.1,
lambda_func=set_parallel_configure_for_layer,
param_init_type=mstype.float32,
softmax_compute_type=mstype.float16).blocks
self.top_query_embedding = VocabEmbedding(vocab_size=1024,
embedding_size=2560)
self.top_query_embedding.set_comm_fusion(4)
self.top_query_layer = QueryLayer(batch_size=32,
hidden_size=2560,
ffn_hidden_size=10240,
num_heads=32,
seq_length=1024,
attention_dropout_rate=0.1,
hidden_dropout_rate=0.1,
hidden_act="fast_gelu",
param_init_type=mstype.float32)
self.top_query_layer.set_comm_fusion(4)
self.dtype = mstype.float16

def construct(self, input_ids,
input_position,
encoder_masks,
init_reset=True,
batch_valid_length=None):
embed, word_table = self.embedding(input_ids, input_position, init_reset, batch_valid_length)
hidden_state = P.Cast()(embed, self.dtype)
hidden_state = self.reshape_to_2d(hidden_state)
hidden_state, _ = self.blocks[0](hidden_state, encoder_masks, init_reset, batch_valid_length)
hidden_state = self.reshape_to_2d(hidden_state)
encoder_output = self.layernorm(hidden_state)
encoder_output = P.Cast()(encoder_output, self.dtype)
top_query_hidden_states, _ = self.top_query_embedding(input_position)
top_query_hidden_states = self.reshape_to_2d(top_query_hidden_states)
encoder_output, _ = self.top_query_layer(encoder_output, top_query_hidden_states,
encoder_masks, init_reset, batch_valid_length)
return encoder_output, word_table

def reshape_to_2d(self, x):
shape = F.shape(x)
if len(shape) <= 2:
return x
x = F.reshape(x, (-1, shape[-1]))
return x


class PanguAlphaModel(nn.Cell):
def __init__(self):
super(PanguAlphaModel, self).__init__()
self.head = PanGuHead(hidden_size=2560)
self.backbone = PanguAlphaRawModel()
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(0)

def construct(self, input_ids, input_position, attention_mask,
init_reset=True, batch_valid_length=None):
output_states, word_table = self.backbone(input_ids, input_position, attention_mask,
init_reset, batch_valid_length)
logits = self.head(output_states, word_table)
return logits


class CrossEntropyLoss(Cell):
def __init__(self):
super(CrossEntropyLoss, self).__init__()
dp = 2
mp = 8
self.sum = P.ReduceSum()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
((dp, mp),))
self.eps_const = Tensor(1e-24, mstype.float32)
self.sub = P.Sub()
self.exp = P.Exp()
self.div = P.RealDiv()
self.log = P.Log()
self.add = P.Add()
self.mul = P.Mul()
self.neg = P.Neg()
self.sum2 = P.ReduceSum()

self.mul2 = P.Mul()
self.add2 = P.Add()
self.div2 = P.RealDiv()

def construct(self, logits, label, input_mask):
logits = F.cast(logits, mstype.float32)
_, logit_max = self.max(logits)
logit_sub = self.sub(logits, logit_max)
logit_exp = self.exp(logit_sub)
exp_sum = self.sum(logit_exp, -1)
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
softmax_result = self.div(logit_exp, exp_sum)
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
label = P.Reshape()(label, (-1,))
one_hot_label = self.onehot(label, F.shape(logits)[-1], self.on_value,
self.off_value)
loss = self.mul(log_softmax_result, one_hot_label)
loss_unsum = self.neg(loss)
loss_reduce = self.sum(loss_unsum, -1)
input_mask = P.Reshape()(input_mask, (-1,))
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
denominator = self.add2(
self.sum2(input_mask),
P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
loss = self.div2(numerator, denominator)
return loss


class PanGUAlphaWithLoss(Cell):
def __init__(self, network, loss):
super(PanGUAlphaWithLoss, self).__init__(auto_prefix=False)
self.batch_size = 32
self.seq_length = 1024
self.network = network
self.eod_token = True
self.loss = loss

self.slice = P.StridedSlice()
self.not_equal = P.NotEqual()
self.len = self.seq_length
self.slice2 = P.StridedSlice()
self.micro_batch_step = 1

def construct(self, input_ids, input_position=None, attention_mask=None):
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len),
(1, 1, 1))
input_mask = F.cast(self.not_equal(tokens, self.eod_token),
mstype.float32)

logits = self.network(tokens,
input_position,
decoder_attention_masks)
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
(1, 1))
labels = P.Reshape()(labels, (-1,))
input_mask = P.Reshape()(input_mask, (-1,))
output = self.loss(logits, labels, input_mask)
return output


def test_pangu_alpha_shard_propagation():
'''
Feature: sharding propagatation
Description: Propagate strategies on pangu_alpha just use a few ops configured stra
Expectation: Get expected strategies by check several key ops
'''
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16,
search_mode="sharding_propagation")
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
pangu_alpha = PanguAlphaModel()
loss = CrossEntropyLoss()
pangu_alpha_loss = PanGUAlphaWithLoss(pangu_alpha, loss)
net = _VirtualDatasetCell(pangu_alpha_loss)
input_ids = Tensor(np.ones((2, 1025)), mstype.int32)
input_position = Tensor(np.ones((2, 1024)), mstype.int32)
attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float16)
dataset = Dataset(input_ids, input_position, attention_mask)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
stras = _cell_graph_executor._get_shard_strategy(model._train_network)
for (k, v) in stras.items():
if re.search("MultiHeadAttention/Add", k) is not None:
assert v == [[2, 1, 1, 1], [2, 8, 1, 1]]
elif re.search("ReduceMean", k) is not None:
assert v == [[2, 1]]
elif re.search("BatchMatmul", k) is not None:
assert v == [[2, 8, 1, 1], [2, 8, 1, 1]]
context.reset_auto_parallel_context()

Loading…
Cancel
Save