diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index ecd42db6bb..9ea583293b 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -65,7 +65,7 @@ double OperatorCost::GetMemoryCost(const std::vector& inputs, // return the per device communication cost in the forward phase. double MatMulCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t&) const { + int32_t) const { TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; Shape input0_shape = input0.shape(); @@ -81,7 +81,7 @@ double MatMulCost::GetForwardCommCost(const std::vector& inputs, con // return the per device communication cost in the forward phase. double MatMulCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not // fully utilize all devices double result = 0.0; @@ -108,7 +108,7 @@ double MatMulCost::GetBackwardCommCost(const std::vector& inputs, co // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double MatMulCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, const int32_t&) const { + const std::vector& outputs, int32_t) const { // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) double result = 0.0; TensorInfo output0 = outputs[0]; @@ -127,7 +127,7 @@ double MatMulCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double MatMulCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; if (is_parameter_[1]) { @@ -152,14 +152,14 @@ double MatMulCost::GetBackwardComputationCost(const std::vector& inp // Return the per device communication cost in the forward phase. double ActivationCost::GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { + int32_t) const { // ReLU is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // Return the per device communication cost in the backward phase. double ActivationCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { TensorInfo input1 = inputs[0]; @@ -181,7 +181,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector& inputs // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double ActivationCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); @@ -190,20 +190,19 @@ double ActivationCost::GetForwardComputationCost(const std::vector& // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double ActivationCost::GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const { + int32_t) const { return 0.0; } // Return the per device communication cost in the forward phase. -double SoftmaxCost::GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { +double SoftmaxCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { // In the forward phase, the communication cost = 0 return 0.0; } // Return the per device communication cost in the backward phase. double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { TensorInfo input1 = inputs[0]; @@ -225,7 +224,7 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { // In the forward phase, the computation cost = slice(A) TensorInfo input0 = inputs[0]; Shape input0_slice_shape = input0.slice_shape(); @@ -235,21 +234,20 @@ double SoftmaxCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, - const int32_t&) const { + const std::vector&, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. double TmpIdentityCost::GetForwardCommCost(const std::vector&, - const std::vector&, const int32_t&) const { + const std::vector&, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. double TmpIdentityCost::GetBackwardCommCost(const std::vector&, - const std::vector&, const int32_t&) const { + const std::vector&, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the backward phase return 0.0; } @@ -257,16 +255,14 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector&, - const std::vector&, - const int32_t&) const { + const std::vector&, int32_t) const { return 0.0; } // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, - const int32_t&) const { + const std::vector&, int32_t) const { return 0.0; } @@ -277,7 +273,7 @@ double TmpIdentityCost::GetMemoryCost(const std::vector&, const std: double BatchParallelCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { double cost = 0.0; for (size_t i = 0; i < inputs.size(); ++i) { cost += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); @@ -287,20 +283,19 @@ double BatchParallelCost::GetForwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const { + int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double PReLUCost::GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { +double PReLUCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { // prelu does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. double PReLUCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { double result = 0.0; if (is_parameter_[1]) { TensorInfo input1 = inputs[1]; @@ -323,7 +318,7 @@ double PReLUCost::GetBackwardCommCost(const std::vector& inputs, con // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double PReLUCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -336,7 +331,7 @@ double PReLUCost::GetForwardComputationCost(const std::vector& input // this operator uses double PReLUCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; if (is_parameter_[1]) { @@ -360,15 +355,13 @@ double PReLUCost::GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const { +double OneHotCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { // onehot does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double OneHotCost::GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { +double OneHotCost::GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const { // onehot does not need communication in the backward phase return 0.0; } @@ -376,7 +369,7 @@ double OneHotCost::GetBackwardCommCost(const std::vector&, const std // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double OneHotCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { // In onehot's forward phase, the computation cost = slice(A) Shape input0_slice_shape = inputs[0].slice_shape(); return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); @@ -385,20 +378,20 @@ double OneHotCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double OneHotCost::GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const { + int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector&, - const std::vector&, const int32_t&) const { + const std::vector&, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector&, - const std::vector&, const int32_t&) const { + const std::vector&, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase return 0.0; } @@ -406,8 +399,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector< // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, - const int32_t&) const { + const std::vector&, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -419,14 +411,13 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, - const int32_t&) const { + const std::vector&, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. double ReshapeCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const { + int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); @@ -441,15 +432,14 @@ double ReshapeCost::GetForwardCommCost(const std::vector& inputs, co } // return the per device communication cost in the backward phase. -double ReshapeCost::GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { +double ReshapeCost::GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const { return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses double ReshapeCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, const int32_t& stage_id) const { + const std::vector& outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); @@ -466,13 +456,12 @@ double ReshapeCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses double ReshapeCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, - const int32_t&) const { + const std::vector&, int32_t) const { return 0.0; } double ArithmeticCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { double result; result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); @@ -480,7 +469,7 @@ double ArithmeticCost::GetForwardComputationCost(const std::vector& } double ArithmeticCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -515,7 +504,7 @@ double ArithmeticCost::GetBackwardComputationCost(const std::vector& } double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -550,7 +539,7 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs return result; } -bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& stage_id) { +bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_id) { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); @@ -560,7 +549,7 @@ bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& } double ReduceMethodCost::GetForwardCommCost(const std::vector& inputs, - const std::vector& outputs, const int32_t& stage_id) const { + const std::vector& outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -571,7 +560,7 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector& input } std::vector dim_list = input0.reduce_dim(); std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](const int32_t& index) { + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; }); if (pos != dim_list.end()) { @@ -582,7 +571,7 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector& input } double ReduceMethodCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, - const int32_t& stage_id) const { + int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { TensorInfo input_tensor_info = inputs[0]; @@ -605,8 +594,7 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector& inpu } double ReduceMethodCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, - const int32_t& stage_id) const { + const std::vector& outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -615,7 +603,7 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector Shape input0_shape = input0.shape(); if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](const int32_t& index) { + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; }); if (pos != dim_list.end()) { @@ -628,8 +616,7 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector } double ReduceMeanCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, - const int32_t& stage_id) const { + const std::vector& outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -638,7 +625,7 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector& Shape input0_shape = input0.shape(); if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](const int32_t& index) { + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; }); if (pos != dim_list.end()) { @@ -651,7 +638,7 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector& } double DropOutCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { if (inputs.empty()) { return 0.0; } @@ -661,21 +648,20 @@ double DropOutCost::GetForwardComputationCost(const std::vector& inp } // return the per device communication cost in the forward phase. -double GatherV2Cost::GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { +double GatherV2Cost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { // GatherV2Cost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. double GatherV2Cost::GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const { + int32_t) const { // GatherV2Cost does not need communication in the backward phase return 0.0; } double GatherV2Cost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, - const int32_t&) const { + int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -685,8 +671,56 @@ double GatherV2Cost::GetForwardComputationCost(const std::vector& in } double GatherV2Cost::GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const { + int32_t) const { return 0.0; } + +double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost"; + } + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; + } + + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t index = 0; index < inputs.size(); ++index) { + if (is_parameter_[index]) { + TensorInfo tensor_info = inputs[index]; + Shape shape = tensor_info.shape(); + Shape slice_shape = tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (slice_shape[i] == 0) { + MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape); + } + used_device_num *= shape[i] / slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); + } + } + } + return result; +} + +double LayerNormCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, + int32_t) const { + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; + } + + for (size_t index = 0; index < inputs.size(); ++index) { + TensorInfo tensor_info = inputs[index]; + Shape slice_shape = tensor_info.slice_shape(); + result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); + } + return result; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 7dc45bae71..f16dfa21fc 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -72,18 +72,18 @@ class OperatorCost { // per device communication cost virtual double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; + int32_t stage_id) const = 0; virtual double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; + int32_t stage_id) const = 0; virtual double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; + int32_t stage_id) const = 0; // per device computation cost virtual double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const = 0; + int32_t stage_id) const = 0; virtual double GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, const int32_t& stage_id) const = 0; + const std::vector& outputs, int32_t stage_id) const = 0; virtual double GetBackwardComputationCost(const std::vector& inputs, - const std::vector& outputs, const int32_t& stage_id) const = 0; + const std::vector& outputs, int32_t stage_id) const = 0; // per device PEAK memory cost in a training iteration // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), // plus necessary inputs. @@ -114,23 +114,23 @@ class MatMulCost : public OperatorCost { // per device communication cost double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; // per device computation cost double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using MatMulCostPtr = std::shared_ptr; @@ -141,21 +141,21 @@ class ActivationCost : public OperatorCost { ~ActivationCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using ActivationCostPtr = std::shared_ptr; using TransposeCost = ActivationCost; @@ -168,21 +168,21 @@ class SoftmaxCost : public OperatorCost { ~SoftmaxCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t&) const override; + int32_t) const override; }; using SoftmaxCostPtr = std::shared_ptr; @@ -193,21 +193,21 @@ class TmpIdentityCost : public OperatorCost { ~TmpIdentityCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; // per device PEAK memory cost in a training iteration double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override; }; @@ -220,25 +220,23 @@ class BatchParallelCost : public OperatorCost { ~BatchParallelCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using BatchParallelCostPtr = std::shared_ptr; @@ -249,27 +247,25 @@ class VirtualDatasetCost : public OperatorCost { ~VirtualDatasetCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } double GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } // per device PEAK memory cost in a training iteration @@ -286,29 +282,27 @@ class GeneratorBaseCost : public OperatorCost { ~GeneratorBaseCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. double GetForwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } // Generator ops don't have backward steps. double GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } }; @@ -322,23 +316,23 @@ class PReLUCost : public OperatorCost { // per device communication cost double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; // per device computation cost double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using PReLUCostPtr = std::shared_ptr; @@ -350,23 +344,23 @@ class OneHotCost : public OperatorCost { // per device communication cost double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; // per device computation cost double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using OneHotCostPtr = std::shared_ptr; @@ -378,23 +372,23 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { // per device communication cost double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; // per device computation cost double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; @@ -407,27 +401,27 @@ class ReshapeCost : public OperatorCost { // per device communication cost double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; // per device computation cost double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using ReshapeCostPtr = std::shared_ptr; @@ -438,24 +432,22 @@ class ArithmeticCost : public OperatorCost { ~ArithmeticCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override; + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using ArithmeticCostPtr = std::shared_ptr; using BiasAddCost = ArithmeticCost; @@ -468,21 +460,21 @@ class ReduceMethodCost : public OperatorCost { ~ReduceMethodCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } void set_cross_batch(bool cb) { cross_batch_ = cb; } @@ -499,7 +491,7 @@ class ReduceMeanCost : public ReduceMethodCost { ~ReduceMeanCost() override = default; double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; }; using ReduceMeanCostPtr = std::shared_ptr; @@ -510,29 +502,27 @@ class GetNextCost : public OperatorCost { ~GetNextCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. double GetForwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } // Generator ops don't have backward steps. double GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } }; @@ -545,25 +535,51 @@ class DropOutCost : public OperatorCost { ~DropOutCost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector&, const std::vector&, + int32_t) const override; + double GetBackwardComputationCost(const std::vector&, const std::vector&, + int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, - const int32_t&) const override { +}; + +using DropOutCostPtr = std::shared_ptr; + +class LayerNormCost : public OperatorCost { + public: + explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + LayerNormCost() : OperatorCost(true) {} + ~LayerNormCost() override = default; + + double GetCommCost(const std::vector& inputs, const std::vector& outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { return 0.0; } + double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override; + int32_t) const override; double GetBackwardComputationCost(const std::vector&, const std::vector&, - const int32_t&) const override { + int32_t) const override { return 0.0; } }; @@ -577,21 +593,21 @@ class GatherV2Cost : public OperatorCost { ~GatherV2Cost() override = default; double GetCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override { + int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t& stage_id) const override; + int32_t stage_id) const override; double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, - const int32_t&) const override; + int32_t) const override; }; using GatherV2CostPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 1b864cd8bf..953380fb32 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -101,6 +101,7 @@ REGISTER(CosInfo); REGISTER(ACosInfo); REGISTER(LogicalNotInfo); REGISTER(L2NormalizeInfo); +REGISTER(LayerNormInfo); REGISTER(ReduceMaxInfo); REGISTER(ArgMaxWithValueInfo); REGISTER(ArgMinWithValueInfo); diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc index c11db56082..e659759de2 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.cc @@ -195,8 +195,8 @@ Status Softmax::GetAttrs() { // for example: tensor dimension is 4, then axis range [-4, 3] int32_t dim = SizeToInt(inputs_shape_.at(0).size()); - auto it = std::find_if(axis_.begin(), axis_.end(), - [dim](const int32_t& element) { return ((element >= dim) || (element < -dim)); }); + auto it = + std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); }); if (it != axis_.end()) { MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "]."; return FAILED; diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc new file mode 100644 index 0000000000..3abfc3d2ed --- /dev/null +++ b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parallel/ops_info/layer_norm_info.h" +#include +#include +#include "parallel/device_matrix.h" +#include "parallel/strategy.h" + +namespace mindspore { +namespace parallel { +Status LayerNormInfo::GetAttrs() { + auto iter = attrs_.find(BEGIN_NORM_AXIS); + if (iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis"; + return FAILED; + } + if ((iter->second == nullptr) || !iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": The axis type is not int"; + return FAILED; + } + + int32_t dim = SizeToInt(input_shape_.size()); + auto axis = GetValue(iter->second); + if ((axis >= dim) || (axis < -dim)) { + MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim << ", " << dim - 1 << "]"; + return FAILED; + } + + if (axis < 0) { + axis = axis + dim; + } + begin_norm_axis_ = IntToSize(axis); + return SUCCESS; +} + +Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + std::vector stra = strategy->GetInputDim(); + if (stra.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); + return FAILED; + } + + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy value"; + return FAILED; + } + + Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX]; + Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX]; + Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX]; + if (begin_norm_axis_ >= input_strategy.size()) { + MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; + return FAILED; + } + // check input strategy + for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { + if (input_strategy[begin_norm_axis_] != NO_SPLIT_STRATEGY) { + MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); + return FAILED; + } + } + + // check gamma and beta strategy + if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) { + MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy"; + return FAILED; + } + + size_t gamma_diff = input_strategy.size() - gamma_strategy.size(); + for (size_t j = 0; j < gamma_strategy.size(); ++j) { + if (gamma_strategy[j] != input_strategy[gamma_diff + j]) { + MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy); + return FAILED; + } + } + + size_t beta_diff = input_strategy.size() - beta_strategy.size(); + for (size_t k = 0; k < beta_strategy.size(); ++k) { + if (beta_strategy[k] != input_strategy[beta_diff + k]) { + MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy); + return FAILED; + } + } + return SUCCESS; +} + +Status LayerNormInfo::InferDevMatrixShape() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status LayerNormInfo::CreateTensorMap(size_t input_index) { + if (inputs_shape_.size() <= input_index) { + MS_LOG(ERROR) << name_ << ": Invalid index" << input_index; + return FAILED; + } + Shape shape = inputs_shape_[input_index]; + Shape tensor_map; + for (size_t i = 0; i < shape.size(); ++i) { + tensor_map.push_back(SizeToInt(shape.size() - i - 1)); + } + inputs_tensor_map_.push_back(tensor_map); + outputs_tensor_map_.push_back(tensor_map); + return SUCCESS; +} + +Status LayerNormInfo::InferTensorMap() { + if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create tensor map failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::CreateMirrorOp(size_t input_index) { + if (inputs_tensor_map_.size() <= input_index) { + MS_LOG(ERROR) << name_ << ": Invalid index " << input_index; + return FAILED; + } + Shape tensor_map = inputs_tensor_map_[input_index]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed"; + return FAILED; + } + OperatorVector mirror_op; + if (!group.empty()) { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is " + << group[0].name(); + } + mirror_ops_.push_back(mirror_op); + return SUCCESS; +} + +Status LayerNormInfo::InferMirrorOps() { + if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create mirror op failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::CreateTensorInfo(size_t input_index) { + if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) { + MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index; + return FAILED; + } + Shape tensor_map = inputs_tensor_map_[input_index]; + Shape shape = inputs_shape_[input_index]; + TensorLayout tensor_layout; + if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed"; + return FAILED; + } + + TensorInfo tensor_info(tensor_layout); + inputs_tensor_info_.push_back(tensor_info); + outputs_tensor_info_.push_back(tensor_info); + return SUCCESS; +} + +Status LayerNormInfo::InferTensorInfo() { + if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create tensor info failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error"; + return FAILED; + } + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0]) + << ", as_loss_divisor_ is " << as_loss_divisor_; + return SUCCESS; +} + +Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Set cost failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector &sp_vector) { + if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { + MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input"; + return FAILED; + } + + size_t gamma_diff = input_shape_.size() - gamma_shape_.size(); + size_t beta_diff = input_shape_.size() - beta_shape_.size(); + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + std::vector tmp_strategy; + Dimensions input_strategy = sp->GetInputDim()[0]; + Dimensions gamma_strategy = input_strategy; + (void)gamma_strategy.erase(gamma_strategy.begin(), + gamma_strategy.begin() + static_cast(gamma_diff)); + Dimensions beta_strategy = input_strategy; + (void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast(beta_diff)); + + // reset the strategy + tmp_strategy.push_back(input_strategy); + tmp_strategy.push_back(gamma_strategy); + tmp_strategy.push_back(beta_strategy); + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +Status LayerNormInfo::GenerateStrategies(int32_t stage_id) { + if (InitShapes() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init shapes failed"; + return FAILED; + } + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Get attrs failed"; + return FAILED; + } + Shape input_split(input_shape_.size(), SPLIT_FLAG); + if (begin_norm_axis_ >= input_split.size()) { + MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; + return FAILED; + } + + // Can not split the dimensions from begin norm axis + for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) { + input_split[i] = NO_SPLIT_FLAG; + } + + // Generate strategy for input + Shapes splittable_inputs = {input_split}; + Shapes tmp_inputs_shape = {input_shape_}; + std::vector sp_vector; + is_auto_parallel_ = true; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate input strategy failed"; + return FAILED; + } + + // Generate the strategies for gamma and beta + if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate gamma and beta strategies failed"; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(DEBUG) << name_ << ": Successfully generated " << success << " strategy"; + } + } + return SUCCESS; +} + +Status LayerNormInfo::InitShapes() { + if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid inputs size"; + return FAILED; + } + input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX]; + gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX]; + beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX]; + return SUCCESS; +} + +Status LayerNormInfo::Init(const StrategyPtr &strategy) { + if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed"; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success"; + return SUCCESS; +} + +Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) { + if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed"; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success"; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h new file mode 100644 index 0000000000..c52645ade2 --- /dev/null +++ b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "parallel/auto_parallel/operator_costmodel.h" +#include "parallel/ops_info/operator_info.h" +#include "parallel/strategy.h" + +namespace mindspore { +namespace parallel { +constexpr size_t LAYER_NORM_INPUT_SIZE = 3; +constexpr size_t LAYER_NORM_INPUT_INDEX = 0; +constexpr size_t LAYER_NORM_GAMMA_INDEX = 1; +constexpr size_t LAYER_NORM_BETA_INDEX = 2; +constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; + +// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split +// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. +class LayerNormInfo : public OperatorInfo { + public: + LayerNormInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), + begin_norm_axis_(0) {} + ~LayerNormInfo() override = default; + + Status Init(const StrategyPtr& strategy) override; + Status InitForCostModel(const StrategyPtr& strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr&) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr& strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferAsLossDivisor() override; + Status CreateTensorMap(size_t input_index); + Status CreateTensorInfo(size_t input_index); + Status CreateMirrorOp(size_t input_index); + Status GenerateGammaAndBetaStrategies(const std::vector& sp_vector); + Status InitShapes(); + + private: + size_t begin_norm_axis_; + Shape input_shape_; + Shape gamma_shape_; + Shape beta_shape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h index 27b434ecca..aec25f7f41 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h @@ -27,6 +27,7 @@ #include "parallel/ops_info/gather_v2_info.h" #include "parallel/ops_info/get_next_info.h" #include "parallel/ops_info/l2_normalize_info.h" +#include "parallel/ops_info/layer_norm_info.h" #include "parallel/ops_info/loss_info.h" #include "parallel/ops_info/matmul_info.h" #include "parallel/ops_info/onehot_info.h" diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 88377d237b..50920e5954 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -26,6 +26,8 @@ constexpr int32_t PRELU_CHANNEL_INDEX = 1; constexpr int32_t PRELU_CHANNEL_STRATEGY = 1; constexpr int32_t NO_SPLIT_MAP = -1; constexpr int32_t NO_SPLIT_STRATEGY = 1; +constexpr int32_t SPLIT_FLAG = 1; +constexpr int32_t NO_SPLIT_FLAG = 0; constexpr size_t MATMUL_ATTRS_SIZE = 2; constexpr size_t MATMUL_INPUTS_SIZE = 2; constexpr size_t MATMUL_OUTPUTS_SIZE = 1; @@ -173,6 +175,7 @@ constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; constexpr char CONV2D[] = "Conv2D"; constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; constexpr char BATCH_NORM[] = "BatchNorm"; +constexpr char LAYER_NORM[] = "LayerNorm"; constexpr char POOLING[] = "Pooling"; constexpr char CAST[] = "Cast"; constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 33097bf2b7..5caf6573f2 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -82,6 +82,7 @@ std::vector splittable_op_ = {MATMUL, SIMPLE_MEAN, FLATTEN, BATCH_NORM, + LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc index 4e34847582..5291e2f48d 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc @@ -245,8 +245,8 @@ void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangeme unified_out_tensor_map, unified_tensor_shape); } -void ValidRedistributionLayoutCheckAll(const int32_t& device_pow_size, const int32_t& tensor_pow_size, - const int32_t& max_device_dim, const int32_t& max_shape_dim) { +void ValidRedistributionLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow_size, + int32_t max_device_dim, int32_t max_shape_dim) { std::vector> layout_list; GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim, &layout_list); diff --git a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc index 36b89684f6..9d6152721e 100644 --- a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc @@ -260,8 +260,8 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheck11) { ValidUnifiedLayoutCheck(device_arrangement, in_tensor_map, in_tensor_shape, out_tensor_map, out_tensor_shape); } -void ValidInferUnifiedLayoutCheckAll(const int32_t& device_pow_size, const int32_t& tensor_pow_size, - const int32_t& max_device_dim, const int32_t& max_shape_dim) { +void ValidInferUnifiedLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow_size, + int32_t max_device_dim, int32_t max_shape_dim) { std::vector> layout_list; GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim, &layout_list); diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc index 07d270c95c..93147c486b 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc @@ -51,7 +51,7 @@ std::vector> combine(const std::vector& in, int32_ return output; } -void GenerateValidShapeBySizeAndDim(const int32_t& pow_size, const int32_t& dim, +void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim, std::vector>* out) { out->clear(); std::vector in; @@ -78,7 +78,7 @@ void GenerateValidShapeBySizeAndDim(const int32_t& pow_size, const int32_t& dim, return; } -void GenerateValidShapeBySize(const int32_t& pow_size, std::vector>* out) { +void GenerateValidShapeBySize(int32_t pow_size, std::vector>* out) { out->clear(); for (int32_t dim = 1; dim <= pow_size; dim++) { std::vector> combine_result; @@ -148,8 +148,8 @@ void GenerateValidTensorMap(const std::vector& device_arrangement, cons } void GenerateValidLayoutByDeviceSizeAndTensorSize( - const int32_t& device_pow_size, const int32_t& tensor_pow_size, const int32_t& max_device_dim, - const int32_t& max_shape_dim, + int32_t device_pow_size, int32_t tensor_pow_size, int32_t max_device_dim, + int32_t max_shape_dim, std::vector, std::vector, std::vector>>* layout_list) { layout_list->clear(); std::vector> device_arrangement_list; diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h index e14556378f..a359cadbea 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h @@ -27,10 +27,10 @@ namespace parallel { std::vector> combine(const std::vector& in, int32_t target); -void GenerateValidShapeBySizeAndDim(const int32_t& pow_size, const int32_t& dim, +void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim, std::vector>* out); -void GenerateValidShapeBySize(const int32_t& pow_size, std::vector>* out); +void GenerateValidShapeBySize(int32_t pow_size, std::vector>* out); std::vector GenerateTensorMap(const uint32_t& map_size, const std::vector& pos_index, const std::vector& pos_value); @@ -39,8 +39,8 @@ void GenerateValidTensorMap(const std::vector& device_arrangement, cons std::vector>* tensor_map_list); void GenerateValidLayoutByDeviceSizeAndTensorSize( - const int32_t& device_pow_size, const int32_t& tensor_pow_size, const int32_t& max_device_dim, - const int32_t& max_shape_dim, + int32_t device_pow_size, int32_t tensor_pow_size, int32_t max_device_dim, + int32_t max_shape_dim, std::vector, std::vector, std::vector>>* layout_list); uint32_t ComputeNoneNumber(const std::vector& tensor_map); diff --git a/tests/ut/python/parallel/test_layer_norm.py b/tests/ut/python/parallel/test_layer_norm.py new file mode 100644 index 0000000000..c65ee5fc8e --- /dev/null +++ b/tests/ut/python/parallel/test_layer_norm.py @@ -0,0 +1,96 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor +from mindspore.common.initializer import initializer + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None): + super().__init__() + self.begin_norm_axis = -1 + self.begin_params_axis = 1 + self.mul = P.Mul().set_strategy(strategy1) + self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2) + self.mul2 = P.Mul().set_strategy(strategy3) + self.mul_weight = Parameter(mul_weight, "w1") + self.normalized_shape = [64, 32, 16] + self.gamma = Parameter(initializer('ones', self.normalized_shape), name="gamma") + self.beta = Parameter(initializer('zeros', self.normalized_shape), name="beta") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out, _, _ = self.layer_norm(out, self.gamma, self.beta) + out = self.mul2(out, b) + return out + + +_x = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32) +_w = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32) + + +def compile(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_layer_norm_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1, 1), (16, 1, 1, 1)) + strategy2 = ((16, 1, 1, 1), (1, 1, 1), (1, 1, 1)) + strategy3 = ((16, 1, 1, 1), (16, 1, 1, 1)) + net = Net(_w, strategy1, strategy2, strategy3) + compile(net) + + +def test_layer_norm_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1)) + strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1)) + strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1)) + net = Net(_w, strategy1, strategy2, strategy3) + compile(net) + + +def test_layer_norm_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1)) + strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1)) + strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) + net = Net(_w, strategy1, strategy2, strategy3) + compile(net) + + +def test_layer_norm_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w) + compile(net) + + +def test_layer_norm_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1)) + strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1)) + strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) + net = Net(_w, strategy1, strategy2, strategy3) + compile(net) +