Browse Source

!29819 Add GlobalNorm Search

Merge pull request !29819 from huangxinjing/add_global_norm_search
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
0f24b679ec
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 498 additions and 26 deletions
  1. +16
    -0
      mindspore/ccsrc/frontend/parallel/device_manager.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/device_manager.h
  3. +14
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  4. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  5. +8
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
  6. +2
    -7
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  7. +195
    -10
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  8. +53
    -7
      mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc
  9. +7
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel_utils.h
  10. +6
    -1
      mindspore/python/mindspore/ops/composite/clip_ops.py
  11. +195
    -0
      tests/ut/python/parallel/test_global_norm.py

+ 16
- 0
mindspore/ccsrc/frontend/parallel/device_manager.cc View File

@@ -228,6 +228,22 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,

RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }

RankList DeviceManager::GetDeviceListBetweenStage() const {
std::vector<int64_t> rank_list;
auto rank_id = g_device_manager->global_rank();
auto stage_id = g_device_manager->stage_id();
auto stage_num = g_device_manager->stage_num();
if (stage_num < 1) {
MS_LOG(EXCEPTION) << "Stage num got " << stage_num << ", expected a positive integer.";
}
auto device_num = parallel::ParallelContext::GetInstance()->device_num();
auto per_stage_rank_num = device_num / stage_num;
for (int64_t i = 0; i < stage_num; ++i) {
rank_list.push_back(rank_id + per_stage_rank_num * (i - stage_id));
}
return rank_list;
}

RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
if (LongToSize(stage_id) >= stage_devices_.size())
MS_LOG(ERROR) << "the 'stage_id': " << stage_id


+ 1
- 0
mindspore/ccsrc/frontend/parallel/device_manager.h View File

@@ -66,6 +66,7 @@ class DeviceManager {
static DeviceManager &GetInstance();
RankList GetDeviceListByStageId(int64_t stage_id) const;
RankList GetDeviceListInThisStage() const;
RankList GetDeviceListBetweenStage() const;

Device CreateNewDeviceByRank(int64_t rank) const;
std::vector<Device> CreateDeviceListByRankList(RankList ranks);


+ 14
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -319,6 +319,20 @@ Operator CreateVirtualDivOp(int64_t div_num) {
return op;
}

Operator CreateDivOp(float scale) {
OperatorName operator_name = REAL_DIV;
OperatorAttrs operator_attrs;
OperatorParams operator_param;
size_t parameter_pos = 2;
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scale);
ValuePtr scale_value = MakeValue(tensor_ptr);
operator_param.push_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);

Operator op = std::make_pair(operator_name, operator_arg);
return op;
}

static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) {
ValuePtr attr0_value = MakeValue(reduce_op);
ValuePtr attr1_value = MakeValue(group);


+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -324,6 +324,7 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
Operator CreateAllGatherOp(const std::string &group);
Operator CreateCastOp(TypePtr type);
Operator CreateDivOp(float scale);
Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror);


+ 8
- 0
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h View File

@@ -71,6 +71,7 @@ constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2;
constexpr size_t UNIQUE_INPUTS_SIZE = 1;
constexpr size_t UNIQUE_INPUT_SIZE = 1;
constexpr size_t UNIQUE_OUTPUTS_SIZE = 2;
constexpr size_t RESHAPE_INPUT_SIZE = 3;
constexpr size_t TRANSFER_PERMUTE_ARGS_SIZE = 5;
constexpr size_t TRANSFER_PERMUTE_SPLIT_COUNT_INDEX = 0;
constexpr size_t TRANSFER_PERMUTE_SPLIT_DIM_INDEX = 1;
@@ -82,6 +83,7 @@ constexpr size_t TRANSFER_CONCAT_TENSOR_DIM_INDEX = 0;
constexpr size_t TRANSFER_CONCAT_DEV_DIM_INDEX = 1;
constexpr size_t TRANSFER_CONCAT_SPLIT_COUNT_INDEX = 2;
constexpr size_t TRANSFER_SPLIT_ARGS_SIZE = 3;
constexpr size_t TUPLE_GETITEM_INDEX_POS = 2;
constexpr size_t MATMUL_DDS_INPUTS_SIZE = 4;
constexpr size_t MATMUL_DDS_OUTPUTS_SIZE = 2;
constexpr size_t MATMUL_DDS_STRATEGY_SIZE = 4;
@@ -216,11 +218,16 @@ constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
constexpr char DARA_PARALLEL[] = "data_parallel";
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
constexpr char FIELD_SIZE[] = "field_size";
constexpr char Y[] = "Y";
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
constexpr char DEVICE[] = "Device";
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather";
constexpr char PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE[] = "parallel_optimizer_allgather_not_recompute";
constexpr char PARALLEL_OPTIMIZER_COMM_OP[] = "parallel_optimizer_comm_op";
constexpr char PARALLEL_GLOBALNORM[] = "PARALLEL_GLOBALNORM_IN_STAGES";
constexpr char PARALLEL_GLOBALNORM_BETWEEN[] = "PARALLEL_GLOBALNORM_BETWEEN_STAGES";
constexpr char PARALLEL_GLOBALNORM_DIV[] = "PARALLEL_GLOBALNORM_DIV";
constexpr char GRAD_SCALE[] = "grad_scale";
constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-";

