Merge pull request !2280 from chenjianping/host_reducetags/v0.5.0-beta
| @@ -179,8 +179,8 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec | |||||
| return result; | return result; | ||||
| } | } | ||||
| bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num, | |||||
| const std::string &op_type, float *output) { | |||||
| bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num, | |||||
| size_t output_size, const std::string &op_type, float *output) { | |||||
| int scatter_index = GetScatterIndex(rank_id_, ranks_group); | int scatter_index = GetScatterIndex(rank_id_, ranks_group); | ||||
| auto group = AddGroup(ranks_group); | auto group = AddGroup(ranks_group); | ||||
| if (group == MPI_GROUP_NULL) { | if (group == MPI_GROUP_NULL) { | ||||
| @@ -193,7 +193,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||||
| } | } | ||||
| MPI_Win window; | MPI_Win window; | ||||
| auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); | |||||
| auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); | |||||
| if (ret != MPI_SUCCESS) { | if (ret != MPI_SUCCESS) { | ||||
| MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; | MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; | ||||
| return false; | return false; | ||||
| @@ -205,18 +205,21 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto op = GetMpiOp(op_type); | auto op = GetMpiOp(op_type); | ||||
| ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op, | |||||
| window); | |||||
| ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, | |||||
| input_data_num, MPI_FLOAT, op, window); | |||||
| if (ret != MPI_SUCCESS) { | if (ret != MPI_SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; | MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; | ||||
| } | } | ||||
| } | } | ||||
| MPI_Win_fence(0, window); | MPI_Win_fence(0, window); | ||||
| if (output != nullptr) { | if (output != nullptr) { | ||||
| auto data_size = data_num * sizeof(float); | |||||
| auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size); | |||||
| auto data_size = input_data_num * sizeof(float); | |||||
| if (output_size < data_size) { | |||||
| MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size; | |||||
| } | |||||
| auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); | |||||
| if (copy_ret != 0) { | if (copy_ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "copy output memory fail!"; | |||||
| MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret; | |||||
| } | } | ||||
| } | } | ||||
| MPI_Win_free(&window); | MPI_Win_free(&window); | ||||
| @@ -224,7 +227,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | |||||
| bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | |||||
| if (ranks_group.empty()) { | if (ranks_group.empty()) { | ||||
| MS_LOG(ERROR) << "input rank group is empty!"; | MS_LOG(ERROR) << "input rank group is empty!"; | ||||
| return false; | return false; | ||||
| @@ -34,9 +34,10 @@ class MPIAdapter { | |||||
| int GetRankId() const; | int GetRankId() const; | ||||
| bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | ||||
| const std::string &op_type = kOpTypeSum); | const std::string &op_type = kOpTypeSum); | ||||
| bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num, | |||||
| const std::string &op_type = kOpTypeSum, float *output = nullptr); | |||||
| bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | |||||
| bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num, | |||||
| size_t output_size, const std::string &op_type = kOpTypeSum, | |||||
| float *output = nullptr); | |||||
| bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | |||||
| private: | private: | ||||
| MPIAdapter(); | MPIAdapter(); | ||||
| @@ -26,21 +26,11 @@ constexpr auto kRanksGroup = "group"; | |||||
| constexpr auto kAllGatherInputNum = 1; | constexpr auto kAllGatherInputNum = 1; | ||||
| } // namespace | } // namespace | ||||
| AllGatherCPUKernel::AllGatherCPUKernel() : input_data_number_(0) {} | |||||
| void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != kAllGatherInputNum) { | if (input_num != kAllGatherInputNum) { | ||||
| MS_LOG(EXCEPTION) << "allgather input num:" << input_num; | MS_LOG(EXCEPTION) << "allgather input num:" << input_num; | ||||
| } | } | ||||
| for (size_t i = 0; i < input_num; ++i) { | |||||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||||
| size_t count = 1; | |||||
| for (size_t j = 0; j < shape.size(); j++) { | |||||
| count *= IntToSize(shape[j]); | |||||
| } | |||||
| input_data_number_ += count; | |||||
| } | |||||
| auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); | auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); | ||||
| if (ranks_group != nullptr) { | if (ranks_group != nullptr) { | ||||
| @@ -55,8 +45,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | ||||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | ||||
| auto input_data_num = inputs[0]->size / sizeof(float); | |||||
| return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_number_); | |||||
| return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num); | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| class AllGatherCPUKernel : public CPUKernel { | class AllGatherCPUKernel : public CPUKernel { | ||||
| public: | public: | ||||
| AllGatherCPUKernel(); | |||||
| AllGatherCPUKernel() = default; | |||||
| ~AllGatherCPUKernel() override = default; | ~AllGatherCPUKernel() override = default; | ||||
| void InitKernel(const CNodePtr &kernel_node) override; | void InitKernel(const CNodePtr &kernel_node) override; | ||||
| @@ -33,7 +33,6 @@ class AllGatherCPUKernel : public CPUKernel { | |||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| size_t input_data_number_; | |||||
| std::vector<int> ranks_group_; | std::vector<int> ranks_group_; | ||||
| }; | }; | ||||
| @@ -24,18 +24,9 @@ namespace { | |||||
| constexpr auto kRanksGroup = "group"; | constexpr auto kRanksGroup = "group"; | ||||
| } // namespace | } // namespace | ||||
| ReduceScatterCPUKernel::ReduceScatterCPUKernel() : output_data_number_(0), op_type_(device::cpu::kOpTypeSum) {} | |||||
| ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} | |||||
| void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| for (size_t i = 0; i < output_num; ++i) { | |||||
| auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); | |||||
| size_t size = 1; | |||||
| for (size_t j = 0; j < shape.size(); j++) { | |||||
| size *= IntToSize(shape[j]); | |||||
| } | |||||
| output_data_number_ += size; | |||||
| } | |||||
| auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); | auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); | ||||
| if (op != nullptr) { | if (op != nullptr) { | ||||
| op_type_ = GetValue<std::string>(op); | op_type_ = GetValue<std::string>(op); | ||||
| @@ -54,8 +45,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | ||||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | ||||
| auto output_data_num = outputs[0]->size / sizeof(float); | |||||
| return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_number_, | |||||
| return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, | |||||
| op_type_); | op_type_); | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -33,7 +33,6 @@ class ReduceScatterCPUKernel : public CPUKernel { | |||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| size_t output_data_number_; | |||||
| std::string op_type_; | std::string op_type_; | ||||
| std::vector<int> ranks_group_; | std::vector<int> ranks_group_; | ||||
| }; | }; | ||||