From: @xiaoda_zh Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.2.0-rc1
| @@ -27,6 +27,7 @@ void OperatorCost::set_is_parameter(const std::vector<bool> &is_parameter) { is_ | |||||
| void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) { | void OperatorCost::set_is_parameter_involve(const std::vector<bool> &is_parameter_inv) { | ||||
| is_parameter_involve_ = is_parameter_inv; | is_parameter_involve_ = is_parameter_inv; | ||||
| is_inputs_should_in_memory_ = std::vector<bool>(is_parameter_involve_.size(), false); | |||||
| } | } | ||||
| void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; } | void OperatorCost::set_output_parameter_involve(int64_t output_para) { output_parameter_involve_ = output_para; } | ||||
| @@ -41,27 +42,28 @@ void OperatorCost::set_output_critical(int64_t critical) { is_outputs_critical_ | |||||
| double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs) const { | const std::vector<TensorInfo> &outputs) const { | ||||
| return GetInputMemoryCost(inputs, outputs) + GetOutputMemoryCost(inputs, outputs); | |||||
| } | |||||
| double OperatorCost::GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &) const { | |||||
| double result = 0.0; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (is_inputs_should_in_memory_[i]) { | |||||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | |||||
| double OperatorCost::GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, | |||||
| const std::vector<TensorInfo> &outputs) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| if (output_parameter_involve_ == 1) { | |||||
| if (is_output_should_in_memory_) { | |||||
| // When this operator has multiple outputs, they all contributes to the memory. | // When this operator has multiple outputs, they all contributes to the memory. | ||||
| for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
| result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]); | result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]); | ||||
| } | } | ||||
| bool is_any_para_inv = | |||||
| std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); | |||||
| if (is_any_para_inv) { | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (is_parameter_[i]) { | |||||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||||
| } else if (inputs_related_ && (!is_parameter_involve_[i])) { | |||||
| // When the inputs of this operator are related, and they are not parameter-involved, then they are included | |||||
| // in the memory cost. | |||||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -166,16 +168,43 @@ double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inp | |||||
| return result; | return result; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void MatMulCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Taking account of input | |||||
| void MatMulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (is_parameter_[1]) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Return the per device communication cost in the forward phase. | // Return the per device communication cost in the forward phase. | ||||
| double ActivationCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| double CastCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | |||||
| // ReLU is the element-wise operator, thus it does not need communication in the forward phase | // ReLU is the element-wise operator, thus it does not need communication in the forward phase | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Return the per device communication cost in the backward phase. | // Return the per device communication cost in the backward phase. | ||||
| double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double CastCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| if (is_parameter_[0]) { | if (is_parameter_[0]) { | ||||
| TensorInfo input1 = inputs[0]; | TensorInfo input1 = inputs[0]; | ||||
| @@ -196,8 +225,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs | |||||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| double CastCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| Shape input0_slice_shape = input0.slice_shape(); | Shape input0_slice_shape = input0.slice_shape(); | ||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | ||||
| @@ -205,11 +234,33 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo> & | |||||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| double CastCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void CastCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Not taking account of input | |||||
| void CastCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| // Taking account of output | |||||
| void SqrtCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||||
| // Taking account of input | |||||
| void GeLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Return the per device communication cost in the forward phase. | // Return the per device communication cost in the forward phase. | ||||
| double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | double SoftmaxCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, | ||||
| int64_t) const { | int64_t) const { | ||||
| @@ -259,6 +310,81 @@ double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::para | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Taking account of output | |||||
| void SoftmaxCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||||
| // Not taking account of input | |||||
| void SoftmaxCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| // Not taking account of output | |||||
| void PackCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Not taking account of input | |||||
| void PackCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| // Not taking account of output | |||||
| void TileCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Taking account of input | |||||
| void TileCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| } | |||||
| // Not taking account of output | |||||
| void BroadcastToCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void BroadcastToCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| // Taking account of input | |||||
| void ReLU6Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Taking account of input | |||||
| void TransposeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calulating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| } | |||||
| // return the per device communication cost in the forward phase. | // return the per device communication cost in the forward phase. | ||||
| double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &, | double TmpIdentityCost::GetForwardCommCost(const std::vector<mindspore::parallel::TensorInfo> &, | ||||
| const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const { | const std::vector<mindspore::parallel::TensorInfo> &, int64_t) const { | ||||
| @@ -288,9 +414,12 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore:: | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Return the per device PEAK memory cost contributed by this operator in a training iteration. | |||||
| double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const { | |||||
| return 0.0; | |||||
| // Not taking account of output | |||||
| void TmpIdentityCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Not taking account of input | |||||
| void TmpIdentityCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | } | ||||
| double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs, | double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo> &inputs, | ||||
| @@ -334,6 +463,42 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo> &inp | |||||
| return result; | return result; | ||||
| } | } | ||||
| void BatchParallelCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void BatchParallelCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (is_parameter_[1]) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() { | |||||
| is_output_should_in_memory_ = is_parameter_involve_[0]; | |||||
| } | |||||
| void SparseSoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory( | |||||
| const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| // return the per device communication cost in the forward phase. | // return the per device communication cost in the forward phase. | ||||
| double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | ||||
| // prelu does not need communication in the forward phase | // prelu does not need communication in the forward phase | ||||
| @@ -401,6 +566,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parall | |||||
| return result; | return result; | ||||
| } | } | ||||
| void PReLUCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void PReLUCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of both 'x' and 'y'; | |||||
| // when calculating 'dy', taking account of both 'x' and 'y' | |||||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // return the per device communication cost in the forward phase. | // return the per device communication cost in the forward phase. | ||||
| double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | double OneHotCost::GetForwardCommCost(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &, int64_t) const { | ||||
| // onehot does not need communication in the forward phase | // onehot does not need communication in the forward phase | ||||
| @@ -430,6 +610,17 @@ double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo> &, c | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void OneHotCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Not taking account of input | |||||
| void OneHotCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| is_inputs_should_in_memory_[3] = is_parameter_[3]; | |||||
| } | |||||
| // return the per device communication cost in the forward phase. | // return the per device communication cost in the forward phase. | ||||
| double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &, | double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector<TensorInfo> &, | ||||
| const std::vector<TensorInfo> &, int64_t) const { | const std::vector<TensorInfo> &, int64_t) const { | ||||
| @@ -463,6 +654,16 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std:: | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Taking account of output | |||||
| void SoftmaxCrossEntropyWithLogitsCost::CalculateOutputInMemory() { | |||||
| is_output_should_in_memory_ = is_parameter_involve_[0]; | |||||
| } | |||||
| void SoftmaxCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| // return the per device communication cost in the forward phase. | // return the per device communication cost in the forward phase. | ||||
| double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const { | int64_t stage_id) const { | ||||
| @@ -524,16 +725,23 @@ double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::para | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double ArithmeticCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| void ReshapeCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void ReshapeCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| double SubCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t) const { | |||||
| double result; | double result; | ||||
| result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) + | result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) + | ||||
| ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]); | ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]); | ||||
| return result; | return result; | ||||
| } | } | ||||
| double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, | |||||
| const std::vector<TensorInfo> &, int64_t stage_id) const { | |||||
| double SubCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| @@ -567,8 +775,8 @@ double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo> | |||||
| return result; | return result; | ||||
| } | } | ||||
| double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double SubCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| @@ -603,6 +811,273 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs | |||||
| return result; | return result; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void SubCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Not taking account of input | |||||
| void SubCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| // Taking account of input | |||||
| void MulCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (is_parameter_[1]) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Taking account of output | |||||
| void DivCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } | |||||
| // Taking account of input | |||||
| void DivCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| // When calculating 'dy', taking account of 'y' | |||||
| if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Taking account of input | |||||
| void ModCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', not taking account of 'x' and 'y' | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| // When calculating 'dy', taking account of 'x' and 'y' | |||||
| if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void PowCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } | |||||
| void PowCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of both 'x' and 'power' | |||||
| if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| // When calculating 'dpower', taking account of 'x' | |||||
| if (is_parameter_[1]) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void AssignCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'x' | |||||
| if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| // When calculating 'dy', not taking account of 'x' and 'y' | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| void SigmoidCrossEntropyWithLogitsCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of both 'x' and 'y' | |||||
| if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| // When calculating 'dy', not taking account of 'x' and 'y' | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| } | |||||
| void Atan2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of both 'x' and 'y'; when calculating 'dy', taking account of both 'x' and | |||||
| // 'y' | |||||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void DivNoNanCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[1]; } | |||||
| void DivNoNanCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| // When calculating 'dy', taking account of 'y' | |||||
| if (is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void MaximumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of both 'x' and 'y'; | |||||
| // when calculating 'dy', taking account of both 'x' and 'y' | |||||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void SliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' and 'z' | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[2]) { | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| } | |||||
| void StridedSliceCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y', 'z' and 'w' | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) { | |||||
| is_inputs_should_in_memory_[3] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(3) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(3))) { | |||||
| is_inputs_should_in_memory_[3] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[2]) { | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[3]) { | |||||
| is_inputs_should_in_memory_[3] = is_parameter_[3]; | |||||
| } | |||||
| } | |||||
| void DropOutDoMaskCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void DropOutDoMaskCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) { | bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_id) { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| @@ -612,8 +1087,8 @@ bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int64_t stage_ | |||||
| return (total_device_num == LongToSize(strategy0)); | return (total_device_num == LongToSize(strategy0)); | ||||
| } | } | ||||
| double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, | |||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||||
| double ReduceSumCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||||
| int64_t stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| TensorInfo output0 = outputs[0]; | TensorInfo output0 = outputs[0]; | ||||
| @@ -634,8 +1109,8 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector<TensorInfo> &input | |||||
| return result; | return result; | ||||
| } | } | ||||
| double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double ReduceSumCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | |||||
| int64_t stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| if (is_parameter_[0]) { | if (is_parameter_[0]) { | ||||
| TensorInfo input_tensor_info = inputs[0]; | TensorInfo input_tensor_info = inputs[0]; | ||||
| @@ -657,8 +1132,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo> &inpu | |||||
| return result; | return result; | ||||
| } | } | ||||
| double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||||
| double ReduceSumCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| TensorInfo output0 = outputs[0]; | TensorInfo output0 = outputs[0]; | ||||
| @@ -679,6 +1154,30 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo> | |||||
| return result; | return result; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void ReduceSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void ReduceSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| // Not taking account of 'y' | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| } | |||||
| double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | ||||
| double result = 0.0; | double result = 0.0; | ||||
| @@ -701,6 +1200,42 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo> & | |||||
| return result; | return result; | ||||
| } | } | ||||
| void ReduceMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||||
| void ReduceMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| // In this case, if 'y' is not be calculated by the previous operator, then 'y' should be included here. | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| // Not taking account of 'y' | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| } | |||||
| void ArgMaxWithValueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||||
| void ArgMaxWithValueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'x' | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | ||||
| int64_t) const { | int64_t) const { | ||||
| if (inputs.empty()) { | if (inputs.empty()) { | ||||
| @@ -760,6 +1295,52 @@ double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo> &, | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void GatherV2Cost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void GatherV2Cost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' and 'z' | |||||
| if (is_parameter_[0]) { | |||||
| // 'x' is parameter, so it should be in memory. | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[2]) { | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| } | |||||
| void GetNextCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void GetNextCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| if (is_inputs_should_in_memory_.size() == 0) { | |||||
| return; | |||||
| } | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| void UniqueCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||||
| void UniqueCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | double LayerNormCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &, | ||||
| int64_t stage_id) const { | int64_t stage_id) const { | ||||
| double result = 0.0; | double result = 0.0; | ||||
| @@ -808,6 +1389,24 @@ double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &i | |||||
| return result; | return result; | ||||
| } | } | ||||
| void LayerNormCost::CalculateOutputInMemory() { | |||||
| is_output_should_in_memory_ = is_parameter_involve_[0] || is_parameter_involve_[1] || is_parameter_involve_[2]; | |||||
| } | |||||
| void LayerNormCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of both 'x' and 'y' | |||||
| // When calculating 'dy', taking account of both 'x' and 'y' | |||||
| if (is_parameter_involve_[0] || is_parameter_involve_[1]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double UniqueCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const { | int64_t stage_id) const { | ||||
| return 0.0; | return 0.0; | ||||
| @@ -924,6 +1523,12 @@ double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector< | |||||
| return result; | return result; | ||||
| } | } | ||||
| void UniformCandidateSamplerCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| void UniformCandidateSamplerCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| is_inputs_should_in_memory_[0] = is_parameter_[0]; | |||||
| } | |||||
| double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | ||||
| double result = 0.0; | double result = 0.0; | ||||
| @@ -1019,6 +1624,29 @@ double UnsortedSegmentSumCost::GetForwardComputationCost(const std::vector<Tenso | |||||
| return result; | return result; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void UnsortedSegmentSumCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Taking account of input | |||||
| void UnsortedSegmentSumCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'y' | |||||
| if (is_parameter_[0]) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } else if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, | double UnsortedSegmentMinCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | const std::vector<TensorInfo> &outputs, int64_t stage_id) const { | ||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| @@ -1078,5 +1706,40 @@ double UnsortedSegmentMinCost::GetForwardComputationCost(const std::vector<Tenso | |||||
| ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin | ListProduct(output_slice_shape) * static_cast<double>(outputs_type_lengths_[0]); // ReduceMin | ||||
| return result; | return result; | ||||
| } | } | ||||
| // Taking account of output | |||||
| void UnsortedSegmentMinCost::CalculateOutputInMemory() { is_output_should_in_memory_ = is_parameter_involve_[0]; } | |||||
| // Taking account of input | |||||
| void UnsortedSegmentMinCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| // When calculating 'dx', taking account of 'x', 'y' and 'z' | |||||
| if (is_parameter_involve_[0]) { | |||||
| if ((prev_output_in_mem.find(0) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(0))) { | |||||
| is_inputs_should_in_memory_[0] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(1) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(1))) { | |||||
| is_inputs_should_in_memory_[1] = true; | |||||
| } | |||||
| if ((prev_output_in_mem.find(2) == prev_output_in_mem.end()) || (!prev_output_in_mem.at(2))) { | |||||
| is_inputs_should_in_memory_[2] = true; | |||||
| } | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[1]) { | |||||
| is_inputs_should_in_memory_[1] = is_parameter_[1]; | |||||
| } | |||||
| if (!is_inputs_should_in_memory_[2]) { | |||||
| is_inputs_should_in_memory_[2] = is_parameter_[2]; | |||||
| } | |||||
| } | |||||
| // Not taking account of output | |||||
| void VirtualDatasetCost::CalculateOutputInMemory() { is_output_should_in_memory_ = false; } | |||||
| // Not taking account of input | |||||
| void VirtualDatasetCost::CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) { | |||||
| for (size_t i = 0; i < is_inputs_should_in_memory_.size(); ++i) { | |||||
| is_inputs_should_in_memory_[i] = is_parameter_[i]; | |||||
| } | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "frontend/parallel/device_manager.h" | #include "frontend/parallel/device_manager.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | #include "frontend/parallel/tensor_layout/tensor_info.h" | ||||
| @@ -47,16 +48,7 @@ double ListProduct(std::vector<T> vec) { | |||||
| // entries timing the length of each entry's data type | // entries timing the length of each entry's data type | ||||
| class OperatorCost { | class OperatorCost { | ||||
| public: | public: | ||||
| explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { | |||||
| // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | |||||
| for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | |||||
| is_parameter_.push_back(false); | |||||
| is_parameter_involve_.push_back(false); | |||||
| inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||||
| outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||||
| } | |||||
| } | |||||
| OperatorCost() : inputs_related_(false) { | |||||
| OperatorCost() { | |||||
| // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | ||||
| for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | ||||
| is_parameter_.push_back(false); | is_parameter_.push_back(false); | ||||
| @@ -89,10 +81,17 @@ class OperatorCost { | |||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0; | const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0; | ||||
| virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, | virtual double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0; | const std::vector<TensorInfo> &outputs, int64_t stage_id) const = 0; | ||||
| virtual void CalculateOutputInMemory() = 0; | |||||
| virtual void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) = 0; | |||||
| bool is_output_in_memory() const { return is_output_should_in_memory_; } | |||||
| // per device PEAK memory cost in a training iteration | // per device PEAK memory cost in a training iteration | ||||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), | |||||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-involved), | |||||
| // plus necessary inputs. | // plus necessary inputs. | ||||
| virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | ||||
| // Contributing the input part for 'GetMemoryCost' | |||||
| double GetInputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | |||||
| // Contributing the output part for 'GetMemoryCost' | |||||
| double GetOutputMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const; | |||||
| // per device memory cost in a inference phase | // per device memory cost in a inference phase | ||||
| double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const; | double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const; | ||||
| @@ -101,25 +100,25 @@ class OperatorCost { | |||||
| // pre-operator that has parameters as input. | // pre-operator that has parameters as input. | ||||
| std::vector<bool> is_parameter_involve_; | std::vector<bool> is_parameter_involve_; | ||||
| int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | int64_t output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | ||||
| // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while | |||||
| // Mul's two inputs are dependent (related). | |||||
| bool inputs_related_; | |||||
| // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | |||||
| // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | |||||
| std::vector<bool> is_parameter_; | std::vector<bool> is_parameter_; | ||||
| // for each input and output, the followings record the number of bytes of each element | |||||
| // Whether the input should keep in memory in training phase. It depends on the operator and the operator's | |||||
| // previous operators. | |||||
| std::vector<bool> is_inputs_should_in_memory_; | |||||
| // Whether the output should keep in memory in training phase. It depends on 'is_parameter_involve_' and the operator. | |||||
| bool is_output_should_in_memory_ = false; | |||||
| // For each input and output, the followings record the number of bytes of each element | |||||
| std::vector<size_t> inputs_type_lengths_; | std::vector<size_t> inputs_type_lengths_; | ||||
| std::vector<size_t> outputs_type_lengths_; | std::vector<size_t> outputs_type_lengths_; | ||||
| // Whether the output is critical, which means that this output is included in calculating peak memory cost | // Whether the output is critical, which means that this output is included in calculating peak memory cost | ||||
| // in the inference phase. | // in the inference phase. | ||||
| int64_t is_outputs_critical_ = -1; | int64_t is_outputs_critical_ = -1; | ||||
| }; | }; | ||||
| using OperatorCostPtr = std::shared_ptr<OperatorCost>; | using OperatorCostPtr = std::shared_ptr<OperatorCost>; | ||||
| class MatMulCost : public OperatorCost { | class MatMulCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| MatMulCost() : OperatorCost(true) {} | |||||
| MatMulCost() : OperatorCost() {} | |||||
| ~MatMulCost() override = default; | ~MatMulCost() override = default; | ||||
| // per device communication cost | // per device communication cost | ||||
| @@ -141,14 +140,15 @@ class MatMulCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| void CalculateOutputInMemory() override; | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using MatMulCostPtr = std::shared_ptr<MatMulCost>; | |||||
| using TensorDotCost = MatMulCost; | |||||
| class ActivationCost : public OperatorCost { | |||||
| class CastCost : public OperatorCost { | |||||
| public: | public: | ||||
| explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| ActivationCost() : OperatorCost(false) {} | |||||
| ~ActivationCost() override = default; | |||||
| CastCost() : OperatorCost() {} | |||||
| ~CastCost() override = default; | |||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override { | int64_t stage_id) const override { | ||||
| @@ -166,21 +166,95 @@ class ActivationCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using RepeatElementsCost = CastCost; | |||||
| using NegCost = CastCost; | |||||
| using ExpandDimsCost = CastCost; | |||||
| using SqueezeCost = CastCost; | |||||
| using ConcatCost = CastCost; | |||||
| using LogicalNotCost = CastCost; | |||||
| using SignCost = CastCost; | |||||
| using FloorCost = CastCost; | |||||
| using RoundCost = CastCost; | |||||
| using CeilCost = CastCost; | |||||
| using ZerosLikeCost = CastCost; | |||||
| using OnesLikeCost = CastCost; | |||||
| using RangeCost = CastCost; | |||||
| using SplitCost = CastCost; | |||||
| class SqrtCost : public CastCost { | |||||
| public: | |||||
| SqrtCost() : CastCost() {} | |||||
| ~SqrtCost() override = default; | |||||
| // Taking account of output, not taking accounting of input | |||||
| void CalculateOutputInMemory() override; | |||||
| }; | }; | ||||
| using ActivationCostPtr = std::shared_ptr<ActivationCost>; | |||||
| using TransposeCost = ActivationCost; | |||||
| using TransposeCostPtr = std::shared_ptr<TransposeCost>; | |||||
| using StridedSliceCost = ActivationCost; | |||||
| using StridedSliceCostPtr = std::shared_ptr<StridedSliceCost>; | |||||
| using SliceCost = ActivationCost; | |||||
| using SliceCostPtr = std::shared_ptr<SliceCost>; | |||||
| using SplitCost = ActivationCost; | |||||
| using SplitCostPtr = std::shared_ptr<SplitCost>; | |||||
| using TanhCost = SqrtCost; | |||||
| using EluCost = SqrtCost; | |||||
| using ReLUCost = SqrtCost; | |||||
| using SigmoidCost = SqrtCost; | |||||
| using ReciprocalCost = | |||||
| SqrtCost; // The derivative of 'Reciprocal' is different on 'Ascend' and 'GPU'. Here, 'Ascend' is chosen | |||||
| using InvCost = SqrtCost; | |||||
| using RsqrtCost = SqrtCost; | |||||
| using AsinhCost = SqrtCost; | |||||
| using AcoshCost = SqrtCost; | |||||
| using ReLUV2Cost = SqrtCost; | |||||
| class ReLU6Cost : public CastCost { | |||||
| public: | |||||
| ReLU6Cost() : CastCost() {} | |||||
| ~ReLU6Cost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using SoftsignCost = ReLU6Cost; | |||||
| using SoftplusCost = ReLU6Cost; | |||||
| using SquareCost = ReLU6Cost; | |||||
| using ExpCost = ReLU6Cost; | |||||
| using LogCost = ReLU6Cost; | |||||
| using CosCost = ReLU6Cost; | |||||
| using ACosCost = ReLU6Cost; | |||||
| using AbsCost = ReLU6Cost; | |||||
| using TanCost = ReLU6Cost; | |||||
| using SinCost = ReLU6Cost; | |||||
| using SinhCost = ReLU6Cost; | |||||
| using Log1pCost = ReLU6Cost; | |||||
| using Expm1Cost = ReLU6Cost; | |||||
| using CoshCost = ReLU6Cost; | |||||
| using AtanhCost = ReLU6Cost; | |||||
| using AtanCost = ReLU6Cost; | |||||
| using AsinCost = ReLU6Cost; | |||||
| using ErfCost = ReLU6Cost; | |||||
| using ErfcCost = ReLU6Cost; | |||||
| using ActivationInfoCost = ReLU6Cost; | |||||
| class TransposeCost : public CastCost { | |||||
| public: | |||||
| TransposeCost() : CastCost() {} | |||||
| ~TransposeCost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class GeLUCost : public SqrtCost { | |||||
| public: | |||||
| GeLUCost() : SqrtCost() {} | |||||
| ~GeLUCost() override = default; | |||||
| // Taking account of input and output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using BesselI0eCost = GeLUCost; | |||||
| using BesselI1eCost = GeLUCost; | |||||
| using L2NormalizeCost = GeLUCost; | |||||
| class SoftmaxCost : public OperatorCost { | class SoftmaxCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| SoftmaxCost() : OperatorCost(false) {} | |||||
| SoftmaxCost() : OperatorCost() {} | |||||
| ~SoftmaxCost() override = default; | ~SoftmaxCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -199,21 +273,45 @@ class SoftmaxCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t) const override; | int64_t) const override; | ||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class TileCost : public SoftmaxCost { | |||||
| public: | |||||
| TileCost() : SoftmaxCost() {} | |||||
| ~TileCost() override = default; | |||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class PackCost : public SoftmaxCost { | |||||
| public: | |||||
| PackCost() : SoftmaxCost() {} | |||||
| ~PackCost() override = default; | |||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class BroadcastToCost : public SoftmaxCost { | |||||
| public: | |||||
| BroadcastToCost() : SoftmaxCost() {} | |||||
| ~BroadcastToCost() override = default; | |||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; | |||||
| using TileCost = SoftmaxCost; | |||||
| using TileCostPtr = std::shared_ptr<TileCost>; | |||||
| using PackCost = TileCost; | |||||
| using PackCostPtr = std::shared_ptr<PackCost>; | |||||
| using ConcatCost = TileCost; | |||||
| using ConcatCostPtr = std::shared_ptr<ConcatCost>; | |||||
| using BroadcastToCost = SoftmaxCost; | |||||
| using BroadcastToCostPtr = std::shared_ptr<BroadcastToCost>; | |||||
| class TmpIdentityCost : public OperatorCost { | class TmpIdentityCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| TmpIdentityCost() : OperatorCost(false) {} | |||||
| TmpIdentityCost() : OperatorCost() {} | |||||
| ~TmpIdentityCost() override = default; | ~TmpIdentityCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -232,15 +330,16 @@ class TmpIdentityCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // per device PEAK memory cost in a training iteration | |||||
| double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override; | |||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; | using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; | ||||
| class BatchParallelCost : public OperatorCost { | class BatchParallelCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| BatchParallelCost() : OperatorCost(false) {} | |||||
| BatchParallelCost() : OperatorCost() {} | |||||
| ~BatchParallelCost() override = default; | ~BatchParallelCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -259,13 +358,25 @@ class BatchParallelCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class SparseSoftmaxCrossEntropyWithLogitsCost : public BatchParallelCost { | |||||
| public: | |||||
| SparseSoftmaxCrossEntropyWithLogitsCost() : BatchParallelCost() {} | |||||
| ~SparseSoftmaxCrossEntropyWithLogitsCost() override = default; | |||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>; | |||||
| class VirtualDatasetCost : public OperatorCost { | class VirtualDatasetCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| VirtualDatasetCost() : OperatorCost(false) {} | |||||
| VirtualDatasetCost() : OperatorCost() {} | |||||
| ~VirtualDatasetCost() override = default; | ~VirtualDatasetCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -290,17 +401,15 @@ class VirtualDatasetCost : public OperatorCost { | |||||
| int64_t) const override { | int64_t) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // per device PEAK memory cost in a training iteration | |||||
| double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const override { | |||||
| return 0.0; | |||||
| } | |||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>; | |||||
| class GeneratorBaseCost : public OperatorCost { | class GeneratorBaseCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| GeneratorBaseCost() : OperatorCost(false) {} | |||||
| GeneratorBaseCost() : OperatorCost() {} | |||||
| ~GeneratorBaseCost() override = default; | ~GeneratorBaseCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -332,8 +441,7 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>; | |||||
| class PReLUCost : public OperatorCost { | class PReLUCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| PReLUCost() : OperatorCost(true) {} | |||||
| PReLUCost() : OperatorCost() {} | |||||
| ~PReLUCost() override = default; | ~PReLUCost() override = default; | ||||
| // per device communication cost | // per device communication cost | ||||
| @@ -355,13 +463,16 @@ class PReLUCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using PReLUCostPtr = std::shared_ptr<PReLUCost>; | using PReLUCostPtr = std::shared_ptr<PReLUCost>; | ||||
| class OneHotCost : public OperatorCost { | class OneHotCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| OneHotCost() : OperatorCost(true) {} | |||||
| OneHotCost() : OperatorCost() {} | |||||
| ~OneHotCost() override = default; | ~OneHotCost() override = default; | ||||
| // per device communication cost | // per device communication cost | ||||
| @@ -383,13 +494,16 @@ class OneHotCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using OneHotCostPtr = std::shared_ptr<OneHotCost>; | using OneHotCostPtr = std::shared_ptr<OneHotCost>; | ||||
| class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} | |||||
| SoftmaxCrossEntropyWithLogitsCost() : OperatorCost() {} | |||||
| ~SoftmaxCrossEntropyWithLogitsCost() override = default; | ~SoftmaxCrossEntropyWithLogitsCost() override = default; | ||||
| // per device communication cost | // per device communication cost | ||||
| @@ -411,13 +525,15 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropyWithLogitsCost>; | |||||
| class ReshapeCost : public OperatorCost { | class ReshapeCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| ReshapeCost() : OperatorCost(true) {} | |||||
| ReshapeCost() : OperatorCost() {} | |||||
| ~ReshapeCost() override = default; | ~ReshapeCost() override = default; | ||||
| @@ -444,14 +560,17 @@ class ReshapeCost : public OperatorCost { | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using ReshapeCostPtr = std::shared_ptr<ReshapeCost>; | using ReshapeCostPtr = std::shared_ptr<ReshapeCost>; | ||||
| class ArithmeticCost : public OperatorCost { | |||||
| class SubCost : public OperatorCost { | |||||
| public: | public: | ||||
| explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| ArithmeticCost() : OperatorCost(false) {} | |||||
| ~ArithmeticCost() override = default; | |||||
| SubCost() : OperatorCost() {} | |||||
| ~SubCost() override = default; | |||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override { | int64_t stage_id) const override { | ||||
| @@ -470,16 +589,127 @@ class ArithmeticCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using TensorAddCost = SubCost; | |||||
| using FloorDivCost = SubCost; | |||||
| using AssignSubCost = SubCost; | |||||
| using AssignAddCost = SubCost; | |||||
| using LogicalAndCost = SubCost; | |||||
| using LogicalOrCost = SubCost; | |||||
| using BiasAddCost = SubCost; | |||||
| using EqualCost = SubCost; | |||||
| using ApproximateEqualCost = SubCost; | |||||
| using NotEqualCost = SubCost; | |||||
| using GreaterCost = SubCost; | |||||
| using GreaterEqualCost = SubCost; | |||||
| using LessCost = SubCost; | |||||
| using LessEqualCost = SubCost; | |||||
| class MulCost : public SubCost { | |||||
| public: | |||||
| MulCost() : SubCost() {} | |||||
| ~MulCost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>; | |||||
| using BiasAddCost = ArithmeticCost; | |||||
| using BiasAddCostPtr = std::shared_ptr<BiasAddCost>; | |||||
| class ReduceMethodCost : public OperatorCost { | |||||
| class DivCost : public SubCost { | |||||
| public: | public: | ||||
| explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| ReduceMethodCost() : OperatorCost(true) {} | |||||
| ~ReduceMethodCost() override = default; | |||||
| DivCost() : SubCost() {} | |||||
| ~DivCost() override = default; | |||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using ReadDivCost = DivCost; | |||||
| class ModCost : public SubCost { | |||||
| public: | |||||
| ModCost() : SubCost() {} | |||||
| ~ModCost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using FloorModCost = ModCost; | |||||
| class PowCost : public SubCost { | |||||
| public: | |||||
| PowCost() : SubCost() {} | |||||
| ~PowCost() override = default; | |||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class AssignCost : public SubCost { | |||||
| public: | |||||
| AssignCost() : SubCost() {} | |||||
| ~AssignCost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class SigmoidCrossEntropyWithLogitsCost : public SubCost { | |||||
| public: | |||||
| SigmoidCrossEntropyWithLogitsCost() : SubCost() {} | |||||
| ~SigmoidCrossEntropyWithLogitsCost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class Atan2Cost : public SubCost { | |||||
| public: | |||||
| Atan2Cost() : SubCost() {} | |||||
| ~Atan2Cost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class DivNoNanCost : public SubCost { | |||||
| public: | |||||
| DivNoNanCost() : SubCost() {} | |||||
| ~DivNoNanCost() override = default; | |||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class MaximumCost : public SubCost { | |||||
| public: | |||||
| MaximumCost() : SubCost() {} | |||||
| ~MaximumCost() override = default; | |||||
| // Taking account of input, not taking account of output | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using MinimumCost = MaximumCost; | |||||
| class SliceCost : public CastCost { | |||||
| public: | |||||
| SliceCost() : CastCost() {} | |||||
| ~SliceCost() override = default; | |||||
| // Not taking account of output, taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class StridedSliceCost : public CastCost { | |||||
| public: | |||||
| StridedSliceCost() : CastCost() {} | |||||
| ~StridedSliceCost() override = default; | |||||
| // Not taking account of output, taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class ReduceSumCost : public OperatorCost { | |||||
| public: | |||||
| ReduceSumCost() : OperatorCost() {} | |||||
| ~ReduceSumCost() override = default; | |||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override { | int64_t stage_id) const override { | ||||
| @@ -500,27 +730,50 @@ class ReduceMethodCost : public OperatorCost { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| void set_cross_batch(bool cb) { cross_batch_ = cb; } | void set_cross_batch(bool cb) { cross_batch_ = cb; } | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| protected: | protected: | ||||
| bool cross_batch_ = false; | bool cross_batch_ = false; | ||||
| }; | }; | ||||
| using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>; | |||||
| using ReduceMethodCost = ReduceSumCost; | |||||
| class ReduceMeanCost : public ReduceMethodCost { | |||||
| class ReduceMeanCost : public ReduceSumCost { | |||||
| public: | public: | ||||
| explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} | |||||
| ReduceMeanCost() : ReduceMethodCost(true) {} | |||||
| ReduceMeanCost() : ReduceSumCost() {} | |||||
| ~ReduceMeanCost() override = default; | ~ReduceMeanCost() override = default; | ||||
| double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| }; | }; | ||||
| using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>; | |||||
| class ReduceMinCost : public ReduceSumCost { | |||||
| public: | |||||
| ReduceMinCost() : ReduceSumCost() {} | |||||
| ~ReduceMinCost() override = default; | |||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using ReduceMaxCost = ReduceMinCost; | |||||
| class ArgMaxWithValueCost : public ReduceSumCost { | |||||
| public: | |||||
| ArgMaxWithValueCost() : ReduceSumCost() {} | |||||
| ~ArgMaxWithValueCost() override = default; | |||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| using ArgMinWithValueCost = ArgMaxWithValueCost; | |||||
| class GetNextCost : public OperatorCost { | class GetNextCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| GetNextCost() : OperatorCost(false) {} | |||||
| GetNextCost() : OperatorCost() {} | |||||
| ~GetNextCost() override = default; | ~GetNextCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -547,13 +800,17 @@ class GetNextCost : public OperatorCost { | |||||
| int64_t) const override { | int64_t) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using GetNextCostPtr = std::shared_ptr<GetNextCost>; | using GetNextCostPtr = std::shared_ptr<GetNextCost>; | ||||
| class DropOutCost : public OperatorCost { | |||||
| // For memory cost, taking account of output, not taking account of input | |||||
| class DropOutCost : public SqrtCost { | |||||
| public: | public: | ||||
| explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| DropOutCost() : OperatorCost(true) {} | |||||
| DropOutCost() : SqrtCost() {} | |||||
| ~DropOutCost() override = default; | ~DropOutCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -578,12 +835,19 @@ class DropOutCost : public OperatorCost { | |||||
| } | } | ||||
| }; | }; | ||||
| using DropOutCostPtr = std::shared_ptr<DropOutCost>; | |||||
| class DropOutDoMaskCost : public DropOutCost { | |||||
| public: | |||||
| DropOutDoMaskCost() : DropOutCost() {} | |||||
| ~DropOutDoMaskCost() override = default; | |||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | |||||
| class UnsortedSegmentSumCost : public OperatorCost { | class UnsortedSegmentSumCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit UnsortedSegmentSumCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| UnsortedSegmentSumCost() : OperatorCost(true) {} | |||||
| UnsortedSegmentSumCost() : OperatorCost() {} | |||||
| ~UnsortedSegmentSumCost() override = default; | ~UnsortedSegmentSumCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -602,14 +866,15 @@ class UnsortedSegmentSumCost : public OperatorCost { | |||||
| int64_t) const override { | int64_t) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using UnsortedSegmentSumCostPtr = std::shared_ptr<UnsortedSegmentSumCost>; | |||||
| class UnsortedSegmentMinCost : public OperatorCost { | class UnsortedSegmentMinCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit UnsortedSegmentMinCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| UnsortedSegmentMinCost() : OperatorCost(true) {} | |||||
| UnsortedSegmentMinCost() : OperatorCost() {} | |||||
| ~UnsortedSegmentMinCost() override = default; | ~UnsortedSegmentMinCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -628,14 +893,16 @@ class UnsortedSegmentMinCost : public OperatorCost { | |||||
| int64_t) const override { | int64_t) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using UnsortedSegmentMinCostPtr = std::shared_ptr<UnsortedSegmentMinCost>; | |||||
| using UnsortedSegmentMaxCost = UnsortedSegmentMinCost; | |||||
| class LayerNormCost : public OperatorCost { | class LayerNormCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| LayerNormCost() : OperatorCost(true) {} | |||||
| LayerNormCost() : OperatorCost() {} | |||||
| ~LayerNormCost() override = default; | ~LayerNormCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -656,14 +923,15 @@ class LayerNormCost : public OperatorCost { | |||||
| int64_t) const override { | int64_t) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using DropOutCostPtr = std::shared_ptr<DropOutCost>; | |||||
| class UniqueCost : public OperatorCost { | class UniqueCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit UniqueCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| UniqueCost() : OperatorCost(true) {} | |||||
| UniqueCost() : OperatorCost() {} | |||||
| ~UniqueCost() override = default; | ~UniqueCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -682,14 +950,15 @@ class UniqueCost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t) const override; | int64_t) const override; | ||||
| // Taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using UniqueCostPtr = std::shared_ptr<UniqueCost>; | |||||
| class UniformCandidateSamplerCost : public OperatorCost { | class UniformCandidateSamplerCost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit UniformCandidateSamplerCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| UniformCandidateSamplerCost() : OperatorCost(false) {} | |||||
| UniformCandidateSamplerCost() : OperatorCost() {} | |||||
| ~UniformCandidateSamplerCost() override = default; | ~UniformCandidateSamplerCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -714,14 +983,15 @@ class UniformCandidateSamplerCost : public OperatorCost { | |||||
| int64_t) const override { | int64_t) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Not Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using UniformCandidateSamplerCostPtr = std::shared_ptr<UniformCandidateSamplerCost>; | |||||
| class GatherV2Cost : public OperatorCost { | class GatherV2Cost : public OperatorCost { | ||||
| public: | public: | ||||
| explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||||
| GatherV2Cost() : OperatorCost(true) {} | |||||
| GatherV2Cost() : OperatorCost() {} | |||||
| ~GatherV2Cost() override = default; | ~GatherV2Cost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -740,14 +1010,15 @@ class GatherV2Cost : public OperatorCost { | |||||
| int64_t stage_id) const override; | int64_t stage_id) const override; | ||||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| int64_t) const override; | int64_t) const override; | ||||
| // Not taking account of output | |||||
| void CalculateOutputInMemory() override; | |||||
| // Taking account of input | |||||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||||
| }; | }; | ||||
| using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>; | |||||
| class GatherV2PCost : public OperatorCost { | |||||
| class GatherV2PCost : public GatherV2Cost { | |||||
| public: | public: | ||||
| explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} | |||||
| GatherV2PCost() : OperatorCost(true), axis_(0) {} | |||||
| GatherV2PCost() : GatherV2Cost(), axis_(0) {} | |||||
| ~GatherV2PCost() override = default; | ~GatherV2PCost() override = default; | ||||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | ||||
| @@ -773,8 +1044,6 @@ class GatherV2PCost : public OperatorCost { | |||||
| int64_t axis_; | int64_t axis_; | ||||
| Shape strategy_; | Shape strategy_; | ||||
| }; | }; | ||||
| using GatherV2PCostPtr = std::shared_ptr<GatherV2PCost>; | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ | #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ | ||||
| @@ -50,8 +50,8 @@ class ActivationBase : public OperatorInfo { | |||||
| class Activation : public ActivationBase { | class Activation : public ActivationBase { | ||||
| public: | public: | ||||
| Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | |||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {} | |||||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, cost) {} | |||||
| ~Activation() override = default; | ~Activation() override = default; | ||||
| Status GenerateStrategies(int64_t stage_id) override; | Status GenerateStrategies(int64_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| @@ -64,7 +64,7 @@ class ActivationInfo : public Activation { | |||||
| public: | public: | ||||
| ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationInfoCost>()) {} | |||||
| ~ActivationInfo() override = default; | ~ActivationInfo() override = default; | ||||
| protected: | protected: | ||||
| @@ -74,8 +74,8 @@ class ActivationInfo : public Activation { | |||||
| class ActivationOther : public Activation { | class ActivationOther : public Activation { | ||||
| public: | public: | ||||
| ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | |||||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||||
| : Activation(name, inputs_shape, outputs_shape, attrs, cost) {} | |||||
| ~ActivationOther() override = default; | ~ActivationOther() override = default; | ||||
| protected: | protected: | ||||
| @@ -86,7 +86,7 @@ class GeluInfo : public ActivationOther { | |||||
| public: | public: | ||||
| GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<GeLUCost>()) {} | |||||
| ~GeluInfo() override = default; | ~GeluInfo() override = default; | ||||
| }; | }; | ||||
| @@ -94,7 +94,7 @@ class TanhInfo : public ActivationOther { | |||||
| public: | public: | ||||
| TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanhCost>()) {} | |||||
| ~TanhInfo() override = default; | ~TanhInfo() override = default; | ||||
| }; | }; | ||||
| @@ -102,7 +102,7 @@ class Softmax : public ActivationBase { | |||||
| public: | public: | ||||
| explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {} | |||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {} | |||||
| ~Softmax() override = default; | ~Softmax() override = default; | ||||
| Status GenerateStrategies(int64_t stage_id) override; | Status GenerateStrategies(int64_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| @@ -134,7 +134,7 @@ class LogSoftmaxInfo : public Softmax { | |||||
| class EluInfo : public ActivationOther { | class EluInfo : public ActivationOther { | ||||
| public: | public: | ||||
| EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<EluCost>()) {} | |||||
| ~EluInfo() override = default; | ~EluInfo() override = default; | ||||
| }; | }; | ||||
| @@ -142,7 +142,7 @@ class ReLUInfo : public ActivationOther { | |||||
| public: | public: | ||||
| ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUCost>()) {} | |||||
| ~ReLUInfo() override = default; | ~ReLUInfo() override = default; | ||||
| }; | }; | ||||
| @@ -150,7 +150,7 @@ class RepeatElementsInfo : public ActivationOther { | |||||
| public: | public: | ||||
| RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | RepeatElementsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RepeatElementsCost>()) {} | |||||
| ~RepeatElementsInfo() override = default; | ~RepeatElementsInfo() override = default; | ||||
| }; | }; | ||||
| @@ -158,7 +158,7 @@ class ReLU6Info : public ActivationOther { | |||||
| public: | public: | ||||
| ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLU6Cost>()) {} | |||||
| ~ReLU6Info() override = default; | ~ReLU6Info() override = default; | ||||
| }; | }; | ||||
| @@ -166,7 +166,7 @@ class SoftsignInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftsignCost>()) {} | |||||
| ~SoftsignInfo() override = default; | ~SoftsignInfo() override = default; | ||||
| }; | }; | ||||
| @@ -174,7 +174,7 @@ class SoftplusInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftplusCost>()) {} | |||||
| ~SoftplusInfo() override = default; | ~SoftplusInfo() override = default; | ||||
| }; | }; | ||||
| @@ -182,7 +182,7 @@ class CastInfo : public ActivationOther { | |||||
| public: | public: | ||||
| CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CastCost>()) {} | |||||
| ~CastInfo() override = default; | ~CastInfo() override = default; | ||||
| protected: | protected: | ||||
| @@ -193,14 +193,14 @@ class SqrtInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqrtCost>()) {} | |||||
| ~SqrtInfo() override = default; | ~SqrtInfo() override = default; | ||||
| }; | }; | ||||
| class NegInfo : public ActivationOther { | class NegInfo : public ActivationOther { | ||||
| public: | public: | ||||
| NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<NegCost>()) {} | |||||
| ~NegInfo() override = default; | ~NegInfo() override = default; | ||||
| }; | }; | ||||
| @@ -208,7 +208,7 @@ class ExpandDimsInfo : public ActivationOther { | |||||
| public: | public: | ||||
| ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpandDimsCost>()) {} | |||||
| ~ExpandDimsInfo() override = default; | ~ExpandDimsInfo() override = default; | ||||
| protected: | protected: | ||||
| @@ -228,7 +228,7 @@ class SqueezeInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SqueezeCost>()) {} | |||||
| ~SqueezeInfo() override = default; | ~SqueezeInfo() override = default; | ||||
| protected: | protected: | ||||
| @@ -247,7 +247,7 @@ class SquareInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareCost>()) {} | |||||
| ~SquareInfo() override = default; | ~SquareInfo() override = default; | ||||
| }; | }; | ||||
| @@ -255,7 +255,7 @@ class SigmoidInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SigmoidCost>()) {} | |||||
| ~SigmoidInfo() override = default; | ~SigmoidInfo() override = default; | ||||
| }; | }; | ||||
| @@ -263,7 +263,7 @@ class DropoutInfo : public ActivationOther { | |||||
| public: | public: | ||||
| DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} | |||||
| ~DropoutInfo() override = default; | ~DropoutInfo() override = default; | ||||
| Status GenerateStrategies(int64_t stage_id) override; | Status GenerateStrategies(int64_t stage_id) override; | ||||
| @@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo { | |||||
| class SubInfo : public ArithmeticBase { | class SubInfo : public ArithmeticBase { | ||||
| public: | public: | ||||
| SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SubCost>()) {} | |||||
| ~SubInfo() override = default; | ~SubInfo() override = default; | ||||
| }; | }; | ||||
| @@ -64,28 +64,28 @@ class TensorAddInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorAddCost>()) {} | |||||
| ~TensorAddInfo() override = default; | ~TensorAddInfo() override = default; | ||||
| }; | }; | ||||
| class MulInfo : public ArithmeticBase { | class MulInfo : public ArithmeticBase { | ||||
| public: | public: | ||||
| MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MulCost>()) {} | |||||
| ~MulInfo() override = default; | ~MulInfo() override = default; | ||||
| }; | }; | ||||
| class DivInfo : public ArithmeticBase { | class DivInfo : public ArithmeticBase { | ||||
| public: | public: | ||||
| DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivCost>()) {} | |||||
| ~DivInfo() override = default; | ~DivInfo() override = default; | ||||
| }; | }; | ||||
| class ModInfo : public ArithmeticBase { | class ModInfo : public ArithmeticBase { | ||||
| public: | public: | ||||
| ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ModCost>()) {} | |||||
| ~ModInfo() override = default; | ~ModInfo() override = default; | ||||
| }; | }; | ||||
| @@ -93,7 +93,7 @@ class RealDivInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReadDivCost>()) {} | |||||
| ~RealDivInfo() override = default; | ~RealDivInfo() override = default; | ||||
| }; | }; | ||||
| @@ -101,7 +101,7 @@ class FloorDivInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorDivCost>()) {} | |||||
| ~FloorDivInfo() override = default; | ~FloorDivInfo() override = default; | ||||
| }; | }; | ||||
| @@ -109,14 +109,14 @@ class FloorModInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorModCost>()) {} | |||||
| ~FloorModInfo() override = default; | ~FloorModInfo() override = default; | ||||
| }; | }; | ||||
| class PowInfo : public ArithmeticBase { | class PowInfo : public ArithmeticBase { | ||||
| public: | public: | ||||
| PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<PowCost>()) {} | |||||
| ~PowInfo() override = default; | ~PowInfo() override = default; | ||||
| }; | }; | ||||
| @@ -124,7 +124,7 @@ class AssignSubInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignSubCost>()) {} | |||||
| ~AssignSubInfo() override = default; | ~AssignSubInfo() override = default; | ||||
| }; | }; | ||||
| @@ -132,7 +132,7 @@ class AssignInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignCost>()) {} | |||||
| ~AssignInfo() override = default; | ~AssignInfo() override = default; | ||||
| }; | }; | ||||
| @@ -140,7 +140,7 @@ class AssignAddInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<AssignAddCost>()) {} | |||||
| ~AssignAddInfo() override = default; | ~AssignAddInfo() override = default; | ||||
| }; | }; | ||||
| @@ -149,7 +149,8 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, | |||||
| std::make_shared<SigmoidCrossEntropyWithLogitsCost>()) {} | |||||
| ~SigmoidCrossEntropyWithLogitsInfo() override = default; | ~SigmoidCrossEntropyWithLogitsInfo() override = default; | ||||
| }; | }; | ||||
| @@ -157,7 +158,7 @@ class Atan2Info : public ArithmeticBase { | |||||
| public: | public: | ||||
| Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<Atan2Cost>()) {} | |||||
| ~Atan2Info() override = default; | ~Atan2Info() override = default; | ||||
| }; | }; | ||||
| @@ -165,7 +166,7 @@ class DivNoNanInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<DivNoNanCost>()) {} | |||||
| ~DivNoNanInfo() override = default; | ~DivNoNanInfo() override = default; | ||||
| }; | }; | ||||
| @@ -173,7 +174,7 @@ class LogicalAndInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalAndCost>()) {} | |||||
| ~LogicalAndInfo() override = default; | ~LogicalAndInfo() override = default; | ||||
| }; | }; | ||||
| @@ -181,7 +182,7 @@ class LogicalOrInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalOrCost>()) {} | |||||
| ~LogicalOrInfo() override = default; | ~LogicalOrInfo() override = default; | ||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -34,8 +34,7 @@ class BatchParallelInfo : public OperatorInfo { | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} | : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} | ||||
| BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)), | |||||
| dev_num_(1) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {} | |||||
| ~BatchParallelInfo() override = default; | ~BatchParallelInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -62,7 +61,8 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | |||||
| public: | public: | ||||
| SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, | SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, | ||||
| const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {} | |||||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, | |||||
| std::make_shared<SparseSoftmaxCrossEntropyWithLogitsCost>()) {} | |||||
| ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; | ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; | ||||
| void ReComputeBatchSplitFlagList() override; | void ReComputeBatchSplitFlagList() override; | ||||
| }; | }; | ||||
| @@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {} | |||||
| ~BiasAddInfo() override = default; | ~BiasAddInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -36,7 +36,7 @@ class BroadcastToInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | BroadcastToInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>()) {} | |||||
| ~BroadcastToInfo() override = default; | ~BroadcastToInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -32,7 +32,7 @@ class EqualInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<EqualCost>()) {} | |||||
| ~EqualInfo() override = default; | ~EqualInfo() override = default; | ||||
| }; | }; | ||||
| @@ -40,7 +40,7 @@ class ApproximateEqualInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ApproximateEqualCost>()) {} | |||||
| ~ApproximateEqualInfo() override = default; | ~ApproximateEqualInfo() override = default; | ||||
| }; | }; | ||||
| @@ -48,7 +48,7 @@ class NotEqualInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<NotEqualCost>()) {} | |||||
| ~NotEqualInfo() override = default; | ~NotEqualInfo() override = default; | ||||
| }; | }; | ||||
| @@ -56,7 +56,7 @@ class MaximumInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MaximumCost>()) {} | |||||
| ~MaximumInfo() override = default; | ~MaximumInfo() override = default; | ||||
| }; | }; | ||||
| @@ -64,7 +64,7 @@ class MinimumInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<MinimumCost>()) {} | |||||
| ~MinimumInfo() override = default; | ~MinimumInfo() override = default; | ||||
| }; | }; | ||||
| @@ -72,7 +72,7 @@ class GreaterInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterCost>()) {} | |||||
| ~GreaterInfo() override = default; | ~GreaterInfo() override = default; | ||||
| }; | }; | ||||
| @@ -80,7 +80,7 @@ class GreaterEqualInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<GreaterEqualCost>()) {} | |||||
| ~GreaterEqualInfo() override = default; | ~GreaterEqualInfo() override = default; | ||||
| }; | }; | ||||
| @@ -88,7 +88,7 @@ class LessInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessCost>()) {} | |||||
| ~LessInfo() override = default; | ~LessInfo() override = default; | ||||
| }; | }; | ||||
| @@ -96,7 +96,7 @@ class LessEqualInfo : public ArithmeticBase { | |||||
| public: | public: | ||||
| LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<LessEqualCost>()) {} | |||||
| ~LessEqualInfo() override = default; | ~LessEqualInfo() override = default; | ||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -33,7 +33,7 @@ class ConcatInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>()) {} | |||||
| ~ConcatInfo() override = default; | ~ConcatInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutDoMaskCost>()) {} | |||||
| ~DropoutDoMaskInfo() override = default; | ~DropoutDoMaskInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | #include "frontend/parallel/auto_parallel/operator_costmodel.h" | ||||
| #include "frontend/parallel/ops_info/activation_info.h" | #include "frontend/parallel/ops_info/activation_info.h" | ||||
| @@ -30,21 +31,21 @@ namespace parallel { | |||||
| class ExpInfo : public ActivationOther { | class ExpInfo : public ActivationOther { | ||||
| public: | public: | ||||
| ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ExpCost>()) {} | |||||
| ~ExpInfo() override = default; | ~ExpInfo() override = default; | ||||
| }; | }; | ||||
| class LogInfo : public ActivationOther { | class LogInfo : public ActivationOther { | ||||
| public: | public: | ||||
| LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogCost>()) {} | |||||
| ~LogInfo() override = default; | ~LogInfo() override = default; | ||||
| }; | }; | ||||
| class CosInfo : public ActivationOther { | class CosInfo : public ActivationOther { | ||||
| public: | public: | ||||
| CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CosCost>()) {} | |||||
| ~CosInfo() override = default; | ~CosInfo() override = default; | ||||
| }; | }; | ||||
| @@ -52,7 +53,7 @@ class ACosInfo : public ActivationOther { | |||||
| public: | public: | ||||
| ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ACosCost>()) {} | |||||
| ~ACosInfo() override = default; | ~ACosInfo() override = default; | ||||
| }; | }; | ||||
| @@ -60,14 +61,14 @@ class LogicalNotInfo : public ActivationOther { | |||||
| public: | public: | ||||
| LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<LogicalNotCost>()) {} | |||||
| ~LogicalNotInfo() override = default; | ~LogicalNotInfo() override = default; | ||||
| }; | }; | ||||
| class AbsInfo : public ActivationOther { | class AbsInfo : public ActivationOther { | ||||
| public: | public: | ||||
| AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AbsCost>()) {} | |||||
| ~AbsInfo() override = default; | ~AbsInfo() override = default; | ||||
| }; | }; | ||||
| @@ -75,7 +76,7 @@ class SignInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SignCost>()) {} | |||||
| ~SignInfo() override = default; | ~SignInfo() override = default; | ||||
| }; | }; | ||||
| @@ -83,7 +84,7 @@ class FloorInfo : public ActivationOther { | |||||
| public: | public: | ||||
| FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<FloorCost>()) {} | |||||
| ~FloorInfo() override = default; | ~FloorInfo() override = default; | ||||
| }; | }; | ||||
| @@ -91,7 +92,7 @@ class RoundInfo : public ActivationOther { | |||||
| public: | public: | ||||
| RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RoundCost>()) {} | |||||
| ~RoundInfo() override = default; | ~RoundInfo() override = default; | ||||
| }; | }; | ||||
| @@ -99,14 +100,14 @@ class ReciprocalInfo : public ActivationOther { | |||||
| public: | public: | ||||
| ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReciprocalCost>()) {} | |||||
| ~ReciprocalInfo() override = default; | ~ReciprocalInfo() override = default; | ||||
| }; | }; | ||||
| class InvInfo : public ActivationOther { | class InvInfo : public ActivationOther { | ||||
| public: | public: | ||||
| InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<InvCost>()) {} | |||||
| ~InvInfo() override = default; | ~InvInfo() override = default; | ||||
| }; | }; | ||||
| @@ -114,21 +115,21 @@ class RsqrtInfo : public ActivationOther { | |||||
| public: | public: | ||||
| RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<RsqrtCost>()) {} | |||||
| ~RsqrtInfo() override = default; | ~RsqrtInfo() override = default; | ||||
| }; | }; | ||||
| class TanInfo : public ActivationOther { | class TanInfo : public ActivationOther { | ||||
| public: | public: | ||||
| TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<TanCost>()) {} | |||||
| ~TanInfo() override = default; | ~TanInfo() override = default; | ||||
| }; | }; | ||||
| class SinInfo : public ActivationOther { | class SinInfo : public ActivationOther { | ||||
| public: | public: | ||||
| SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinCost>()) {} | |||||
| ~SinInfo() override = default; | ~SinInfo() override = default; | ||||
| }; | }; | ||||
| @@ -136,7 +137,7 @@ class SinhInfo : public ActivationOther { | |||||
| public: | public: | ||||
| SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<SinhCost>()) {} | |||||
| ~SinhInfo() override = default; | ~SinhInfo() override = default; | ||||
| }; | }; | ||||
| @@ -144,7 +145,7 @@ class Log1pInfo : public ActivationOther { | |||||
| public: | public: | ||||
| Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Log1pCost>()) {} | |||||
| ~Log1pInfo() override = default; | ~Log1pInfo() override = default; | ||||
| }; | }; | ||||
| @@ -152,7 +153,7 @@ class Expm1Info : public ActivationOther { | |||||
| public: | public: | ||||
| Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<Expm1Cost>()) {} | |||||
| ~Expm1Info() override = default; | ~Expm1Info() override = default; | ||||
| }; | }; | ||||
| @@ -160,7 +161,7 @@ class CoshInfo : public ActivationOther { | |||||
| public: | public: | ||||
| CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CoshCost>()) {} | |||||
| ~CoshInfo() override = default; | ~CoshInfo() override = default; | ||||
| }; | }; | ||||
| @@ -168,7 +169,7 @@ class CeilInfo : public ActivationOther { | |||||
| public: | public: | ||||
| CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<CeilCost>()) {} | |||||
| ~CeilInfo() override = default; | ~CeilInfo() override = default; | ||||
| }; | }; | ||||
| @@ -176,7 +177,7 @@ class AtanhInfo : public ActivationOther { | |||||
| public: | public: | ||||
| AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanhCost>()) {} | |||||
| ~AtanhInfo() override = default; | ~AtanhInfo() override = default; | ||||
| }; | }; | ||||
| @@ -184,7 +185,7 @@ class AtanInfo : public ActivationOther { | |||||
| public: | public: | ||||
| AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AtanCost>()) {} | |||||
| ~AtanInfo() override = default; | ~AtanInfo() override = default; | ||||
| }; | }; | ||||
| @@ -192,7 +193,7 @@ class AsinInfo : public ActivationOther { | |||||
| public: | public: | ||||
| AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinCost>()) {} | |||||
| ~AsinInfo() override = default; | ~AsinInfo() override = default; | ||||
| }; | }; | ||||
| @@ -200,7 +201,7 @@ class AsinhInfo : public ActivationOther { | |||||
| public: | public: | ||||
| AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AsinhCost>()) {} | |||||
| ~AsinhInfo() override = default; | ~AsinhInfo() override = default; | ||||
| }; | }; | ||||
| @@ -208,14 +209,14 @@ class AcoshInfo : public ActivationOther { | |||||
| public: | public: | ||||
| AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<AcoshCost>()) {} | |||||
| ~AcoshInfo() override = default; | ~AcoshInfo() override = default; | ||||
| }; | }; | ||||
| class ErfInfo : public ActivationOther { | class ErfInfo : public ActivationOther { | ||||
| public: | public: | ||||
| ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfCost>()) {} | |||||
| ~ErfInfo() override = default; | ~ErfInfo() override = default; | ||||
| }; | }; | ||||
| @@ -223,7 +224,7 @@ class ErfcInfo : public ActivationOther { | |||||
| public: | public: | ||||
| ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ErfcCost>()) {} | |||||
| ~ErfcInfo() override = default; | ~ErfcInfo() override = default; | ||||
| }; | }; | ||||
| @@ -231,7 +232,7 @@ class ZerosLikeInfo : public ActivationOther { | |||||
| public: | public: | ||||
| ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<ZerosLikeCost>()) {} | |||||
| ~ZerosLikeInfo() override = default; | ~ZerosLikeInfo() override = default; | ||||
| }; | }; | ||||
| @@ -239,7 +240,7 @@ class OnesLikeInfo : public ActivationOther { | |||||
| public: | public: | ||||
| OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<OnesLikeCost>()) {} | |||||
| ~OnesLikeInfo() override = default; | ~OnesLikeInfo() override = default; | ||||
| }; | }; | ||||
| @@ -247,7 +248,7 @@ class BesselI0eInfo : public ActivationOther { | |||||
| public: | public: | ||||
| BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI0eCost>()) {} | |||||
| ~BesselI0eInfo() override = default; | ~BesselI0eInfo() override = default; | ||||
| }; | }; | ||||
| @@ -255,7 +256,7 @@ class BesselI1eInfo : public ActivationOther { | |||||
| public: | public: | ||||
| BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<BesselI1eCost>()) {} | |||||
| ~BesselI1eInfo() override = default; | ~BesselI1eInfo() override = default; | ||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {} | |||||
| ~GetNextInfo() override = default; | ~GetNextInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -33,7 +33,7 @@ class L2NormalizeInfo : public Activation { | |||||
| public: | public: | ||||
| L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||||
| : Activation(name, inputs_shape, outputs_shape, attrs, std::make_shared<L2NormalizeCost>()) {} | |||||
| ~L2NormalizeInfo() override = default; | ~L2NormalizeInfo() override = default; | ||||
| Status GenerateStrategies(int64_t stage_id) override; | Status GenerateStrategies(int64_t stage_id) override; | ||||
| @@ -40,7 +40,7 @@ class LayerNormInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>(true)), | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>()), | |||||
| begin_norm_axis_(0) {} | begin_norm_axis_(0) {} | ||||
| ~LayerNormInfo() override = default; | ~LayerNormInfo() override = default; | ||||
| @@ -36,8 +36,7 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, | |||||
| std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {} | |||||
| ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo { | |||||
| public: | public: | ||||
| MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {} | |||||
| ~MatMulBase() override = default; | ~MatMulBase() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {} | |||||
| ~OneHotInfo() override = default; | ~OneHotInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -1204,6 +1204,20 @@ int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { | |||||
| } else { | } else { | ||||
| is_output_parameter_involve_ = 0; | is_output_parameter_involve_ = 0; | ||||
| } | } | ||||
| // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in | |||||
| // calculating 'inputs_in_memory' and 'output_in_memory', respectively. | |||||
| operator_cost()->set_is_parameter_involve(is_parameter_involve_); | |||||
| operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); | |||||
| // Calculating 'output_in_memory' | |||||
| operator_cost()->CalculateOutputInMemory(); | |||||
| // Calculating 'inputs_in_memory' | |||||
| std::map<size_t, bool> input_in_memory; | |||||
| for (auto &p_edge : prev_edges) { | |||||
| auto input_index = p_edge->next_op_input_index(); | |||||
| auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory(); | |||||
| input_in_memory.emplace(std::make_pair(input_index, is_in_mem)); | |||||
| } | |||||
| operator_cost()->CalculateInputsInMemory(input_in_memory); | |||||
| return is_output_parameter_involve_; | return is_output_parameter_involve_; | ||||
| } | } | ||||
| @@ -1220,14 +1234,10 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) { | |||||
| } | } | ||||
| Status OperatorInfo::CalculateMemoryCost() { | Status OperatorInfo::CalculateMemoryCost() { | ||||
| // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to | |||||
| // calculate memory cost. | |||||
| if (is_parameter_involve_.size() != is_parameter_.size()) { | if (is_parameter_involve_.size() != is_parameter_.size()) { | ||||
| MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; | MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| operator_cost()->set_is_parameter_involve(is_parameter_involve_); | |||||
| operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); | |||||
| // Set the memory cost in the 'strategy_cost_' | // Set the memory cost in the 'strategy_cost_' | ||||
| for (auto &swc : strategy_cost_) { | for (auto &swc : strategy_cost_) { | ||||
| auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); | auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); | ||||
| @@ -33,7 +33,7 @@ class PackInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | PackInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>()) {} | |||||
| ~PackInfo() override = default; | ~PackInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {} | |||||
| ~PReLUInfo() override = default; | ~PReLUInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -39,7 +39,7 @@ class RangeInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | RangeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(true)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RangeCost>()) {} | |||||
| ~RangeInfo() override = default; | ~RangeInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -33,8 +33,8 @@ namespace parallel { | |||||
| class ReduceMethod : public OperatorInfo { | class ReduceMethod : public OperatorInfo { | ||||
| public: | public: | ||||
| ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {} | |||||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {} | |||||
| ~ReduceMethod() override = default; | ~ReduceMethod() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -62,7 +62,7 @@ class ReduceMaxInfo : public ReduceMethod { | |||||
| public: | public: | ||||
| ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMaxCost>()) { | |||||
| reduce_method_ = REDUCE_OP_MAX; | reduce_method_ = REDUCE_OP_MAX; | ||||
| } | } | ||||
| @@ -73,7 +73,7 @@ class ArgMaxWithValueInfo : public ReduceMethod { | |||||
| public: | public: | ||||
| ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgMaxWithValueCost>()) { | |||||
| reduce_method_ = REDUCE_OP_MAX; | reduce_method_ = REDUCE_OP_MAX; | ||||
| } | } | ||||
| @@ -105,9 +105,7 @@ class ReduceMeanInfo : public ReduceMethod { | |||||
| public: | public: | ||||
| ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||||
| set_cost(std::make_shared<ReduceMeanCost>()); | |||||
| } | |||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMeanCost>()) {} | |||||
| ~ReduceMeanInfo() override = default; | ~ReduceMeanInfo() override = default; | ||||
| @@ -119,7 +117,7 @@ class ReduceSumInfo : public ReduceMethod { | |||||
| public: | public: | ||||
| ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceSumCost>()) { | |||||
| reduce_method_ = REDUCE_OP_SUM; | reduce_method_ = REDUCE_OP_SUM; | ||||
| } | } | ||||
| @@ -130,7 +128,7 @@ class ReduceMinInfo : public ReduceMethod { | |||||
| public: | public: | ||||
| ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | |||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMinCost>()) { | |||||
| reduce_method_ = REDUCE_OP_MIN; | reduce_method_ = REDUCE_OP_MIN; | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ class ReLUV2Info : public OperatorInfo { | |||||
| public: | public: | ||||
| ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReLUV2Info(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUV2Cost>()) {} | |||||
| ~ReLUV2Info() override = default; | ~ReLUV2Info() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)), | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()), | |||||
| dev_num_(0), | dev_num_(0), | ||||
| pre_operator_index_(0), | pre_operator_index_(0), | ||||
| next_operator_index_(0), | next_operator_index_(0), | ||||
| @@ -34,7 +34,7 @@ class SliceInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| SliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>(false)), | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SliceCost>()), | |||||
| slice_axis_(-1) {} | slice_axis_(-1) {} | ||||
| ~SliceInfo() override = default; | ~SliceInfo() override = default; | ||||
| @@ -31,7 +31,7 @@ class SplitInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| SplitInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | SplitInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SplitCost>()) {} | |||||
| ~SplitInfo() override = default; | ~SplitInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -273,6 +273,8 @@ Status StridedSliceInfo::GenerateStrategies(int64_t stage_id) { | |||||
| PrintStrategy(sp); | PrintStrategy(sp); | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << name() << ", finishing GenerateStrategies()."; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -34,7 +34,7 @@ class StridedSliceInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | StridedSliceInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {} | |||||
| ~StridedSliceInfo() override = default; | ~StridedSliceInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -41,7 +41,7 @@ class TensorDotInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | TensorDotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorDotCost>()) {} | |||||
| ~TensorDotInfo() override = default; | ~TensorDotInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -34,7 +34,7 @@ class TileInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | TileInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>()) {} | |||||
| ~TileInfo() override = default; | ~TileInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, | TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, | ||||
| const std::string &name = IDENTITY_INFO) | const std::string &name = IDENTITY_INFO) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {} | |||||
| ~TmpIdentityInfo() override = default; | ~TmpIdentityInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {} | |||||
| ~TransposeInfo() override = default; | ~TransposeInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -32,7 +32,7 @@ class UniqueInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | UniqueInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {} | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniqueCost>()) {} | |||||
| ~UniqueInfo() override = default; | ~UniqueInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -82,7 +82,7 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | |||||
| public: | public: | ||||
| UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {} | |||||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {} | |||||
| ~UnsortedSegmentMaxInfo() override = default; | ~UnsortedSegmentMaxInfo() override = default; | ||||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | ||||
| @@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {} | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {} | |||||
| ~VirtualDatasetInfo() override = default; | ~VirtualDatasetInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| @@ -85,11 +85,11 @@ class TestActivationCost : public UT::Common { | |||||
| TestActivationCost() {} | TestActivationCost() {} | ||||
| void SetUp(); | void SetUp(); | ||||
| void TearDown(); | void TearDown(); | ||||
| ActivationCost ac_cost_; | |||||
| ActivationInfoCost ac_cost_; | |||||
| }; | }; | ||||
| void TestActivationCost::SetUp() { | void TestActivationCost::SetUp() { | ||||
| ac_cost_ = ActivationCost(); | |||||
| ac_cost_ = ActivationInfoCost(); | |||||
| RankList dev_list; | RankList dev_list; | ||||
| for (int32_t i = 0; i < 1050; i++) { | for (int32_t i = 0; i < 1050; i++) { | ||||