constexpr char OUT_CHANNEL[] = "out_channel";
@@ -296,6 +303,7 @@ constexpr char L2_NORMALIZE[] = "L2Normalize";
constexpr char TRANSPOSE[] = "Transpose";
constexpr char RESHAPE[] = "Reshape";
constexpr char ADD[] = "Add";
constexpr char ADDN[] = "AddN";
constexpr char BIAS_ADD[] = "BiasAdd";
constexpr char SUB[] = "Sub";
constexpr char MUL[] = "Mul";


+ 2
- 7
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -93,6 +93,7 @@ bool PipelineTransformer::NeedGrad(const CNodePtr &cnode, const CNodePtr &graph_
auto temp = input;
while (IsPrimitiveCNode(temp, prim::kPrimLoad) || IsPrimitiveCNode(temp, prim::kPrimCast)) {
auto input_cnode = temp->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
temp = input_cnode->input(1);
}
if (temp->isa<Parameter>()) {
@@ -177,13 +178,7 @@ void PipelineTransformer::LabelMicroBatch() {
}

void PipelineTransformer::CreateForwardGroup() {
std::vector<int64_t> rank_list;
auto rank_id = g_device_manager->global_rank();
auto stage_id = g_device_manager->stage_id();
auto stage_num = g_device_manager->stage_num();
for (int64_t i = 0; i < stage_num; ++i) {
rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id));
}
std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage();
auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list);
auto g = g_device_manager->CreateGroup(rank_list);
auto g_back_name = g.name() + BACKWARD;


+ 195
- 10
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -48,7 +48,6 @@
#include "utils/ms_context.h"
#include "utils/symbolic.h"
#include "mindspore/core/utils/parallel_node_check.h"
#include "mindspore/ccsrc/pybind_api/ir/primitive_py.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/util.h"
#include "ps/ps_context.h"
@@ -61,9 +60,12 @@ namespace parallel {
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
static const std::vector<std::pair<const std::string, int64_t>> REDUCE_SUM_MATCH_PATTERN = {
std::make_pair(MAKE_TUPLE, 1), std::make_pair(ADDN, 1), std::make_pair(SQRT, 1)};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
const uint32_t MAX_BFS_DEPTH = 7;

void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool do_mirror, bool accu_flag) {
if (new_node_input.empty()) {
@@ -421,12 +423,6 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
return tensorinfo_in.tensor_layout();
}

std::string GetPrimName(const CNodePtr &node) {
auto prim = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
return prim->name();
}

OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsParallelCareNode(node)) {
@@ -850,11 +846,11 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3";
}

if (!cnode->input(2)->isa<ValueNode>()) {
if (!cnode->input(TUPLE_GETITEM_INDEX_POS)->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node";
}

ValuePtr tuple_index_value = GetValueNode(cnode->input(2));
ValuePtr tuple_index_value = GetValueNode(cnode->input(TUPLE_GETITEM_INDEX_POS));
MS_EXCEPTION_IF_NULL(tuple_index_value);
if (!tuple_index_value->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32";
@@ -892,6 +888,51 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
}
}

