Merge pull request !2280 from chenjianping/host_reducetags/v0.5.0-beta
| @@ -179,8 +179,8 @@ bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vec | |||
| return result; | |||
| } | |||
| bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num, | |||
| const std::string &op_type, float *output) { | |||
| bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num, | |||
| size_t output_size, const std::string &op_type, float *output) { | |||
| int scatter_index = GetScatterIndex(rank_id_, ranks_group); | |||
| auto group = AddGroup(ranks_group); | |||
| if (group == MPI_GROUP_NULL) { | |||
| @@ -193,7 +193,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||
| } | |||
| MPI_Win window; | |||
| auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); | |||
| auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; | |||
| return false; | |||
| @@ -205,18 +205,21 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||
| continue; | |||
| } | |||
| auto op = GetMpiOp(op_type); | |||
| ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op, | |||
| window); | |||
| ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, | |||
| input_data_num, MPI_FLOAT, op, window); | |||
| if (ret != MPI_SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; | |||
| } | |||
| } | |||
| MPI_Win_fence(0, window); | |||
| if (output != nullptr) { | |||
| auto data_size = data_num * sizeof(float); | |||
| auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size); | |||
| auto data_size = input_data_num * sizeof(float); | |||
| if (output_size < data_size) { | |||
| MS_LOG(EXCEPTION) << "output buffer size " << output_size << " < input size " << data_size; | |||
| } | |||
| auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); | |||
| if (copy_ret != 0) { | |||
| MS_LOG(EXCEPTION) << "copy output memory fail!"; | |||
| MS_LOG(EXCEPTION) << "copy output memory fail!ret = " << copy_ret; | |||
| } | |||
| } | |||
| MPI_Win_free(&window); | |||
| @@ -224,7 +227,7 @@ bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector<int | |||
| return true; | |||
| } | |||
| bool MPIAdapter::AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | |||
| bool MPIAdapter::AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num) { | |||
| if (ranks_group.empty()) { | |||
| MS_LOG(ERROR) << "input rank group is empty!"; | |||
| return false; | |||
| @@ -34,9 +34,10 @@ class MPIAdapter { | |||
| int GetRankId() const; | |||
| bool ReduceScatter(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num, | |||
| const std::string &op_type = kOpTypeSum); | |||
| bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t data_num, | |||
| const std::string &op_type = kOpTypeSum, float *output = nullptr); | |||
| bool AllGather(float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | |||
| bool ReduceScatterOverwriteInput(float *input, const std::vector<int> &ranks_group, size_t input_data_num, | |||
| size_t output_size, const std::string &op_type = kOpTypeSum, | |||
| float *output = nullptr); | |||
| bool AllGather(const float *input, float *output, const std::vector<int> &ranks_group, size_t data_num); | |||
| private: | |||
| MPIAdapter(); | |||
| @@ -26,21 +26,11 @@ constexpr auto kRanksGroup = "group"; | |||
| constexpr auto kAllGatherInputNum = 1; | |||
| } // namespace | |||
| AllGatherCPUKernel::AllGatherCPUKernel() : input_data_number_(0) {} | |||
| void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != kAllGatherInputNum) { | |||
| MS_LOG(EXCEPTION) << "allgather input num:" << input_num; | |||
| } | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); | |||
| size_t count = 1; | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| count *= IntToSize(shape[j]); | |||
| } | |||
| input_data_number_ += count; | |||
| } | |||
| auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); | |||
| if (ranks_group != nullptr) { | |||
| @@ -55,8 +45,9 @@ bool AllGatherCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| auto input_data_num = inputs[0]->size / sizeof(float); | |||
| return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_number_); | |||
| return device::cpu::MPIAdapter::Instance().AllGather(input_addr, output_addr, ranks_group_, input_data_num); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||
| namespace kernel { | |||
| class AllGatherCPUKernel : public CPUKernel { | |||
| public: | |||
| AllGatherCPUKernel(); | |||
| AllGatherCPUKernel() = default; | |||
| ~AllGatherCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| @@ -33,7 +33,6 @@ class AllGatherCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| size_t input_data_number_; | |||
| std::vector<int> ranks_group_; | |||
| }; | |||
| @@ -24,18 +24,9 @@ namespace { | |||
| constexpr auto kRanksGroup = "group"; | |||
| } // namespace | |||
| ReduceScatterCPUKernel::ReduceScatterCPUKernel() : output_data_number_(0), op_type_(device::cpu::kOpTypeSum) {} | |||
| ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} | |||
| void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); | |||
| size_t size = 1; | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| size *= IntToSize(shape[j]); | |||
| } | |||
| output_data_number_ += size; | |||
| } | |||
| auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); | |||
| if (op != nullptr) { | |||
| op_type_ = GetValue<std::string>(op); | |||
| @@ -54,8 +45,9 @@ bool ReduceScatterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| auto output_data_num = outputs[0]->size / sizeof(float); | |||
| return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_number_, | |||
| return device::cpu::MPIAdapter::Instance().ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, | |||
| op_type_); | |||
| } | |||
| } // namespace kernel | |||
| @@ -33,7 +33,6 @@ class ReduceScatterCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| size_t output_data_number_; | |||
| std::string op_type_; | |||
| std::vector<int> ranks_group_; | |||
| }; | |||