From: @xiaoda_zh Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.2.0-rc1
| @@ -27,6 +27,7 @@ void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_ | |||
| void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) { | |||
| is_parameter_involve_ = is_parameter_inv; | |||
| is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false); | |||
| } | |||
| void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; } | |||
| @@ -41,27 +42,28 @@ void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ | |||
| double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs) const { | |||
| return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs); | |||
| } | |||
| double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const { | |||
| double result = 0.0; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (is_inputs_should_in_memory_[i]) { | |||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs) const { | |||
| double result = 0.0; | |||
| if (output_parameter_involve_ == 1) { | |||
| if (is_output_should_in_memory_) { | |||
| // When this operator has multiple outputs, they all contributes to the memory. | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]); | |||
| } | |||
| bool is_any_para_inv = | |||
| std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); | |||
| if (is_any_para_inv) { | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (is_parameter_[i]) { | |||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||
| } else if (inputs_related_ && (!is_parameter_involve_[i])) { | |||
| // When the inputs of this operator are related, and they are not parameter-involved, then they are included | |||
| // in the memory cost. | |||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| @@ -166,16 +168,43 @@ double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inp | |||
| return result; | |||
| } | |||
| // Not taking account of output | |||
| void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Taking account of input | |||
| void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (is_parameter_[1]) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } else if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| // Return the per device communication cost in the forward phase. | |||
| double ActivationCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | |||
| // ReLU is the element-wise operator, thus it does not need communication in the forward phase | |||
| return 0.0; | |||
| } | |||
| // Return the per device communication cost in the backward phase. | |||
| double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double result = 0.0; | |||
| if (is_parameter_[0]) { | |||
| TensorInfo input1 = inputs[0]; | |||
| @@ -196,8 +225,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| TensorInfo input0 = inputs[0]; | |||
| Shape input0_slice_shape = input0.slice_shape(); | |||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| @@ -205,11 +234,33 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> & | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Not taking account of input | |||
| void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| // Taking account of output | |||
| void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||
| // Taking account of input | |||
| void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| // Return the per device communication cost in the forward phase. | |||
| double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| @@ -259,6 +310,81 @@ double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::para | |||
| return 0.0; | |||
| } | |||
| // Taking account of output | |||
| void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||
| // Not taking account of input | |||
| void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| // Not taking account of output | |||
| void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Not taking account of input | |||
| void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| // Not taking account of output | |||
| void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Taking account of input | |||
| void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| } | |||
| // Not taking account of output | |||
| void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| // Taking account of input | |||
| void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| // Taking account of input | |||
| void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calulating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &, | |||
| const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const { | |||
| @@ -288,9 +414,12 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore:: | |||
| return 0.0; | |||
| } | |||
| // Return the per device PEAK memory cost contributed by this operator in a training iteration. | |||
| double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const { | |||
| return 0.0; | |||
| // Not taking account of output | |||
| void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Not taking account of input | |||
| void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs, | |||
| @@ -334,6 +463,42 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inp | |||
| return result; | |||
| } | |||
| void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (is_parameter_[1]) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } else if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() { | |||
| is_output_should_in_memory_ = is_parameter_involve_[0]; | |||
| } | |||
| void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory( | |||
| const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | |||
| // prelu does not need communication in the forward phase | |||
| @@ -401,6 +566,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parall | |||
| return result; | |||
| } | |||
| void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of both 'x' and 'y'; | |||
| // when calculating 'dy', taking account of both 'x' and 'y' | |||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | |||
| // onehot does not need communication in the forward phase | |||
| @@ -430,6 +610,17 @@ double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, c | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Not taking account of input | |||
| void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| is_inputs_should_in_memory_[3] = is_parameter_[3]; | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &, | |||
| const std::vector<TensorInfo> &, int64_t) const { | |||
| @@ -463,6 +654,16 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std:: | |||
| return 0.0; | |||
| } | |||
| // Taking account of output | |||
| void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() { | |||
| is_output_should_in_memory_ = is_parameter_involve_[0]; | |||
| } | |||
| void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const { | |||
| @@ -524,16 +725,23 @@ double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::para | |||
| return 0.0; | |||
| } | |||
| double ArithmeticCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| double result; | |||
| result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) + | |||
| ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]); | |||
| return result; | |||
| } | |||
| double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &, int64_t stage_id) const { | |||
| double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double result = 0.0; | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| @@ -567,8 +775,8 @@ double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo> | |||
| return result; | |||
| } | |||
| double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double result = 0.0; | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| @@ -603,6 +811,273 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs | |||
| return result; | |||
| } | |||
| // Not taking account of output | |||
| void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Not taking account of input | |||
| void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| // Taking account of input | |||
| void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (is_parameter_[1]) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } else if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| // Taking account of output | |||
| void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } | |||
| // Taking account of input | |||
| void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| // When calculating 'dy', taking account of 'y' | |||
| if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| } | |||
| // Taking account of input | |||
| void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', not taking account of 'x' and 'y' | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| // When calculating 'dy', taking account of 'x' and 'y' | |||
| if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| } | |||
| void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } | |||
| void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of both 'x' and 'power' | |||
| if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| // When calculating 'dpower', taking account of 'x' | |||
| if (is_parameter_[1]) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } else if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'x' | |||
| if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| // When calculating 'dy', not taking account of 'x' and 'y' | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of both 'x' and 'y' | |||
| if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| // When calculating 'dy', not taking account of 'x' and 'y' | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| } | |||
| void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and | |||
| // 'y' | |||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| } | |||
| void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } | |||
| void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| // When calculating 'dy', taking account of 'y' | |||
| if (is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| } | |||
| void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of both 'x' and 'y'; | |||
| // when calculating 'dy', taking account of both 'x' and 'y' | |||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| } | |||
| void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' and 'z' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| if (!is_inputs_should_in_memory_[2]) { | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| } | |||
| void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y', 'z' and 'w' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) { | |||
| is_inputs_should_in_memory_[3] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) { | |||
| is_inputs_should_in_memory_[3] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| if (!is_inputs_should_in_memory_[2]) { | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| if (!is_inputs_should_in_memory_[3]) { | |||
| is_inputs_should_in_memory_[3] = is_parameter_[3]; | |||
| } | |||
| } | |||
| void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) { | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| @@ -612,8 +1087,8 @@ bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_ | |||
| return (total_device_num == LongToSize(strategy0)); | |||
| } | |||
| double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||
| double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const { | |||
| double result = 0.0; | |||
| TensorInfo input0 = inputs[0]; | |||
| TensorInfo output0 = outputs[0]; | |||
| @@ -634,8 +1109,8 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &input | |||
| return result; | |||
| } | |||
| double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double result = 0.0; | |||
| if (is_parameter_[0]) { | |||
| TensorInfo input_tensor_info = inputs[0]; | |||
| @@ -657,8 +1132,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inpu | |||
| return result; | |||
| } | |||
| double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||
| double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||
| double result = 0.0; | |||
| TensorInfo input0 = inputs[0]; | |||
| TensorInfo output0 = outputs[0]; | |||
| @@ -679,6 +1154,30 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo> | |||
| return result; | |||
| } | |||
| // Not taking account of output | |||
| void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| // Not taking account of 'y' | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| } | |||
| double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||
| double result = 0.0; | |||
| @@ -701,6 +1200,42 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> & | |||
| return result; | |||
| } | |||
| void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||
| void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| // Not taking account of 'y' | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| } | |||
| void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||
| void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'x' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| } | |||
| } | |||
| double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t) const { | |||
| if (inputs.empty()) { | |||
| @@ -760,6 +1295,52 @@ double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &, | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' and 'z' | |||
| if (is_parameter_[0]) { | |||
| // 'x' is parameter, so it should be in memory. | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| if (!is_inputs_should_in_memory_[2]) { | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| } | |||
| void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| if (is_inputs_should_in_memory_.size() == 0) { | |||
| return; | |||
| } | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||
| void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||
| int64_t stage_id) const { | |||
| double result = 0.0; | |||
| @@ -808,6 +1389,24 @@ double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &i | |||
| return result; | |||
| } | |||
| void LayerNormCost::CalculateOutputInMemory() { | |||
| is_output_should_in_memory_ = is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[2]; | |||
| } | |||
| void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of both 'x' and 'y' | |||
| // When calculating 'dy', taking account of both 'x' and 'y' | |||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const { | |||
| return 0.0; | |||
| @@ -924,6 +1523,12 @@ double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector< | |||
| return result; | |||
| } | |||
| void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||
| } | |||
| double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||
| double result = 0.0; | |||
| @@ -1019,6 +1624,29 @@ double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<Tenso | |||
| return result; | |||
| } | |||
| // Not taking account of output | |||
| void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Taking account of input | |||
| void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'y' | |||
| if (is_parameter_[0]) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } else if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||
| TensorInfo input0 = inputs[0]; | |||
| @@ -1078,5 +1706,40 @@ double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<Tenso | |||
| ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin | |||
| return result; | |||
| } | |||
| // Taking account of output | |||
| void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||
| // Taking account of input | |||
| void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| // When calculating 'dx', taking account of 'x', 'y' and 'z' | |||
| if (is_parameter_involve_[0]) { | |||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||
| is_inputs_should_in_memory_[0] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||
| is_inputs_should_in_memory_[1] = true; | |||
| } | |||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||
| is_inputs_should_in_memory_[2] = true; | |||
| } | |||
| } | |||
| if (!is_inputs_should_in_memory_[1]) { | |||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||
| } | |||
| if (!is_inputs_should_in_memory_[2]) { | |||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||
| } | |||
| } | |||
| // Not taking account of output | |||
| void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||
| // Not taking account of input | |||
| void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||
| for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) { | |||
| is_inputs_should_in_memory_[i] = is_parameter_[i]; | |||
| } | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | |||
| @@ -47,16 +48,7 @@ double ListProduct(std::vector<T> vec) { | |||
| // entries timing the length of each entry's data type | |||
| class OperatorCost { | |||
| public: | |||
| explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { | |||
| // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | |||
| for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | |||
| is_parameter_.push_back(false); | |||
| is_parameter_involve_.push_back(false); | |||
| inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||
| outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||
| } | |||
| } | |||
| OperatorCost() : inputs_related_(false) { | |||
| OperatorCost() { | |||
| // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | |||
| for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | |||
| is_parameter_.push_back(false); | |||
| @@ -89,10 +81,17 @@ class OperatorCost { | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0; | |||
| virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0; | |||
| virtual void CalculateOutputInMemory() = 0; | |||
| virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0; | |||
| bool is_output_in_memory() const { return is_output_should_in_memory_; } | |||
| // per device PEAK memory cost in a training iteration | |||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), | |||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved), | |||
| // plus necessary inputs. | |||
| virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | |||
| // Contributing the input part for 'GetMemoryCost' | |||
| double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | |||
| // Contributing the output part for 'GetMemoryCost' | |||
| double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | |||
| // per device memory cost in a inference phase | |||
| double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const; | |||
| @@ -101,25 +100,25 @@ class OperatorCost { | |||
| // pre-operator that has parameters as input. | |||
| std::vector<bool> is_parameter_involve_; | |||
| int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||
| // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while | |||
| // Mul's two inputs are dependent (related). | |||
| bool inputs_related_; | |||
| // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | |||
| // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | |||
| std::vector<bool> is_parameter_; | |||
| // for each input and output, the followings record the number of bytes of each element | |||
| // Whether the input should keep in memory in training phase. It depends on the operator and the operator's | |||
| // previous operators. | |||
| std::vector<bool> is_inputs_should_in_memory_; | |||
| // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator. | |||
| bool is_output_should_in_memory_ = false; | |||
| // For each input and output, the followings record the number of bytes of each element | |||
| std::vector<size_t> inputs_type_lengths_; | |||
| std::vector<size_t> outputs_type_lengths_; | |||
| // Whether the output is critical, which means that this output is included in calculating peak memory cost | |||
| // in the inference phase. | |||
| int64_t is_outputs_critical_ = -1; | |||
| }; | |||
| using OperatorCostPtr = std::shared_ptr<OperatorCost>; | |||
| class MatMulCost : public OperatorCost { | |||
| public: | |||
| explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| MatMulCost() : OperatorCost(true) {} | |||
| MatMulCost() : OperatorCost() {} | |||
| ~MatMulCost() override = default; | |||
| // per device communication cost | |||
| @@ -141,14 +140,15 @@ class MatMulCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| void CalculateOutputInMemory() override; | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using MatMulCostPtr = std::shared_ptr<MatMulCost>; | |||
| using TensorDotCost = MatMulCost; | |||
| class ActivationCost : public OperatorCost { | |||
| class CastCost : public OperatorCost { | |||
| public: | |||
| explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ActivationCost() : OperatorCost(false) {} | |||
| ~ActivationCost() override = default; | |||
| CastCost() : OperatorCost() {} | |||
| ~CastCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override { | |||
| @@ -166,21 +166,95 @@ class ActivationCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using RepeatElementsCost = CastCost; | |||
| using NegCost = CastCost; | |||
| using ExpandDimsCost = CastCost; | |||
| using SqueezeCost = CastCost; | |||
| using ConcatCost = CastCost; | |||
| using LogicalNotCost = CastCost; | |||
| using SignCost = CastCost; | |||
| using FloorCost = CastCost; | |||
| using RoundCost = CastCost; | |||
| using CeilCost = CastCost; | |||
| using ZerosLikeCost = CastCost; | |||
| using OnesLikeCost = CastCost; | |||
| using RangeCost = CastCost; | |||
| using SplitCost = CastCost; | |||
| class SqrtCost : public CastCost { | |||
| public: | |||
| SqrtCost() : CastCost() {} | |||
| ~SqrtCost() override = default; | |||
| // Taking account of output, not taking accounting of input | |||
| void CalculateOutputInMemory() override; | |||
| }; | |||
| using ActivationCostPtr = std::shared_ptr<ActivationCost>; | |||
| using TransposeCost = ActivationCost; | |||
| using TransposeCostPtr = std::shared_ptr<TransposeCost>; | |||
| using StridedSliceCost = ActivationCost; | |||
| using StridedSliceCostPtr = std::shared_ptr<StridedSliceCost>; | |||
| using SliceCost = ActivationCost; | |||
| using SliceCostPtr = std::shared_ptr<SliceCost>; | |||
| using SplitCost = ActivationCost; | |||
| using SplitCostPtr = std::shared_ptr<SplitCost>; | |||
| using TanhCost = SqrtCost; | |||
| using EluCost = SqrtCost; | |||
| using ReLUCost = SqrtCost; | |||
| using SigmoidCost = SqrtCost; | |||
| using ReciprocalCost = | |||
| SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen | |||
| using InvCost = SqrtCost; | |||
| using RsqrtCost = SqrtCost; | |||
| using AsinhCost = SqrtCost; | |||
| using AcoshCost = SqrtCost; | |||
| using ReLUV2Cost = SqrtCost; | |||
| class ReLU6Cost : public CastCost { | |||
| public: | |||
| ReLU6Cost() : CastCost() {} | |||
| ~ReLU6Cost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using SoftsignCost = ReLU6Cost; | |||
| using SoftplusCost = ReLU6Cost; | |||
| using SquareCost = ReLU6Cost; | |||
| using ExpCost = ReLU6Cost; | |||
| using LogCost = ReLU6Cost; | |||
| using CosCost = ReLU6Cost; | |||
| using ACosCost = ReLU6Cost; | |||
| using AbsCost = ReLU6Cost; | |||
| using TanCost = ReLU6Cost; | |||
| using SinCost = ReLU6Cost; | |||
| using SinhCost = ReLU6Cost; | |||
| using Log1pCost = ReLU6Cost; | |||
| using Expm1Cost = ReLU6Cost; | |||
| using CoshCost = ReLU6Cost; | |||
| using AtanhCost = ReLU6Cost; | |||
| using AtanCost = ReLU6Cost; | |||
| using AsinCost = ReLU6Cost; | |||
| using ErfCost = ReLU6Cost; | |||
| using ErfcCost = ReLU6Cost; | |||
| using ActivationInfoCost = ReLU6Cost; | |||
| class TransposeCost : public CastCost { | |||
| public: | |||
| TransposeCost() : CastCost() {} | |||
| ~TransposeCost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class GeLUCost : public SqrtCost { | |||
| public: | |||
| GeLUCost() : SqrtCost() {} | |||
| ~GeLUCost() override = default; | |||
| // Taking account of input and output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using BesselI0eCost = GeLUCost; | |||
| using BesselI1eCost = GeLUCost; | |||
| using L2NormalizeCost = GeLUCost; | |||
| class SoftmaxCost : public OperatorCost { | |||
| public: | |||
| explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| SoftmaxCost() : OperatorCost(false) {} | |||
| SoftmaxCost() : OperatorCost() {} | |||
| ~SoftmaxCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -199,21 +273,45 @@ class SoftmaxCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t) const override; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class TileCost : public SoftmaxCost { | |||
| public: | |||
| TileCost() : SoftmaxCost() {} | |||
| ~TileCost() override = default; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class PackCost : public SoftmaxCost { | |||
| public: | |||
| PackCost() : SoftmaxCost() {} | |||
| ~PackCost() override = default; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class BroadcastToCost : public SoftmaxCost { | |||
| public: | |||
| BroadcastToCost() : SoftmaxCost() {} | |||
| ~BroadcastToCost() override = default; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; | |||
| using TileCost = SoftmaxCost; | |||
| using TileCostPtr = std::shared_ptr<TileCost>; | |||
| using PackCost = TileCost; | |||
| using PackCostPtr = std::shared_ptr<PackCost>; | |||
| using ConcatCost = TileCost; | |||
| using ConcatCostPtr = std::shared_ptr<ConcatCost>; | |||
| using BroadcastToCost = SoftmaxCost; | |||
| using BroadcastToCostPtr = std::shared_ptr<BroadcastToCost>; | |||
| class TmpIdentityCost : public OperatorCost { | |||
| public: | |||
| explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| TmpIdentityCost() : OperatorCost(false) {} | |||
| TmpIdentityCost() : OperatorCost() {} | |||
| ~TmpIdentityCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -232,15 +330,16 @@ class TmpIdentityCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // per device PEAK memory cost in a training iteration | |||
| double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; | |||
| class BatchParallelCost : public OperatorCost { | |||
| public: | |||
| explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| BatchParallelCost() : OperatorCost(false) {} | |||
| BatchParallelCost() : OperatorCost() {} | |||
| ~BatchParallelCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -259,13 +358,25 @@ class BatchParallelCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost { | |||
| public: | |||
| SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {} | |||
| ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>; | |||
| class VirtualDatasetCost : public OperatorCost { | |||
| public: | |||
| explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| VirtualDatasetCost() : OperatorCost(false) {} | |||
| VirtualDatasetCost() : OperatorCost() {} | |||
| ~VirtualDatasetCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -290,17 +401,15 @@ class VirtualDatasetCost : public OperatorCost { | |||
| int64_t) const override { | |||
| return 0.0; | |||
| } | |||
| // per device PEAK memory cost in a training iteration | |||
| double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override { | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>; | |||
| class GeneratorBaseCost : public OperatorCost { | |||
| public: | |||
| explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GeneratorBaseCost() : OperatorCost(false) {} | |||
| GeneratorBaseCost() : OperatorCost() {} | |||
| ~GeneratorBaseCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -332,8 +441,7 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>; | |||
| class PReLUCost : public OperatorCost { | |||
| public: | |||
| explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| PReLUCost() : OperatorCost(true) {} | |||
| PReLUCost() : OperatorCost() {} | |||
| ~PReLUCost() override = default; | |||
| // per device communication cost | |||
| @@ -355,13 +463,16 @@ class PReLUCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using PReLUCostPtr = std::shared_ptr<PReLUCost>; | |||
| class OneHotCost : public OperatorCost { | |||
| public: | |||
| explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| OneHotCost() : OperatorCost(true) {} | |||
| OneHotCost() : OperatorCost() {} | |||
| ~OneHotCost() override = default; | |||
| // per device communication cost | |||
| @@ -383,13 +494,16 @@ class OneHotCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using OneHotCostPtr = std::shared_ptr<OneHotCost>; | |||
| class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | |||
| public: | |||
| explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} | |||
| SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {} | |||
| ~SoftmaxCrossEntropyWithLogitsCost() override = default; | |||
| // per device communication cost | |||
| @@ -411,13 +525,15 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropyWithLogitsCost>; | |||
| class ReshapeCost : public OperatorCost { | |||
| public: | |||
| explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ReshapeCost() : OperatorCost(true) {} | |||
| ReshapeCost() : OperatorCost() {} | |||
| ~ReshapeCost() override = default; | |||
| @@ -444,14 +560,17 @@ class ReshapeCost : public OperatorCost { | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using ReshapeCostPtr = std::shared_ptr<ReshapeCost>; | |||
| class ArithmeticCost : public OperatorCost { | |||
| class SubCost : public OperatorCost { | |||
| public: | |||
| explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ArithmeticCost() : OperatorCost(false) {} | |||
| ~ArithmeticCost() override = default; | |||
| SubCost() : OperatorCost() {} | |||
| ~SubCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override { | |||
| @@ -470,16 +589,127 @@ class ArithmeticCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using TensorAddCost = SubCost; | |||
| using FloorDivCost = SubCost; | |||
| using AssignSubCost = SubCost; | |||
| using AssignAddCost = SubCost; | |||
| using LogicalAndCost = SubCost; | |||
| using LogicalOrCost = SubCost; | |||
| using BiasAddCost = SubCost; | |||
| using EqualCost = SubCost; | |||
| using ApproximateEqualCost = SubCost; | |||
| using NotEqualCost = SubCost; | |||
| using GreaterCost = SubCost; | |||
| using GreaterEqualCost = SubCost; | |||
| using LessCost = SubCost; | |||
| using LessEqualCost = SubCost; | |||
| class MulCost : public SubCost { | |||
| public: | |||
| MulCost() : SubCost() {} | |||
| ~MulCost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>; | |||
| using BiasAddCost = ArithmeticCost; | |||
| using BiasAddCostPtr = std::shared_ptr<BiasAddCost>; | |||
| class ReduceMethodCost : public OperatorCost { | |||
| class DivCost : public SubCost { | |||
| public: | |||
| explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ReduceMethodCost() : OperatorCost(true) {} | |||
| ~ReduceMethodCost() override = default; | |||
| DivCost() : SubCost() {} | |||
| ~DivCost() override = default; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using ReadDivCost = DivCost; | |||
| class ModCost : public SubCost { | |||
| public: | |||
| ModCost() : SubCost() {} | |||
| ~ModCost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using FloorModCost = ModCost; | |||
| class PowCost : public SubCost { | |||
| public: | |||
| PowCost() : SubCost() {} | |||
| ~PowCost() override = default; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class AssignCost : public SubCost { | |||
| public: | |||
| AssignCost() : SubCost() {} | |||
| ~AssignCost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class SigmoidCrossEntropyWithLogitsCost : public SubCost { | |||
| public: | |||
| SigmoidCrossEntropyWithLogitsCost() : SubCost() {} | |||
| ~SigmoidCrossEntropyWithLogitsCost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class Atan2Cost : public SubCost { | |||
| public: | |||
| Atan2Cost() : SubCost() {} | |||
| ~Atan2Cost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class DivNoNanCost : public SubCost { | |||
| public: | |||
| DivNoNanCost() : SubCost() {} | |||
| ~DivNoNanCost() override = default; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class MaximumCost : public SubCost { | |||
| public: | |||
| MaximumCost() : SubCost() {} | |||
| ~MaximumCost() override = default; | |||
| // Taking account of input, not taking account of output | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using MinimumCost = MaximumCost; | |||
| class SliceCost : public CastCost { | |||
| public: | |||
| SliceCost() : CastCost() {} | |||
| ~SliceCost() override = default; | |||
| // Not taking account of output, taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class StridedSliceCost : public CastCost { | |||
| public: | |||
| StridedSliceCost() : CastCost() {} | |||
| ~StridedSliceCost() override = default; | |||
| // Not taking account of output, taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class ReduceSumCost : public OperatorCost { | |||
| public: | |||
| ReduceSumCost() : OperatorCost() {} | |||
| ~ReduceSumCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override { | |||
| @@ -500,27 +730,50 @@ class ReduceMethodCost : public OperatorCost { | |||
| return 0.0; | |||
| } | |||
| void set_cross_batch(bool cb) { cross_batch_ = cb; } | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| protected: | |||
| bool cross_batch_ = false; | |||
| }; | |||
| using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>; | |||
| using ReduceMethodCost = ReduceSumCost; | |||
| class ReduceMeanCost : public ReduceMethodCost { | |||
| class ReduceMeanCost : public ReduceSumCost { | |||
| public: | |||
| explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} | |||
| ReduceMeanCost() : ReduceMethodCost(true) {} | |||
| ReduceMeanCost() : ReduceSumCost() {} | |||
| ~ReduceMeanCost() override = default; | |||
| double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t stage_id) const override; | |||
| }; | |||
| using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>; | |||
| class ReduceMinCost : public ReduceSumCost { | |||
| public: | |||
| ReduceMinCost() : ReduceSumCost() {} | |||
| ~ReduceMinCost() override = default; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using ReduceMaxCost = ReduceMinCost; | |||
| class ArgMaxWithValueCost : public ReduceSumCost { | |||
| public: | |||
| ArgMaxWithValueCost() : ReduceSumCost() {} | |||
| ~ArgMaxWithValueCost() override = default; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using ArgMinWithValueCost = ArgMaxWithValueCost; | |||
| class GetNextCost : public OperatorCost { | |||
| public: | |||
| explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GetNextCost() : OperatorCost(false) {} | |||
| GetNextCost() : OperatorCost() {} | |||
| ~GetNextCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -547,13 +800,17 @@ class GetNextCost : public OperatorCost { | |||
| int64_t) const override { | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using GetNextCostPtr = std::shared_ptr<GetNextCost>; | |||
| class DropOutCost : public OperatorCost { | |||
| // For memory cost, taking account of output, not taking account of input | |||
| class DropOutCost : public SqrtCost { | |||
| public: | |||
| explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| DropOutCost() : OperatorCost(true) {} | |||
| DropOutCost() : SqrtCost() {} | |||
| ~DropOutCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -578,12 +835,19 @@ class DropOutCost : public OperatorCost { | |||
| } | |||
| }; | |||
| using DropOutCostPtr = std::shared_ptr<DropOutCost>; | |||
| class DropOutDoMaskCost : public DropOutCost { | |||
| public: | |||
| DropOutDoMaskCost() : DropOutCost() {} | |||
| ~DropOutDoMaskCost() override = default; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| class UnsortedSegmentSumCost : public OperatorCost { | |||
| public: | |||
| explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| UnsortedSegmentSumCost() : OperatorCost(true) {} | |||
| UnsortedSegmentSumCost() : OperatorCost() {} | |||
| ~UnsortedSegmentSumCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -602,14 +866,15 @@ class UnsortedSegmentSumCost : public OperatorCost { | |||
| int64_t) const override { | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using UnsortedSegmentSumCostPtr = std::shared_ptr<UnsortedSegmentSumCost>; | |||
| class UnsortedSegmentMinCost : public OperatorCost { | |||
| public: | |||
| explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| UnsortedSegmentMinCost() : OperatorCost(true) {} | |||
| UnsortedSegmentMinCost() : OperatorCost() {} | |||
| ~UnsortedSegmentMinCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -628,14 +893,16 @@ class UnsortedSegmentMinCost : public OperatorCost { | |||
| int64_t) const override { | |||
| return 0.0; | |||
| } | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using UnsortedSegmentMinCostPtr = std::shared_ptr<UnsortedSegmentMinCost>; | |||
| using UnsortedSegmentMaxCost = UnsortedSegmentMinCost; | |||
| class LayerNormCost : public OperatorCost { | |||
| public: | |||
| explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| LayerNormCost() : OperatorCost(true) {} | |||
| LayerNormCost() : OperatorCost() {} | |||
| ~LayerNormCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -656,14 +923,15 @@ class LayerNormCost : public OperatorCost { | |||
| int64_t) const override { | |||
| return 0.0; | |||
| } | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using DropOutCostPtr = std::shared_ptr<DropOutCost>; | |||
| class UniqueCost : public OperatorCost { | |||
| public: | |||
| explicit UniqueCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| UniqueCost() : OperatorCost(true) {} | |||
| UniqueCost() : OperatorCost() {} | |||
| ~UniqueCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -682,14 +950,15 @@ class UniqueCost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t) const override; | |||
| // Taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using UniqueCostPtr = std::shared_ptr<UniqueCost>; | |||
| class UniformCandidateSamplerCost : public OperatorCost { | |||
| public: | |||
| explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| UniformCandidateSamplerCost() : OperatorCost(false) {} | |||
| UniformCandidateSamplerCost() : OperatorCost() {} | |||
| ~UniformCandidateSamplerCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -714,14 +983,15 @@ class UniformCandidateSamplerCost : public OperatorCost { | |||
| int64_t) const override { | |||
| return 0.0; | |||
| } | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Not Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using UniformCandidateSamplerCostPtr = std::shared_ptr<UniformCandidateSamplerCost>; | |||
| class GatherV2Cost : public OperatorCost { | |||
| public: | |||
| explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GatherV2Cost() : OperatorCost(true) {} | |||
| GatherV2Cost() : OperatorCost() {} | |||
| ~GatherV2Cost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -740,14 +1010,15 @@ class GatherV2Cost : public OperatorCost { | |||
| int64_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int64_t) const override; | |||
| // Not taking account of output | |||
| void CalculateOutputInMemory() override; | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>; | |||
| class GatherV2PCost : public OperatorCost { | |||
| class GatherV2PCost : public GatherV2Cost { | |||
| public: | |||
| explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} | |||
| GatherV2PCost() : OperatorCost(true), axis_(0) {} | |||
| GatherV2PCost() : GatherV2Cost(), axis_(0) {} | |||
| ~GatherV2PCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| @@ -773,8 +1044,6 @@ class GatherV2PCost : public OperatorCost { | |||
| int64_t axis_; | |||
| Shape strategy_; | |||
| }; | |||
| using GatherV2PCostPtr = std::shared_ptr<GatherV2PCost>; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ | |||
| @@ -50,8 +50,8 @@ class ActivationBase : public OperatorInfo { | |||
| class Activation : public ActivationBase { | |||
| public: | |||
| Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {} | |||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~Activation() override = default; | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -64,7 +64,7 @@ class ActivationInfo : public Activation { | |||
| public: | |||
| ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||
| : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {} | |||
| ~ActivationInfo() override = default; | |||
| protected: | |||
| @@ -74,8 +74,8 @@ class ActivationInfo : public Activation { | |||
| class ActivationOther : public Activation { | |||
| public: | |||
| ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||
| : Activation(name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~ActivationOther() override = default; | |||
| protected: | |||
| @@ -86,7 +86,7 @@ class GeluInfo : public ActivationOther { | |||
| public: | |||
| GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {} | |||
| ~GeluInfo() override = default; | |||
| }; | |||
| @@ -94,7 +94,7 @@ class TanhInfo : public ActivationOther { | |||
| public: | |||
| TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanhCost>()) {} | |||
| ~TanhInfo() override = default; | |||
| }; | |||
| @@ -102,7 +102,7 @@ class Softmax : public ActivationBase { | |||
| public: | |||
| explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {} | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {} | |||
| ~Softmax() override = default; | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -134,7 +134,7 @@ class LogSoftmaxInfo : public Softmax { | |||
| class EluInfo : public ActivationOther { | |||
| public: | |||
| EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<EluCost>()) {} | |||
| ~EluInfo() override = default; | |||
| }; | |||
| @@ -142,7 +142,7 @@ class ReLUInfo : public ActivationOther { | |||
| public: | |||
| ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {} | |||
| ~ReLUInfo() override = default; | |||
| }; | |||
| @@ -150,7 +150,7 @@ class RepeatElementsInfo : public ActivationOther { | |||
| public: | |||
| RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RepeatElementsCost>()) {} | |||
| ~RepeatElementsInfo() override = default; | |||
| }; | |||
| @@ -158,7 +158,7 @@ class ReLU6Info : public ActivationOther { | |||
| public: | |||
| ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLU6Cost>()) {} | |||
| ~ReLU6Info() override = default; | |||
| }; | |||
| @@ -166,7 +166,7 @@ class SoftsignInfo : public ActivationOther { | |||
| public: | |||
| SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftsignCost>()) {} | |||
| ~SoftsignInfo() override = default; | |||
| }; | |||
| @@ -174,7 +174,7 @@ class SoftplusInfo : public ActivationOther { | |||
| public: | |||
| SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftplusCost>()) {} | |||
| ~SoftplusInfo() override = default; | |||
| }; | |||
| @@ -182,7 +182,7 @@ class CastInfo : public ActivationOther { | |||
| public: | |||
| CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CastCost>()) {} | |||
| ~CastInfo() override = default; | |||
| protected: | |||
| @@ -193,14 +193,14 @@ class SqrtInfo : public ActivationOther { | |||
| public: | |||
| SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqrtCost>()) {} | |||
| ~SqrtInfo() override = default; | |||
| }; | |||
| class NegInfo : public ActivationOther { | |||
| public: | |||
| NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<NegCost>()) {} | |||
| ~NegInfo() override = default; | |||
| }; | |||
| @@ -208,7 +208,7 @@ class ExpandDimsInfo : public ActivationOther { | |||
| public: | |||
| ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpandDimsCost>()) {} | |||
| ~ExpandDimsInfo() override = default; | |||
| protected: | |||
| @@ -228,7 +228,7 @@ class SqueezeInfo : public ActivationOther { | |||
| public: | |||
| SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqueezeCost>()) {} | |||
| ~SqueezeInfo() override = default; | |||
| protected: | |||
| @@ -247,7 +247,7 @@ class SquareInfo : public ActivationOther { | |||
| public: | |||
| SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareCost>()) {} | |||
| ~SquareInfo() override = default; | |||
| }; | |||
| @@ -255,7 +255,7 @@ class SigmoidInfo : public ActivationOther { | |||
| public: | |||
| SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SigmoidCost>()) {} | |||
| ~SigmoidInfo() override = default; | |||
| }; | |||
| @@ -263,7 +263,7 @@ class DropoutInfo : public ActivationOther { | |||
| public: | |||
| DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} | |||
| ~DropoutInfo() override = default; | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| @@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo { | |||
| class SubInfo : public ArithmeticBase { | |||
| public: | |||
| SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SubCost>()) {} | |||
| ~SubInfo() override = default; | |||
| }; | |||
| @@ -64,28 +64,28 @@ class TensorAddInfo : public ArithmeticBase { | |||
| public: | |||
| TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {} | |||
| ~TensorAddInfo() override = default; | |||
| }; | |||
| class MulInfo : public ArithmeticBase { | |||
| public: | |||
| MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulCost>()) {} | |||
| ~MulInfo() override = default; | |||
| }; | |||
| class DivInfo : public ArithmeticBase { | |||
| public: | |||
| DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivCost>()) {} | |||
| ~DivInfo() override = default; | |||
| }; | |||
| class ModInfo : public ArithmeticBase { | |||
| public: | |||
| ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ModCost>()) {} | |||
| ~ModInfo() override = default; | |||
| }; | |||
| @@ -93,7 +93,7 @@ class RealDivInfo : public ArithmeticBase { | |||
| public: | |||
| RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReadDivCost>()) {} | |||
| ~RealDivInfo() override = default; | |||
| }; | |||
| @@ -101,7 +101,7 @@ class FloorDivInfo : public ArithmeticBase { | |||
| public: | |||
| FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorDivCost>()) {} | |||
| ~FloorDivInfo() override = default; | |||
| }; | |||
| @@ -109,14 +109,14 @@ class FloorModInfo : public ArithmeticBase { | |||
| public: | |||
| FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorModCost>()) {} | |||
| ~FloorModInfo() override = default; | |||
| }; | |||
| class PowInfo : public ArithmeticBase { | |||
| public: | |||
| PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<PowCost>()) {} | |||
| ~PowInfo() override = default; | |||
| }; | |||
| @@ -124,7 +124,7 @@ class AssignSubInfo : public ArithmeticBase { | |||
| public: | |||
| AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignSubCost>()) {} | |||
| ~AssignSubInfo() override = default; | |||
| }; | |||
| @@ -132,7 +132,7 @@ class AssignInfo : public ArithmeticBase { | |||
| public: | |||
| AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignCost>()) {} | |||
| ~AssignInfo() override = default; | |||
| }; | |||
| @@ -140,7 +140,7 @@ class AssignAddInfo : public ArithmeticBase { | |||
| public: | |||
| AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignAddCost>()) {} | |||
| ~AssignAddInfo() override = default; | |||
| }; | |||
| @@ -149,7 +149,8 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { | |||
| public: | |||
| SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, | |||
| std::make_shared<SigmoidCrossEntropyWithLogitsCost>()) {} | |||
| ~SigmoidCrossEntropyWithLogitsInfo() override = default; | |||
| }; | |||
| @@ -157,7 +158,7 @@ class Atan2Info : public ArithmeticBase { | |||
| public: | |||
| Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<Atan2Cost>()) {} | |||
| ~Atan2Info() override = default; | |||
| }; | |||
| @@ -165,7 +166,7 @@ class DivNoNanInfo : public ArithmeticBase { | |||
| public: | |||
| DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivNoNanCost>()) {} | |||
| ~DivNoNanInfo() override = default; | |||
| }; | |||
| @@ -173,7 +174,7 @@ class LogicalAndInfo : public ArithmeticBase { | |||
| public: | |||
| LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalAndCost>()) {} | |||
| ~LogicalAndInfo() override = default; | |||
| }; | |||
| @@ -181,7 +182,7 @@ class LogicalOrInfo : public ArithmeticBase { | |||
| public: | |||
| LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {} | |||
| ~LogicalOrInfo() override = default; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -34,8 +34,7 @@ class BatchParallelInfo : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} | |||
| BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)), | |||
| dev_num_(1) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {} | |||
| ~BatchParallelInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -62,7 +61,8 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | |||
| public: | |||
| SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, | |||
| const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {} | |||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, | |||
| std::make_shared<SparseSoftmaxCrossEntropyWithLogitsCost>()) {} | |||
| ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| }; | |||
| @@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo { | |||
| public: | |||
| BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {} | |||
| ~BiasAddInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -36,7 +36,7 @@ class BroadcastToInfo : public OperatorInfo { | |||
| public: | |||
| BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>()) {} | |||
| ~BroadcastToInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -32,7 +32,7 @@ class EqualInfo : public ArithmeticBase { | |||
| public: | |||
| EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<EqualCost>()) {} | |||
| ~EqualInfo() override = default; | |||
| }; | |||
| @@ -40,7 +40,7 @@ class ApproximateEqualInfo : public ArithmeticBase { | |||
| public: | |||
| ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ApproximateEqualCost>()) {} | |||
| ~ApproximateEqualInfo() override = default; | |||
| }; | |||
| @@ -48,7 +48,7 @@ class NotEqualInfo : public ArithmeticBase { | |||
| public: | |||
| NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<NotEqualCost>()) {} | |||
| ~NotEqualInfo() override = default; | |||
| }; | |||
| @@ -56,7 +56,7 @@ class MaximumInfo : public ArithmeticBase { | |||
| public: | |||
| MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MaximumCost>()) {} | |||
| ~MaximumInfo() override = default; | |||
| }; | |||
| @@ -64,7 +64,7 @@ class MinimumInfo : public ArithmeticBase { | |||
| public: | |||
| MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MinimumCost>()) {} | |||
| ~MinimumInfo() override = default; | |||
| }; | |||
| @@ -72,7 +72,7 @@ class GreaterInfo : public ArithmeticBase { | |||
| public: | |||
| GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterCost>()) {} | |||
| ~GreaterInfo() override = default; | |||
| }; | |||
| @@ -80,7 +80,7 @@ class GreaterEqualInfo : public ArithmeticBase { | |||
| public: | |||
| GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterEqualCost>()) {} | |||
| ~GreaterEqualInfo() override = default; | |||
| }; | |||
| @@ -88,7 +88,7 @@ class LessInfo : public ArithmeticBase { | |||
| public: | |||
| LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessCost>()) {} | |||
| ~LessInfo() override = default; | |||
| }; | |||
| @@ -96,7 +96,7 @@ class LessEqualInfo : public ArithmeticBase { | |||
| public: | |||
| LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessEqualCost>()) {} | |||
| ~LessEqualInfo() override = default; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -33,7 +33,7 @@ class ConcatInfo : public OperatorInfo { | |||
| public: | |||
| ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>()) {} | |||
| ~ConcatInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||
| public: | |||
| DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutDoMaskCost>()) {} | |||
| ~DropoutDoMaskInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -20,6 +20,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||
| #include "frontend/parallel/ops_info/activation_info.h" | |||
| @@ -30,21 +31,21 @@ namespace parallel { | |||
| class ExpInfo : public ActivationOther { | |||
| public: | |||
| ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpCost>()) {} | |||
| ~ExpInfo() override = default; | |||
| }; | |||
| class LogInfo : public ActivationOther { | |||
| public: | |||
| LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogCost>()) {} | |||
| ~LogInfo() override = default; | |||
| }; | |||
| class CosInfo : public ActivationOther { | |||
| public: | |||
| CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CosCost>()) {} | |||
| ~CosInfo() override = default; | |||
| }; | |||
| @@ -52,7 +53,7 @@ class ACosInfo : public ActivationOther { | |||
| public: | |||
| ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ACosCost>()) {} | |||
| ~ACosInfo() override = default; | |||
| }; | |||
| @@ -60,14 +61,14 @@ class LogicalNotInfo : public ActivationOther { | |||
| public: | |||
| LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalNotCost>()) {} | |||
| ~LogicalNotInfo() override = default; | |||
| }; | |||
| class AbsInfo : public ActivationOther { | |||
| public: | |||
| AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AbsCost>()) {} | |||
| ~AbsInfo() override = default; | |||
| }; | |||
| @@ -75,7 +76,7 @@ class SignInfo : public ActivationOther { | |||
| public: | |||
| SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SignCost>()) {} | |||
| ~SignInfo() override = default; | |||
| }; | |||
| @@ -83,7 +84,7 @@ class FloorInfo : public ActivationOther { | |||
| public: | |||
| FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorCost>()) {} | |||
| ~FloorInfo() override = default; | |||
| }; | |||
| @@ -91,7 +92,7 @@ class RoundInfo : public ActivationOther { | |||
| public: | |||
| RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RoundCost>()) {} | |||
| ~RoundInfo() override = default; | |||
| }; | |||
| @@ -99,14 +100,14 @@ class ReciprocalInfo : public ActivationOther { | |||
| public: | |||
| ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReciprocalCost>()) {} | |||
| ~ReciprocalInfo() override = default; | |||
| }; | |||
| class InvInfo : public ActivationOther { | |||
| public: | |||
| InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<InvCost>()) {} | |||
| ~InvInfo() override = default; | |||
| }; | |||
| @@ -114,21 +115,21 @@ class RsqrtInfo : public ActivationOther { | |||
| public: | |||
| RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RsqrtCost>()) {} | |||
| ~RsqrtInfo() override = default; | |||
| }; | |||
| class TanInfo : public ActivationOther { | |||
| public: | |||
| TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanCost>()) {} | |||
| ~TanInfo() override = default; | |||
| }; | |||
| class SinInfo : public ActivationOther { | |||
| public: | |||
| SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinCost>()) {} | |||
| ~SinInfo() override = default; | |||
| }; | |||
| @@ -136,7 +137,7 @@ class SinhInfo : public ActivationOther { | |||
| public: | |||
| SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinhCost>()) {} | |||
| ~SinhInfo() override = default; | |||
| }; | |||
| @@ -144,7 +145,7 @@ class Log1pInfo : public ActivationOther { | |||
| public: | |||
| Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Log1pCost>()) {} | |||
| ~Log1pInfo() override = default; | |||
| }; | |||
| @@ -152,7 +153,7 @@ class Expm1Info : public ActivationOther { | |||
| public: | |||
| Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Expm1Cost>()) {} | |||
| ~Expm1Info() override = default; | |||
| }; | |||
| @@ -160,7 +161,7 @@ class CoshInfo : public ActivationOther { | |||
| public: | |||
| CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CoshCost>()) {} | |||
| ~CoshInfo() override = default; | |||
| }; | |||
| @@ -168,7 +169,7 @@ class CeilInfo : public ActivationOther { | |||
| public: | |||
| CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CeilCost>()) {} | |||
| ~CeilInfo() override = default; | |||
| }; | |||
| @@ -176,7 +177,7 @@ class AtanhInfo : public ActivationOther { | |||
| public: | |||
| AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanhCost>()) {} | |||
| ~AtanhInfo() override = default; | |||
| }; | |||
| @@ -184,7 +185,7 @@ class AtanInfo : public ActivationOther { | |||
| public: | |||
| AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanCost>()) {} | |||
| ~AtanInfo() override = default; | |||
| }; | |||
| @@ -192,7 +193,7 @@ class AsinInfo : public ActivationOther { | |||
| public: | |||
| AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinCost>()) {} | |||
| ~AsinInfo() override = default; | |||
| }; | |||
| @@ -200,7 +201,7 @@ class AsinhInfo : public ActivationOther { | |||
| public: | |||
| AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinhCost>()) {} | |||
| ~AsinhInfo() override = default; | |||
| }; | |||
| @@ -208,14 +209,14 @@ class AcoshInfo : public ActivationOther { | |||
| public: | |||
| AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AcoshCost>()) {} | |||
| ~AcoshInfo() override = default; | |||
| }; | |||
| class ErfInfo : public ActivationOther { | |||
| public: | |||
| ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfCost>()) {} | |||
| ~ErfInfo() override = default; | |||
| }; | |||
| @@ -223,7 +224,7 @@ class ErfcInfo : public ActivationOther { | |||
| public: | |||
| ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfcCost>()) {} | |||
| ~ErfcInfo() override = default; | |||
| }; | |||
| @@ -231,7 +232,7 @@ class ZerosLikeInfo : public ActivationOther { | |||
| public: | |||
| ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ZerosLikeCost>()) {} | |||
| ~ZerosLikeInfo() override = default; | |||
| }; | |||
| @@ -239,7 +240,7 @@ class OnesLikeInfo : public ActivationOther { | |||
| public: | |||
| OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<OnesLikeCost>()) {} | |||
| ~OnesLikeInfo() override = default; | |||
| }; | |||
| @@ -247,7 +248,7 @@ class BesselI0eInfo : public ActivationOther { | |||
| public: | |||
| BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI0eCost>()) {} | |||
| ~BesselI0eInfo() override = default; | |||
| }; | |||
| @@ -255,7 +256,7 @@ class BesselI1eInfo : public ActivationOther { | |||
| public: | |||
| BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI1eCost>()) {} | |||
| ~BesselI1eInfo() override = default; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo { | |||
| public: | |||
| GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {} | |||
| ~GetNextInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -33,7 +33,7 @@ class L2NormalizeInfo : public Activation { | |||
| public: | |||
| L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||
| : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<L2NormalizeCost>()) {} | |||
| ~L2NormalizeInfo() override = default; | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| @@ -40,7 +40,7 @@ class LayerNormInfo : public OperatorInfo { | |||
| public: | |||
| LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>(true)), | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>()), | |||
| begin_norm_axis_(0) {} | |||
| ~LayerNormInfo() override = default; | |||
| @@ -36,8 +36,7 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||
| public: | |||
| SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, | |||
| std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {} | |||
| ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| @@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo { | |||
| public: | |||
| MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {} | |||
| ~MatMulBase() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo { | |||
| public: | |||
| OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {} | |||
| ~OneHotInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| @@ -1204,6 +1204,20 @@ int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { | |||
| } else { | |||
| is_output_parameter_involve_ = 0; | |||
| } | |||
| // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in | |||
| // calculating 'inputs_in_memory' and 'output_in_memory', respectively. | |||
| operator_cost()->set_is_parameter_involve(is_parameter_involve_); | |||
| operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); | |||
| // Calculating 'output_in_memory' | |||
| operator_cost()->CalculateOutputInMemory(); | |||
| // Calculating 'inputs_in_memory' | |||
| std::map<size_t, bool> input_in_memory; | |||
| for (auto &p_edge : prev_edges) { | |||
| auto input_index = p_edge->next_op_input_index(); | |||
| auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory(); | |||
| input_in_memory.emplace(std::make_pair(input_index, is_in_mem)); | |||
| } | |||
| operator_cost()->CalculateInputsInMemory(input_in_memory); | |||
| return is_output_parameter_involve_; | |||
| } | |||
| @@ -1220,14 +1234,10 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) { | |||
| } | |||
| Status OperatorInfo::CalculateMemoryCost() { | |||
| // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to | |||
| // calculate memory cost. | |||
| if (is_parameter_involve_.size() != is_parameter_.size()) { | |||
| MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; | |||
| return FAILED; | |||
| } | |||
| operator_cost()->set_is_parameter_involve(is_parameter_involve_); | |||
| operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); | |||
| // Set the memory cost in the 'strategy_cost_' | |||
| for (auto &swc : strategy_cost_) { | |||
| auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); | |||
| @@ -33,7 +33,7 @@ class PackInfo : public OperatorInfo { | |||
| public: | |||
| PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>()) {} | |||
| ~PackInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo { | |||
| public: | |||
| PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {} | |||
| ~PReLUInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| @@ -39,7 +39,7 @@ class RangeInfo : public OperatorInfo { | |||
| public: | |||
| RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(true)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RangeCost>()) {} | |||
| ~RangeInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -33,8 +33,8 @@ namespace parallel { | |||
| class ReduceMethod : public OperatorInfo { | |||
| public: | |||
| ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {} | |||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~ReduceMethod() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -62,7 +62,7 @@ class ReduceMaxInfo : public ReduceMethod { | |||
| public: | |||
| ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMaxCost>()) { | |||
| reduce_method_ = REDUCE_OP_MAX; | |||
| } | |||
| @@ -73,7 +73,7 @@ class ArgMaxWithValueInfo : public ReduceMethod { | |||
| public: | |||
| ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgMaxWithValueCost>()) { | |||
| reduce_method_ = REDUCE_OP_MAX; | |||
| } | |||
| @@ -105,9 +105,7 @@ class ReduceMeanInfo : public ReduceMethod { | |||
| public: | |||
| ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||
| set_cost(std::make_shared<ReduceMeanCost>()); | |||
| } | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMeanCost>()) {} | |||
| ~ReduceMeanInfo() override = default; | |||
| @@ -119,7 +117,7 @@ class ReduceSumInfo : public ReduceMethod { | |||
| public: | |||
| ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceSumCost>()) { | |||
| reduce_method_ = REDUCE_OP_SUM; | |||
| } | |||
| @@ -130,7 +128,7 @@ class ReduceMinInfo : public ReduceMethod { | |||
| public: | |||
| ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMinCost>()) { | |||
| reduce_method_ = REDUCE_OP_MIN; | |||
| } | |||
| @@ -37,7 +37,7 @@ class ReLUV2Info : public OperatorInfo { | |||
| public: | |||
| ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUV2Cost>()) {} | |||
| ~ReLUV2Info() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo { | |||
| public: | |||
| ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)), | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()), | |||
| dev_num_(0), | |||
| pre_operator_index_(0), | |||
| next_operator_index_(0), | |||
| @@ -34,7 +34,7 @@ class SliceInfo : public OperatorInfo { | |||
| public: | |||
| SliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>(false)), | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>()), | |||
| slice_axis_(-1) {} | |||
| ~SliceInfo() override = default; | |||
| @@ -31,7 +31,7 @@ class SplitInfo : public OperatorInfo { | |||
| public: | |||
| SplitInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SplitCost>()) {} | |||
| ~SplitInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -273,6 +273,8 @@ Status StridedSliceInfo::GenerateStrategies(int64_t stage_id) { | |||
| PrintStrategy(sp); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << name() << ", finishing GenerateStrategies()."; | |||
| return SUCCESS; | |||
| } | |||
| @@ -34,7 +34,7 @@ class StridedSliceInfo : public OperatorInfo { | |||
| public: | |||
| StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {} | |||
| ~StridedSliceInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -41,7 +41,7 @@ class TensorDotInfo : public OperatorInfo { | |||
| public: | |||
| TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorDotCost>()) {} | |||
| ~TensorDotInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -34,7 +34,7 @@ class TileInfo : public OperatorInfo { | |||
| public: | |||
| TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>()) {} | |||
| ~TileInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { | |||
| public: | |||
| TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, | |||
| const std::string &name = IDENTITY_INFO) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {} | |||
| ~TmpIdentityInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo { | |||
| public: | |||
| TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {} | |||
| ~TransposeInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| @@ -32,7 +32,7 @@ class UniqueInfo : public OperatorInfo { | |||
| public: | |||
| UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniqueCost>()) {} | |||
| ~UniqueInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -82,7 +82,7 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | |||
| public: | |||
| UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {} | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {} | |||
| ~UnsortedSegmentMaxInfo() override = default; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| @@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo { | |||
| public: | |||
| VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {} | |||
| ~VirtualDatasetInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| @@ -85,11 +85,11 @@ class TestActivationCost : public UT::Common { | |||
| TestActivationCost() {} | |||
| void SetUp(); | |||
| void TearDown(); | |||
| ActivationCost ac_cost_; | |||
| ActivationInfoCost ac_cost_; | |||
| }; | |||
| void TestActivationCost::SetUp() { | |||
| ac_cost_ = ActivationCost(); | |||
| ac_cost_ = ActivationInfoCost(); | |||
| RankList dev_list; | |||
| for (int32_t i = 0; i < 1050; i++) { | |||