void InsertRealDivOpToNodeInput(const CNodePtr &node, int64_t scale, const string &instance_name) {
MS_EXCEPTION_IF_NULL(node);
if (scale == 0) {
MS_LOG(EXCEPTION) << "Find the scale value is 0, you should check the mirror operators's group size.";
}
size_t node_size = node->inputs().size();
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// instance the real div operator
Operator div_op = CreateDivOp(scale);

// Insert it as the input of the node
for (size_t index = 1; index < node_size; ++index) {
AnfNodePtr input = node->input(index);
MS_EXCEPTION_IF_NULL(input);
// if it is not a tensor, continue
if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
continue;
}
InsertNode(div_op, node, index, node->input(index), func_graph, instance_name);
}
}

void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group, const std::string &instance_name) {
MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size();
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// instance the real div operator
CheckGlobalDeviceManager();
Operator allreduce_op = CreateAllReduceOp(REDUCE_OP_SUM, group);

// Insert it as the input of the node
for (size_t index = 1; index < node_size; ++index) {
AnfNodePtr input = node->input(index);
MS_EXCEPTION_IF_NULL(input);
// if it is not a tensor, continue
if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
continue;
}

InsertNode(allreduce_op, node, index, node->input(index), func_graph, instance_name);
}
}

void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
vector<std::string> last_forward_node_ids;
vector<size_t> last_indexs;
@@ -2192,7 +2233,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
if (StrategyFound(attrs)) {
MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
}
MS_ASSERT(cnode->inputs().size() == 3);
MS_ASSERT(cnode->inputs().size() == RESHAPE_INPUT_SIZE);
auto prev_layout_ptr = FindPrevLayout(cnode->input(1));
if (prev_layout_ptr) {
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
@@ -3092,6 +3133,148 @@ static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfN
}
}

static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) {
auto cnode = res_node->cast<CNodePtr>();
auto graphs = res_node->func_graph();
MS_EXCEPTION_IF_NULL(graphs);
auto manager = graphs->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_user_map = manager->node_users();
if (!IsSomePrimitive(cnode, EXPAND_DIMS)) {
MS_LOG(ERROR) << "Expected the operator expand_dims, but found the " << GetPrimName(cnode)
<< "This may cause the calculation of the global norm incorrect";
return;
}
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
auto expand_dims_node = node_user_map.at(res_node).front().first;
auto sqrt_node = MatchPattern(expand_dims_node, node_user_map, REDUCE_SUM_MATCH_PATTERN);
if (!sqrt_node) return;
auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
Group cur_stage_device_list = g_device_manager->CreateGroup(cur_stage_rank_list);
InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), cur_stage_device_list.name(), PARALLEL_GLOBALNORM);
MS_LOG(INFO) << "Insert the AllReduce for global norm value in stages succeed.";
if (pipeline_stages > 1) {
MS_LOG(INFO) << "Insert the AllReduce for global norm value between stages succeed.";
auto ranks_between_stages = g_device_manager->GetDeviceListBetweenStage();
Group group_between_stages = g_device_manager->CreateGroup(ranks_between_stages);
InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), group_between_stages.name(), PARALLEL_GLOBALNORM_BETWEEN);
}
}

AnfNodePtr FindPrimitiveWithAtrribute(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map, uint32_t limits) {
std::queue<AnfNodePtr> visited;
AnfNodePtr queue_node = nullptr;
CNodePtr cnode = nullptr;
AnfNodePtr last_node = nullptr;
uint32_t depth = 0;
if (!node_ptr) {
return nullptr;
}
visited.push(node_ptr);
while (!visited.empty()) {
queue_node = visited.front();
visited.pop();
cnode = queue_node->cast<CNodePtr>();
// MAKE_TUPLE will not appear after the load in the forward graph
if (IsSomePrimitive(cnode, EXPAND_DIMS)) {
auto value = GetAttrsFromAnfNode(queue_node, GRAD_SCALE);
if (!value || !GetValue<bool>(value)) {
continue;
}
return queue_node;
}
if (!IsSomePrimitiveList(cnode, {ENVIRONGET, MUL, SQUARE, REDUCE_SUM, EXPAND_DIMS, DEPEND, CAST, REF_TO_EMBED})) {
continue;
}
auto node_set = node_users_map.at(queue_node);
for (auto &node_user : node_set) {
visited.push(node_user.first);
}
if (!last_node || last_node == queue_node) {
if (++depth == limits) {
break;
}
last_node = visited.back();
}
}
return nullptr;
}

