| @@ -514,60 +514,6 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs | |||||
| return result; | return result; | ||||
| } | } | ||||
| double L2NormalizeCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| double result = 0.0; | |||||
| if (is_parameter_[0]) { | |||||
| TensorInfo input_tensor_info = inputs[0]; | |||||
| CheckGlobalDeviceManager(); | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||||
| Shape input_shape = input_tensor_info.shape(); | |||||
| Shape input_slice_shape = input_tensor_info.slice_shape(); | |||||
| int32_t used_device_num = 1; | |||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||||
| used_device_num *= input_shape[i] / input_slice_shape[i]; | |||||
| } | |||||
| if (total_device_num != IntToSize(used_device_num)) | |||||
| result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| double L2NormalizeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| TensorInfo input0_info = inputs[0]; | |||||
| Shape input0_slice_shape = input0_info.slice_shape(); | |||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||||
| } | |||||
| double L2NormalizeCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>&, const int32_t& stage_id) const { | |||||
| double result = 0.0; | |||||
| if (is_parameter_[0]) { | |||||
| TensorInfo input_tensor_info = inputs[0]; | |||||
| CheckGlobalDeviceManager(); | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||||
| Shape input_shape = input_tensor_info.shape(); | |||||
| Shape input_slice_shape = input_tensor_info.slice_shape(); | |||||
| int32_t used_device_num = 1; | |||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||||
| used_device_num *= input_shape[i] / input_slice_shape[i]; | |||||
| } | |||||
| if (total_device_num != IntToSize(used_device_num)) | |||||
| result += ListProduct(input_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& stage_id) { | bool IsDataParallel(const Shape& shape, const Shape& slice_shape, const int32_t& stage_id) { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| @@ -132,6 +132,8 @@ class ActivationCost : public OperatorCost { | |||||
| }; | }; | ||||
| using ActivationCostPtr = std::shared_ptr<ActivationCost>; | using ActivationCostPtr = std::shared_ptr<ActivationCost>; | ||||
| using TransposeCost = ActivationCost; | |||||
| using TransposeCostPtr = std::shared_ptr<TransposeCost>; | |||||
| class SoftmaxCost : public OperatorCost { | class SoftmaxCost : public OperatorCost { | ||||
| public: | public: | ||||
| @@ -415,32 +417,8 @@ class ArithmeticCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| }; | }; | ||||
| using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>; | using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>; | ||||
| class L2NormalizeCost : public OperatorCost { | |||||
| public: | |||||
| L2NormalizeCost() = default; | |||||
| ~L2NormalizeCost() override = default; | |||||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | |||||
| } | |||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | |||||
| using L2NormalizeCostPtr = std::shared_ptr<L2NormalizeCost>; | |||||
| using BiasAddCost = ArithmeticCost; | |||||
| using BiasAddCostPtr = std::shared_ptr<BiasAddCost>; | |||||
| class ReduceMethodCost : public OperatorCost { | class ReduceMethodCost : public OperatorCost { | ||||
| public: | public: | ||||
| @@ -32,8 +32,8 @@ namespace parallel { | |||||
| class ActivationBase : public OperatorInfo { | class ActivationBase : public OperatorInfo { | ||||
| public: | public: | ||||
| ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {} | |||||
| const PrimitiveAttrs& attrs, OperatorCostPtr cost) | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} | |||||
| ~ActivationBase() override = default; | ~ActivationBase() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| @@ -51,19 +51,13 @@ class Activation : public ActivationBase { | |||||
| public: | public: | ||||
| Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs) { | |||||
| ac_cost_ptr_ = std::make_shared<ActivationCost>(); | |||||
| } | |||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {} | |||||
| ~Activation() override = default; | ~Activation() override = default; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return ac_cost_ptr_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | Status CheckStrategy(const StrategyPtr& strategy) override; | ||||
| private: | |||||
| ActivationCostPtr ac_cost_ptr_; | |||||
| }; | }; | ||||
| class ActivationInfo : public Activation { | class ActivationInfo : public Activation { | ||||
| @@ -108,13 +102,10 @@ class Softmax : public ActivationBase { | |||||
| public: | public: | ||||
| explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs) { | |||||
| sm_cost_ptr_ = std::make_shared<SoftmaxCost>(); | |||||
| } | |||||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {} | |||||
| ~Softmax() override = default; | ~Softmax() override = default; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return sm_cost_ptr_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | Status CheckStrategy(const StrategyPtr& strategy) override; | ||||
| @@ -122,7 +113,6 @@ class Softmax : public ActivationBase { | |||||
| private: | private: | ||||
| std::vector<int32_t> axis_; | std::vector<int32_t> axis_; | ||||
| SoftmaxCostPtr sm_cost_ptr_; | |||||
| }; | }; | ||||
| class SoftmaxInfo : public Softmax { | class SoftmaxInfo : public Softmax { | ||||
| @@ -33,15 +33,12 @@ class ArithmeticBase : public OperatorInfo { | |||||
| public: | public: | ||||
| ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { | |||||
| arithmeticcost_ptr_ = std::make_shared<ArithmeticCost>(); | |||||
| } | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>()) {} | |||||
| ~ArithmeticBase() override = default; | ~ArithmeticBase() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t) override; | Status GenerateStrategies(int32_t) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr&) override; | Status SetCostUnderStrategy(const StrategyPtr&) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return arithmeticcost_ptr_; } | |||||
| void ReComputeBatchSplitFlagList() override; | void ReComputeBatchSplitFlagList() override; | ||||
| protected: | protected: | ||||
| @@ -54,7 +51,6 @@ class ArithmeticBase : public OperatorInfo { | |||||
| Status InferTensorMap() override; | Status InferTensorMap() override; | ||||
| Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); | Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); | ||||
| Shapes InferExpendShape(); | Shapes InferExpendShape(); | ||||
| ArithmeticCostPtr arithmeticcost_ptr_; | |||||
| }; | }; | ||||
| class SubInfo : public ArithmeticBase { | class SubInfo : public ArithmeticBase { | ||||
| @@ -31,16 +31,13 @@ class BatchParallelInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs), dev_num_(1) { | |||||
| bp_cost_ptr_ = std::make_shared<BatchParallelCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {} | |||||
| ~BatchParallelInfo() override = default; | ~BatchParallelInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return bp_cost_ptr_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | Status CheckStrategy(const StrategyPtr& strategy) override; | ||||
| @@ -55,7 +52,6 @@ class BatchParallelInfo : public OperatorInfo { | |||||
| private: | private: | ||||
| int32_t dev_num_; | int32_t dev_num_; | ||||
| BatchParallelCostPtr bp_cost_ptr_; | |||||
| }; | }; | ||||
| class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | ||||
| @@ -34,16 +34,13 @@ class BiasAddInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { | |||||
| biasaddcost_ptr_ = std::make_shared<ArithmeticCost>(); | |||||
| } | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {} | |||||
| ~BiasAddInfo() override = default; | ~BiasAddInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t) override; | Status GenerateStrategies(int32_t) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr&) override; | Status SetCostUnderStrategy(const StrategyPtr&) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return biasaddcost_ptr_; } | |||||
| void ReComputeBatchSplitFlagList() override; | void ReComputeBatchSplitFlagList() override; | ||||
| protected: | protected: | ||||
| @@ -55,7 +52,6 @@ class BiasAddInfo : public OperatorInfo { | |||||
| Status InferDevMatrixShape() override; | Status InferDevMatrixShape() override; | ||||
| Status InferTensorMap() override; | Status InferTensorMap() override; | ||||
| Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); | Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); | ||||
| ArithmeticCostPtr biasaddcost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,15 +33,12 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| bpcost_ptr_ = std::make_shared<BatchParallelCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {} | |||||
| ~DropoutDoMaskInfo() override = default; | ~DropoutDoMaskInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return bpcost_ptr_; } | |||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | ||||
| @@ -53,9 +50,6 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||||
| Status GetAttrs() override { return SUCCESS; } | Status GetAttrs() override { return SUCCESS; } | ||||
| Status InferTensorInfo() override; | Status InferTensorInfo() override; | ||||
| Status InferDevMatrixShape() override; | Status InferDevMatrixShape() override; | ||||
| private: | |||||
| BatchParallelCostPtr bpcost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,15 +32,12 @@ class GeneratorBase : public OperatorInfo { | |||||
| public: | public: | ||||
| GeneratorBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GeneratorBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { | |||||
| generatorbasecost_ptr_ = std::make_shared<GeneratorBaseCost>(); | |||||
| } | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GeneratorBaseCost>()) {} | |||||
| ~GeneratorBase() override = default; | ~GeneratorBase() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return generatorbasecost_ptr_; } | |||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| protected: | protected: | ||||
| @@ -52,7 +49,6 @@ class GeneratorBase : public OperatorInfo { | |||||
| Status InferMirrorOps() override { return SUCCESS; } | Status InferMirrorOps() override { return SUCCESS; } | ||||
| Status InferForwardCommunication() override { return SUCCESS; } | Status InferForwardCommunication() override { return SUCCESS; } | ||||
| virtual Status InferReplaceOps(const StrategyPtr &strategy) = 0; | virtual Status InferReplaceOps(const StrategyPtr &strategy) = 0; | ||||
| GeneratorBaseCostPtr generatorbasecost_ptr_; | |||||
| }; | }; | ||||
| class DropoutGenMaskInfo : public GeneratorBase { | class DropoutGenMaskInfo : public GeneratorBase { | ||||
| @@ -32,14 +32,11 @@ class GetNextInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) { | |||||
| getnextcost_ptr_ = std::make_shared<GetNextCost>(); | |||||
| } | |||||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {} | |||||
| ~GetNextInfo() override = default; | ~GetNextInfo() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return getnextcost_ptr_; } | |||||
| Status InitForCostModel(const StrategyPtr &strategy) override; | Status InitForCostModel(const StrategyPtr &strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| @@ -65,7 +62,6 @@ class GetNextInfo : public OperatorInfo { | |||||
| Shapes shapes_; | Shapes shapes_; | ||||
| int32_t output_num_ = 0; | int32_t output_num_ = 0; | ||||
| std::string shared_name_; | std::string shared_name_; | ||||
| GetNextCostPtr getnextcost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,12 +33,9 @@ class L2NormalizeInfo : public Activation { | |||||
| public: | public: | ||||
| L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : Activation(name, inputs_shape, outputs_shape, attrs) { | |||||
| l2normalizecost_ptr_ = std::make_shared<L2NormalizeCost>(); | |||||
| } | |||||
| : Activation(name, inputs_shape, outputs_shape, attrs) {} | |||||
| ~L2NormalizeInfo() override = default; | ~L2NormalizeInfo() override = default; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return l2normalizecost_ptr_; } | |||||
| protected: | protected: | ||||
| Status GetAttrs() override; | Status GetAttrs() override; | ||||
| @@ -47,7 +44,6 @@ class L2NormalizeInfo : public Activation { | |||||
| private: | private: | ||||
| int32_t axis_ = 0; // Default value = 0 | int32_t axis_ = 0; // Default value = 0 | ||||
| L2NormalizeCostPtr l2normalizecost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,16 +36,13 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| softmax_loss_cost_ptr_ = std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {} | |||||
| ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return softmax_loss_cost_ptr_; } | |||||
| void ReComputeBatchSplitFlagList() override; | void ReComputeBatchSplitFlagList() override; | ||||
| protected: | protected: | ||||
| @@ -59,7 +56,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||||
| // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload | // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload | ||||
| // the InferAsLossDivisor. | // the InferAsLossDivisor. | ||||
| Status InferAsLossDivisor() override; | Status InferAsLossDivisor() override; | ||||
| SoftmaxCrossEntropyWithLogitsCostPtr softmax_loss_cost_ptr_; | |||||
| private: | private: | ||||
| int32_t axis_ = -1; // default -1 | int32_t axis_ = -1; // default -1 | ||||
| @@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& | |||||
| // Here, we use the origin outputs_, because we only use the slice size of the output tensor. | // Here, we use the origin outputs_, because we only use the slice size of the output tensor. | ||||
| // It does not matter whether the output tensor is transposed or not. | // It does not matter whether the output tensor is transposed or not. | ||||
| double computation_cost = | double computation_cost = | ||||
| matmulcost_ptr->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | ||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| result->communication_without_parameter_ + | result->communication_without_parameter_ + | ||||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | ||||
| @@ -34,9 +34,7 @@ class MatMulBase : public OperatorInfo { | |||||
| public: | public: | ||||
| MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| matmulcost_ptr = std::make_shared<MatMulCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {} | |||||
| ~MatMulBase() override = default; | ~MatMulBase() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| @@ -48,7 +46,6 @@ class MatMulBase : public OperatorInfo { | |||||
| Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, | Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, | ||||
| size_t input1_shape_size, StrategyPtr* sp); | size_t input1_shape_size, StrategyPtr* sp); | ||||
| OperatorCostPtr GetOperatorCost() const override { return matmulcost_ptr; } | |||||
| Status SwapLastTwoElements(Shape* shape); | Status SwapLastTwoElements(Shape* shape); | ||||
| protected: | protected: | ||||
| @@ -66,8 +63,6 @@ class MatMulBase : public OperatorInfo { | |||||
| bool transpose_b_ = false; | bool transpose_b_ = false; | ||||
| size_t mat_a_dimension_ = 0; | size_t mat_a_dimension_ = 0; | ||||
| size_t mat_b_dimension_ = 0; | size_t mat_b_dimension_ = 0; | ||||
| MatMulCostPtr matmulcost_ptr; | |||||
| }; | }; | ||||
| class MatMul : public MatMulBase { | class MatMul : public MatMulBase { | ||||
| @@ -33,16 +33,13 @@ class OneHotInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| onehot_cost_ptr_ = std::make_shared<OneHotCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {} | |||||
| ~OneHotInfo() override = default; | ~OneHotInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return onehot_cost_ptr_; } | |||||
| ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override; | ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override; | ||||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | ||||
| @@ -60,7 +57,6 @@ class OneHotInfo : public OperatorInfo { | |||||
| Status ComputeReplaceGraph(const CNodePtr& cnode); | Status ComputeReplaceGraph(const CNodePtr& cnode); | ||||
| int axis_ = -1; | int axis_ = -1; | ||||
| OneHotCostPtr onehot_cost_ptr_; | |||||
| int32_t rank_ = 0; | int32_t rank_ = 0; | ||||
| int32_t total_class_number_ = 1; | int32_t total_class_number_ = 1; | ||||
| int32_t classes_each_device_ = 1; | int32_t classes_each_device_ = 1; | ||||
| @@ -1034,12 +1034,11 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int32_t stage_id = strategy->GetInputStage(); | int32_t stage_id = strategy->GetInputStage(); | ||||
| double computation_cost = | |||||
| GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | ||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| result->communication_without_parameter_ + | result->communication_without_parameter_ + | ||||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | ||||
| @@ -1096,7 +1095,7 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| is_parameter_ = is_parameter; | is_parameter_ = is_parameter; | ||||
| GetOperatorCost()->set_is_parameter(is_parameter); | |||||
| cost()->set_is_parameter(is_parameter); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1193,7 +1192,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu | |||||
| } | } | ||||
| inputs_type_lengths_ = input_lengths; | inputs_type_lengths_ = input_lengths; | ||||
| outputs_type_lengths_ = output_lengths; | outputs_type_lengths_ = output_lengths; | ||||
| GetOperatorCost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); | |||||
| cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1211,7 +1210,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra | |||||
| } | } | ||||
| double OperatorInfo::GetForwardMemoryCostFromCNode() { | double OperatorInfo::GetForwardMemoryCostFromCNode() { | ||||
| return GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||||
| return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -53,12 +53,13 @@ class Edge; | |||||
| class OperatorInfo { | class OperatorInfo { | ||||
| public: | public: | ||||
| OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs) | |||||
| OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost) | |||||
| : name_(std::move(name)), | : name_(std::move(name)), | ||||
| inputs_shape_(std::move(inputs_shape)), | inputs_shape_(std::move(inputs_shape)), | ||||
| outputs_shape_(std::move(outputs_shape)), | outputs_shape_(std::move(outputs_shape)), | ||||
| attrs_(std::move(attrs)), | attrs_(std::move(attrs)), | ||||
| is_alive_(true) { | |||||
| is_alive_(true), | |||||
| cost_(cost) { | |||||
| std::vector<bool> not_parameteter(inputs_shape_.size(), false); | std::vector<bool> not_parameteter(inputs_shape_.size(), false); | ||||
| is_parameter_ = not_parameteter; | is_parameter_ = not_parameteter; | ||||
| refkey_parameter_name_ = ""; | refkey_parameter_name_ = ""; | ||||
| @@ -75,7 +76,8 @@ class OperatorInfo { | |||||
| // Given the stage_id (which indicates the number of devices), | // Given the stage_id (which indicates the number of devices), | ||||
| // generate all strategies for this operator | // generate all strategies for this operator | ||||
| virtual Status GenerateStrategies(int32_t stage_id) = 0; | virtual Status GenerateStrategies(int32_t stage_id) = 0; | ||||
| virtual OperatorCostPtr GetOperatorCost() const = 0; | |||||
| const OperatorCostPtr& cost() const { return cost_; } | |||||
| void set_cost(const OperatorCostPtr& cost) { cost_ = cost; } | |||||
| virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; | virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; | ||||
| virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies(); | virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies(); | ||||
| @@ -115,7 +117,7 @@ class OperatorInfo { | |||||
| void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | ||||
| void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | ||||
| void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | ||||
| std::vector<size_t> GetOutputTypeLengths() const { return GetOperatorCost()->outputs_type_lengths(); } | |||||
| std::vector<size_t> GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); } | |||||
| void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { | void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { | ||||
| selected_strategy_ = s_strategy; | selected_strategy_ = s_strategy; | ||||
| selected_cost_ = cost; | selected_cost_ = cost; | ||||
| @@ -221,6 +223,9 @@ class OperatorInfo { | |||||
| std::string refkey_parameter_name_; | std::string refkey_parameter_name_; | ||||
| CNodePtr cnode_; | CNodePtr cnode_; | ||||
| int32_t used_devices_ = -1; | int32_t used_devices_ = -1; | ||||
| private: | |||||
| OperatorCostPtr cost_; | |||||
| }; | }; | ||||
| Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy); | Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy); | ||||
| @@ -35,15 +35,12 @@ class PReLUInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| prelucost_ptr = std::make_shared<PReLUCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {} | |||||
| ~PReLUInfo() override = default; | ~PReLUInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return prelucost_ptr; } | |||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| protected: | protected: | ||||
| @@ -59,7 +56,6 @@ class PReLUInfo : public OperatorInfo { | |||||
| private: | private: | ||||
| Dimensions input_strategy_; | Dimensions input_strategy_; | ||||
| PReLUCostPtr prelucost_ptr; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -109,8 +109,12 @@ Status ReduceMethod::GetAttrs() { | |||||
| } | } | ||||
| cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value(); | cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value(); | ||||
| } | } | ||||
| reducemethodcost_ptr_->set_cross_batch(cross_batch_); | |||||
| auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(cost()); | |||||
| if (reducemethodcost == nullptr) { | |||||
| MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; | |||||
| return FAILED; | |||||
| } | |||||
| reducemethodcost->set_cross_batch(cross_batch_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -34,9 +34,7 @@ class ReduceMethod : public OperatorInfo { | |||||
| public: | public: | ||||
| ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| reducemethodcost_ptr_ = std::make_shared<ReduceMethodCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>()) {} | |||||
| ~ReduceMethod() override = default; | ~ReduceMethod() override = default; | ||||
| Status Init(const StrategyPtr &strategy) override; | Status Init(const StrategyPtr &strategy) override; | ||||
| @@ -44,13 +42,11 @@ class ReduceMethod : public OperatorInfo { | |||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return reducemethodcost_ptr_; } | |||||
| protected: | protected: | ||||
| std::string reduce_method_; | std::string reduce_method_; | ||||
| bool keepdims_ = false; | bool keepdims_ = false; | ||||
| bool cross_batch_ = false; | bool cross_batch_ = false; | ||||
| ReduceMethodCostPtr reducemethodcost_ptr_; | |||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | Status CheckStrategy(const StrategyPtr &strategy) override; | ||||
| Status GetAttrs() override; | Status GetAttrs() override; | ||||
| Dimensions InferOutputStrategy(); | Dimensions InferOutputStrategy(); | ||||
| @@ -110,7 +106,7 @@ class ReduceMeanInfo : public ReduceMethod { | |||||
| ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { | ||||
| reducemethodcost_ptr_ = std::make_shared<ReduceMeanCost>(); | |||||
| set_cost(std::make_shared<ReduceMeanCost>()); | |||||
| } | } | ||||
| ~ReduceMeanInfo() override = default; | ~ReduceMeanInfo() override = default; | ||||
| @@ -36,12 +36,10 @@ class ReshapeInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs), | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()), | |||||
| dev_num_(0), | dev_num_(0), | ||||
| input_layout_set_flag_(false), | input_layout_set_flag_(false), | ||||
| output_layout_set_flag_(false) { | |||||
| reshape_cost_ptr_ = std::make_shared<ReshapeCost>(); | |||||
| } | |||||
| output_layout_set_flag_(false) {} | |||||
| ~ReshapeInfo() override = default; | ~ReshapeInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| void SetInputLayout(const TensorLayout& input_layout) { | void SetInputLayout(const TensorLayout& input_layout) { | ||||
| @@ -55,7 +53,6 @@ class ReshapeInfo : public OperatorInfo { | |||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return reshape_cost_ptr_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | Status CheckStrategy(const StrategyPtr& strategy) override; | ||||
| @@ -67,7 +64,6 @@ class ReshapeInfo : public OperatorInfo { | |||||
| Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); | Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); | ||||
| Status GetAttrs() override; | Status GetAttrs() override; | ||||
| Strategys GetOutputsStrategy(); | Strategys GetOutputsStrategy(); | ||||
| ReshapeCostPtr reshape_cost_ptr_; | |||||
| private: | private: | ||||
| Status GetParameterInput(); | Status GetParameterInput(); | ||||
| @@ -34,9 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, | TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, | ||||
| const std::string& name = IDENTITY_INFO) | const std::string& name = IDENTITY_INFO) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| id_cost_ptr_ = std::make_shared<TmpIdentityCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {} | |||||
| ~TmpIdentityInfo() override = default; | ~TmpIdentityInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| @@ -44,7 +42,6 @@ class TmpIdentityInfo : public OperatorInfo { | |||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return id_cost_ptr_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | Status CheckStrategy(const StrategyPtr& strategy) override; | ||||
| @@ -54,9 +51,6 @@ class TmpIdentityInfo : public OperatorInfo { | |||||
| Status InferTensorInfo() override; | Status InferTensorInfo() override; | ||||
| Status InferDevMatrixShape() override; | Status InferDevMatrixShape() override; | ||||
| Status InferTensorMap() override; | Status InferTensorMap() override; | ||||
| private: | |||||
| TmpIdentityCostPtr id_cost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,15 +35,12 @@ class TransposeInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| transpose_cost_ptr_ = std::make_shared<ActivationCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {} | |||||
| ~TransposeInfo() override = default; | ~TransposeInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return transpose_cost_ptr_; } | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | Status CheckStrategy(const StrategyPtr& strategy) override; | ||||
| @@ -60,7 +57,6 @@ class TransposeInfo : public OperatorInfo { | |||||
| Status ComputeAxis(); | Status ComputeAxis(); | ||||
| std::vector<int32_t> axis_v_; | std::vector<int32_t> axis_v_; | ||||
| Dimensions input_strategy_; | Dimensions input_strategy_; | ||||
| ActivationCostPtr transpose_cost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -32,16 +32,13 @@ class VirtualDatasetInfo : public OperatorInfo { | |||||
| public: | public: | ||||
| VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | ||||
| const PrimitiveAttrs& attrs) | const PrimitiveAttrs& attrs) | ||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs) { | |||||
| vd_cost_ptr_ = std::make_shared<VirtualDatasetCost>(); | |||||
| } | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {} | |||||
| ~VirtualDatasetInfo() override = default; | ~VirtualDatasetInfo() override = default; | ||||
| Status Init(const StrategyPtr& strategy) override; | Status Init(const StrategyPtr& strategy) override; | ||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | Status InitForCostModel(const StrategyPtr& strategy) override; | ||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | ||||
| OperatorCostPtr GetOperatorCost() const override { return vd_cost_ptr_; } | |||||
| void ReComputeBatchSplitFlagList() override; | void ReComputeBatchSplitFlagList() override; | ||||
| protected: | protected: | ||||
| @@ -53,9 +50,6 @@ class VirtualDatasetInfo : public OperatorInfo { | |||||
| Status InferTensorMap() override; | Status InferTensorMap() override; | ||||
| Status GetAttrs() override; | Status GetAttrs() override; | ||||
| Status InferAsLossDivisor() override; | Status InferAsLossDivisor() override; | ||||
| private: | |||||
| VirtualDatasetCostPtr vd_cost_ptr_; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) { | |||||
| act_ptr_->InitForCostModel(sp); | act_ptr_->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | cost.computation_cost_); | ||||
| ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.communication_cost_); | cost.communication_cost_); | ||||
| } | } | ||||
| } | } | ||||
| @@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) { | |||||
| soft_ptr_->InitForCostModel(sp); | soft_ptr_->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | cost.computation_cost_); | ||||
| ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.communication_cost_); | cost.communication_cost_); | ||||
| } | } | ||||
| } | } | ||||
| @@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) { | |||||
| matmul1->InitForCostModel(sp); | matmul1->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | cost.computation_cost_); | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) { | |||||
| TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); | TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); | ||||
| replica_inputs_info.push_back(replica_input1_info); | replica_inputs_info.push_back(replica_input1_info); | ||||
| ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | cost.computation_cost_); | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) { | |||||
| tensor_add->InitForCostModel(sp); | tensor_add->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info(); | ||||
| double memory_cost0 = tensor_add->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost1 = cost.computation_cost_; | double memory_cost1 = cost.computation_cost_; | ||||
| bool memory = memory_cost0 - memory_cost1 <= 1.0; | bool memory = memory_cost0 - memory_cost1 <= 1.0; | ||||
| double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double comm_cost1 = cost.communication_cost_; | double comm_cost1 = cost.communication_cost_; | ||||
| bool comm = comm_cost0 - comm_cost1 <= 1.0; | bool comm = comm_cost0 - comm_cost1 <= 1.0; | ||||
| @@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { | |||||
| tensor_add1->InitForCostModel(sp); | tensor_add1->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info(); | ||||
| double memory_cost0 = tensor_add1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost1 = cost.computation_cost_; | double memory_cost1 = cost.computation_cost_; | ||||
| bool memory = memory_cost0 - memory_cost1 <= 1.0; | bool memory = memory_cost0 - memory_cost1 <= 1.0; | ||||
| double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double comm_cost1 = cost.communication_cost_; | double comm_cost1 = cost.communication_cost_; | ||||
| bool comm = comm_cost0 - comm_cost1 <= 1.0; | bool comm = comm_cost0 - comm_cost1 <= 1.0; | ||||
| @@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) { | |||||
| identity_ptr->Init(sp); | identity_ptr->Init(sp); | ||||
| std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | cost.computation_cost_); | ||||
| ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.communication_cost_); | cost.communication_cost_); | ||||
| } | } | ||||
| } | } | ||||