static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr &parameter,
uint32_t dev_num) {
AnfNodePtr expand_dims_node = nullptr;
AnfNodePtr prefix_node = nullptr;
auto params_user_set = node_user_map.at(parameter);
for (auto &param_pair : params_user_set) {
expand_dims_node = nullptr;
auto cnode = param_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->in_forward_flag()) {
continue;
}
expand_dims_node = FindPrimitiveWithAtrribute(cnode, node_user_map, MAX_BFS_DEPTH);
if (!expand_dims_node) {
continue;
}
auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
if (!value || !GetValue<bool>(value)) {
continue;
}
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->DebugString()
<< "succeed!";
// If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
InsertAllReduceForNormValue(expand_dims_node);
}
}

static AnfNodePtr GetMirrorOp(const NodeUsersMap &node_user_map, const AnfNodePtr &parameter) {
auto params_user_set = node_user_map.at(parameter);
for (auto &param_pair : params_user_set) {
auto cnode = param_pair.first->cast<CNodePtr>();
std::vector<AnfNodePtr> candidate = {cnode};
if (!cnode->in_forward_flag()) {
continue;
}
if (IsInTrivialNodeList(cnode) || IsSomePrimitive(cnode, LOAD)) {
auto load_users = node_user_map.at(param_pair.first);
std::transform(load_users.begin(), load_users.end(), std::back_inserter(candidate),
[](const auto &v) { return v.first; });
}
for (auto &node : candidate) {
auto local_cnode = node->cast<CNodePtr>();
if (!IsPrimitiveCNode(local_cnode, prim::kPrimMirror) &&
!IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMicroStep) &&
!IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMiniStep)) {
continue;
}
return node;
}
}
return nullptr;
}

static void HandlGlobalNormScale(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager) {
auto parameters = root->parameters();
auto node_user_map = manager->node_users();
MS_LOG(INFO) << "Start to process the global norm";
for (auto &parameter : parameters) {
if (!ParameterRequireGrad(parameter)) continue;
auto mirror_node = GetMirrorOp(node_user_map, parameter);
if (!mirror_node) continue;
auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
if (!device_num_ptr) {
MS_LOG(ERROR) << "The mirror operator is excepted to have device number attribute, but found none. This "
"will cause the global norm calculation with wrong precision.";
continue;
}
auto dev_num = device_num_ptr->cast<Int64ImmPtr>()->value();
if (dev_num == 0) continue;
InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num);
}
}

bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
@@ -3201,6 +3384,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
HandleFullySplitParameters(root);

HandlGlobalNormScale(root, all_nodes, manager);

DumpGraph(root, std::string(STEP_PARALLEL_END));

// step parallel only run once


+ 53
- 7
mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc View File

@@ -58,6 +58,17 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
return (prim->name() == name);
}

bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list) {
return std::any_of(check_list.begin(), check_list.end(),
[cnode](const string &in) { return IsSomePrimitive(cnode, in); });
}

std::string GetPrimName(const CNodePtr &node) {
auto prim = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
return prim->name();
}

TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info) {
auto user_cnode = param_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
@@ -108,11 +119,6 @@ AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr
return first_node;
}

bool IsInNodeList(const CNodePtr &cnode, const std::set<string> &check_list) {
return std::any_of(check_list.begin(), check_list.end(),
[cnode](const string &in) { return IsSomePrimitive(cnode, in); });
}

bool IsParallelCareNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
@@ -394,9 +400,9 @@ AnfNodePtr GetChildCastNode(const AnfNodePtr &node_ptr, const NodeUsersMap &node
visited.pop();
cnode = queue_node->cast<CNodePtr>();
// MAKE_TUPLE will not appear after the load in the forward graph
if (IsInNodeList(cnode, {MAKE_TUPLE})) {
if (IsSomePrimitive(cnode, MAKE_TUPLE)) {
continue;
} else if (IsInAllGatherNodeList(cnode) || IsInNodeList(cnode, {LOAD, RESHAPE})) {
} else if (IsInAllGatherNodeList(cnode) || IsSomePrimitiveList(cnode, {LOAD, RESHAPE})) {
auto node_set = node_users_map.at(queue_node);
for (auto &node_user : node_set) {
visited.push(node_user.first);
@@ -473,6 +479,7 @@ AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, cons
type_node->set_abstract(compute_node_type->ToAbstract());
auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
new_node->set_abstract(node->abstract());
new_node->set_in_forward_flag(true);
return new_node;
}

@@ -504,5 +511,44 @@ void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) {
}
}
}

std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key) {
if (!node) return nullptr;
auto cnode = node->cast<CNodePtr>();
auto prim = GetCNodePrimitive(cnode);
if (prim && prim->HasAttr(key)) {
return prim->GetAttr(key);
}
return nullptr;
}

AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
const std::vector<std::pair<const std::string, int64_t>> &match_pattern) {
AnfNodePtr start_node = node;
bool find = false;
for (uint32_t i = 0; i < match_pattern.size(); ++i) {
find = false;
if (!IsSomePrimitive(start_node->cast<CNodePtr>(), {match_pattern[i].first})) {
break;
} else if (i == match_pattern.size() - 1) {
find = true;
break;
}

auto next_node_users = user_map.at(start_node);
for (auto &next_node : next_node_users) {
if (i + 1 < match_pattern.size() &&
IsSomePrimitive(next_node.first->cast<CNodePtr>(), {match_pattern[i + 1].first}) &&
next_node.second == match_pattern[i + 1].second) {
start_node = next_node.first;
break;
}
}
}
if (!find) {
start_node = nullptr;
}
return start_node;
}
} // namespace parallel
} // namespace mindspore

+ 7
- 1
mindspore/ccsrc/frontend/parallel/step_parallel_utils.h View File

@@ -20,6 +20,8 @@
#include <vector>
#include <string>
#include <utility>
#include <set>
#include <memory>
#include "base/base.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
@@ -30,20 +32,24 @@ const int64_t TWO_INPUT_SIZE = 2;

// common method
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsSomePrimitiveList(const CNodePtr &cnode, const std::set<string> &check_list);
bool IsParallelCareNode(const CNodePtr &cnode);
Shapes GetNodeShape(const AnfNodePtr &node);
std::string GetPrimName(const CNodePtr &node);
std::shared_ptr<Value> GetAttrsFromAnfNode(const std::shared_ptr<AnfNode> &node, const string &key);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info);
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
AnfNodePtr MatchPattern(const AnfNodePtr &node, const NodeUsersMap &user_map,
const std::vector<std::pair<const std::string, int64_t>> &match_pattern);

// for specific scenarios
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const TypePtr &compute_node_type);
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
void LabelGenMaskMicro(const FuncGraphPtr &root);
void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes);


+ 6
- 1
mindspore/python/mindspore/ops/composite/clip_ops.py View File

@@ -128,11 +128,16 @@ def clip_by_value(x, clip_value_min=None, clip_value_max=None):
return x_max


# The attribute grad_scale is needed for enabling the parallel mode
# If this is removed, c.clip_by_global_norm will have precision error in semi/auto parallel mode.
expand_dims = P.ExpandDims().add_prim_attr("grad_scale", True)


get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor")
def _get_square_sum(x):
norm = P.ReduceSum(False)(F.square(x), ())
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
norm = expand_dims(F.cast(norm, mstype.float32), 0)
return norm




+ 195
- 0
tests/ut/python/parallel/test_global_norm.py View File

@@ -0,0 +1,195 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test global norm test """
import re
import os
import shutil
import glob
import numpy as np

import mindspore as ms
import mindspore.nn as nn
import mindspore.dataset as ds
from mindspore import Tensor, Parameter, Model
from mindspore.train import DynamicLossScaleManager
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell, MicroBatchInterleaved, PipelineCell
from mindspore.nn.optim import AdamWeightDecay
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import context

class Net(nn.Cell):
"""Net definition"""
def __init__(self, param_type, strategy1, strategy2):
super(Net, self).__init__()
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(param_type)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(param_type)), name="weight2", parallel_optimizer=False)
self.sub = P.Sub()

def construct(self, x, y):
x = P.Cast()(x, ms.float16)
p1 = P.Cast()(self.p1, ms.float16)
p2 = P.Cast()(self.p2, ms.float16)
x = self.fc1(x, p1)
x = self.fc2(x, p2)
return self.sub(x, y)


class Net2(nn.Cell):
"""Net definition"""
def __init__(self, param_type, strategy1, strategy2):
super(Net2, self).__init__()
self.net1 = Net(param_type, strategy1, strategy2)
self.net2 = Net(param_type, strategy1, strategy2)
self.net1.pipeline_stage = 0
self.net2.pipeline_stage = 1
self.sub = P.Sub()

def construct(self, x, y):
out1 = self.net1(x, y)
out2 = self.net2(x, y)
return self.sub(out1, out2)


def get_dataset():
inputs = np.ones([64, 48]).astype(np.float32)
label = np.zeros([64, 16]).astype(np.float32)

def dataset_generator():
for _ in range(10):
yield inputs, label

dataset = ds.GeneratorDataset(dataset_generator, column_names=["inputs", "label"])

return dataset


class CustomOptimizer(AdamWeightDecay):
def __init__(self, params):
super(CustomOptimizer, self).__init__(params)
self.optimizer = super(CustomOptimizer, self).construct

def construct(self, gradients):
grads = C.clip_by_global_norm(gradients)
return self.optimizer(grads)


def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None,
interleaved_batch=2, stages=1, micro_size=1, param_type=np.float32,
loss_scale_manager=None):
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True,
pipeline_stages=stages)

net = MicroBatchInterleaved(net(param_type, strategy1, strategy2), interleaved_batch)
if stages > 1:
net = PipelineCell(net, micro_size=micro_size)
net = _VirtualDatasetCell(net).set_comm_fusion(4)
parameters = net.trainable_params() if stages == 1 else net.infer_param_pipeline_stage()
optimizer = CustomOptimizer(parameters)
if loss_scale_manager:
model = Model(net, optimizer=optimizer, loss_scale_manager=loss_scale_manager)
else:
model = Model(net, optimizer=optimizer)
dataset = get_dataset()
model.train(1, dataset)


class TestGlobalNormInserted:
def setup_method(self):
self.output_path = './graphs' + self.__str__()
context.set_context(save_graphs=True,
save_graphs_path=self.output_path)

def teardown_method(self):
shutil.rmtree(self.output_path)

def run_count_check(self, target_count, pattern):
"""
This function will check the target_key counts with the golden one.
:param target_count: The gold float16 count in the Ir files.
:param pattern: The generated keyword in the Ir files.

"""
# Find the step_parallel_end
ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', 'step_parallel_end*.ir'))
assert len(ir_files) == 1
appear_count = 0
with open(ir_files[0], 'r') as fp:
for line in fp:
res = re.findall(pattern, line)
if len(res) >= 1:
appear_count += 1
assert appear_count == target_count

def test_nonpipeline_global_norm(self):
"""
Feature: Parallel ClipByGlobalNorm
Description: Test the global norm when running in semi auto parallel mode, scale for data parallel should be 8
Expectation:When there is no real div inserted or AllReduce inserted
"""
auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1, param_type=np.float32)
self.run_count_check(target_count=1, pattern=r"=8.*PARALLEL_GLOBALNORM_DIV")
self.run_count_check(target_count=2, pattern=r"PARALLEL_GLOBALNORM")

def test_pipeline_global_norm(self):
"""
Feature: Parallel ClipByGlobalNorm
Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
Expectation: When there is no real div inserted or AllReduce inserted
"""
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32)
self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")

def test_pipeline_global_norm_loss_scale(self):
"""
Feature: Parallel ClipByGlobalNorm
Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
Expectation: When there is no real div inserted or AllReduce inserted
"""
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1, stages=2, micro_size=2, param_type=np.float32,
loss_scale_manager=DynamicLossScaleManager())
self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")


def test_pipeline_global_norm_fp16(self):
"""
Feature: Parallel ClipByGlobalNorm
Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
Expectation: When there is no real div inserted or AllReduce inserted
"""
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16)
self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")

def test_pipeline_global_norm_loss_scale_fp16(self):
"""
Feature: Parallel ClipByGlobalNorm
Description: Test the global norm when running in pipeline mode, scale for data parallel should be 8
Expectation: When there is no real div inserted or AllReduce inserted
"""
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)),
interleaved_batch=1, stages=2, micro_size=2, param_type=np.float16,
loss_scale_manager=DynamicLossScaleManager())
self.run_count_check(target_count=1, pattern=r"=16.*PARALLEL_GLOBALNORM_DIV")
self.run_count_check(target_count=3, pattern=r"PARALLEL_GLOBALNORM")

Loading…
Cancel
Save