| @@ -0,0 +1,128 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh" | |||
| __device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; } | |||
| template <typename T> | |||
| __global__ void FillWithoutBroadcast(const size_t size, const T *src, T *dst) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| dst[pos] = src[pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void FillAndBroadcast(const size_t size, const size_t shape_size, const size_t *src_shape, | |||
| const size_t *dst_shape, const T *src, T *dst) { | |||
| size_t dst_index_array[MAX_LOGITS_DIMENSION]; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| size_t tmp_pos = pos; | |||
| size_t pos_size = size / dst_shape[0]; | |||
| dst_index_array[0] = tmp_pos / pos_size; | |||
| for (size_t i = 1; i < shape_size; i++) { | |||
| tmp_pos -= dst_index_array[i - 1] * pos_size; | |||
| pos_size = pos_size / dst_shape[i]; | |||
| dst_index_array[i] = tmp_pos / pos_size; | |||
| } | |||
| size_t src_pos = 0; | |||
| size_t src_size = 1; | |||
| for (size_t i = 0; i < shape_size; i++) { | |||
| src_size *= src_shape[i]; | |||
| } | |||
| for (size_t i = 0; i < shape_size; i++) { | |||
| src_size /= src_shape[i]; | |||
| size_t length_by_index = Index(dst_index_array[i], src_shape[i]) * src_size; | |||
| src_pos += length_by_index; | |||
| } | |||
| dst[pos] = src[src_pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void BCEWithLogitsLossMain(size_t size, const T *predict, const T *target, const T *shape_broadcasted, | |||
| T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| T max_value = -predict[pos]; | |||
| max_value = max_value > static_cast<T>(0) ? max_value : static_cast<T>(0); | |||
| const T log_weight = (shape_broadcasted[pos] - static_cast<T>(1)) * target[pos] + static_cast<T>(1); | |||
| output[pos] = (static_cast<T>(1) - target[pos]) * predict[pos] + | |||
| log_weight * (log(exp(-max_value) + exp(-predict[pos] - max_value)) + max_value); | |||
| } | |||
| return; | |||
| } | |||
| template <> | |||
| __global__ void BCEWithLogitsLossMain(size_t size, const half *predict, const half *target, | |||
| const half *shape_broadcasted, half *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| half max_value = -predict[pos]; | |||
| max_value = max_value > static_cast<half>(0) ? max_value : static_cast<half>(0); | |||
| const half log_weight = (shape_broadcasted[pos] - static_cast<half>(1)) * target[pos] + static_cast<half>(1); | |||
| output[pos] = (static_cast<half>(1) - target[pos]) * predict[pos] + | |||
| log_weight * (hlog(hexp(-max_value) + hexp(-predict[pos] - max_value)) + max_value); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void Mul(size_t size, const T *lhs, T *rhs) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| rhs[pos] *= lhs[pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalBCEWithLogitsLoss(const size_t input_size, const T *predict, const T *target, const size_t *input_shape, | |||
| const size_t shape_size, const T *weight, const size_t *weight_shape, | |||
| const bool weight_need_broadcast, const T *pos_weight, const size_t *pos_weight_shape, | |||
| const bool pos_weight_need_broadcast, T *shape_broadcasted, T *output, | |||
| cudaStream_t cuda_stream) { | |||
| if (pos_weight_need_broadcast) { | |||
| FillAndBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>( | |||
| input_size, shape_size, pos_weight_shape, input_shape, pos_weight, shape_broadcasted); | |||
| } else { | |||
| FillWithoutBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, pos_weight, | |||
| shape_broadcasted); | |||
| } | |||
| BCEWithLogitsLossMain<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, predict, target, | |||
| shape_broadcasted, output); | |||
| if (weight_need_broadcast) { | |||
| FillAndBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, shape_size, weight_shape, | |||
| input_shape, weight, shape_broadcasted); | |||
| } else { | |||
| FillWithoutBroadcast<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, weight, | |||
| shape_broadcasted); | |||
| } | |||
| Mul<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, shape_broadcasted, output); | |||
| return; | |||
| } | |||
| template void CalBCEWithLogitsLoss<half>(const size_t input_size, const half *predict, const half *target, | |||
| const size_t *input_shape, const size_t shape_size, const half *weight, | |||
| const size_t *weight_shape, const bool weight_need_broadcast, | |||
| const half *pos_weight, const size_t *pos_weight_shape, | |||
| const bool pos_weight_need_broadcast, half *shape_broadcasted, half *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalBCEWithLogitsLoss<float>(const size_t input_size, const float *predict, const float *target, | |||
| const size_t *input_shape, const size_t shape_size, const float *weight, | |||
| const size_t *weight_shape, const bool weight_need_broadcast, | |||
| const float *pos_weight, const size_t *pos_weight_shape, | |||
| const bool pos_weight_need_broadcast, float *shape_broadcasted, float *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_ | |||
| #define MAX_LOGITS_DIMENSION 100 | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalBCEWithLogitsLoss(const size_t input_size, const T *predict, const T *target, const size_t *input_shape, | |||
| const size_t shape_size, const T *weight, const size_t *weight_shape, | |||
| const bool weight_need_broadcast, const T *pos_weight, const size_t *pos_weight_shape, | |||
| const bool pos_weight_need_broadcast, T *shape_broadcasted, T *output, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_BCE_WITH_LOGITS_LOSS_IMPL_CUH_ | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/bce_with_logits_loss_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BCEWithLogitsLoss, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| BCEWithLogitsLossKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BCEWithLogitsLoss, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| BCEWithLogitsLossKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,167 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/bce_with_logits_loss_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BCEWithLogitsLossKernel : public GpuKernel { | |||
| public: | |||
| BCEWithLogitsLossKernel() { ResetResource(); } | |||
| ~BCEWithLogitsLossKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *predict = GetDeviceAddress<T>(inputs, 0); | |||
| T *target = GetDeviceAddress<T>(inputs, 1); | |||
| T *weight = GetDeviceAddress<T>(inputs, 2); | |||
| T *pos_weight = GetDeviceAddress<T>(inputs, 3); | |||
| size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||
| size_t *weight_shape = GetDeviceAddress<size_t>(workspace, 1); | |||
| size_t *pos_weight_shape = GetDeviceAddress<size_t>(workspace, 2); | |||
| T *shape_broadcasted = GetDeviceAddress<T>(workspace, 3); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(input_shape, &input_shape_[0], input_shape_.size() * sizeof(size_t), | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_shape_ failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(weight_shape, &weight_shape_[0], weight_shape_.size() * sizeof(size_t), | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync weight_shape_ failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||
| kernel_node_, | |||
| cudaMemcpyAsync(pos_weight_shape, &pos_weight_shape_[0], pos_weight_shape_.size() * sizeof(size_t), | |||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync pos_weight_shape_ failed"); | |||
| CalBCEWithLogitsLoss(input_size_, predict, target, input_shape, input_shape_.size(), weight, weight_shape, | |||
| weight_need_broadcast_, pos_weight, pos_weight_shape, pos_weight_need_broadcast_, | |||
| shape_broadcasted, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 4) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but BCEWithLogitsLoss needs 4 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but BCEWithLogitsLoss has 1 output."; | |||
| return false; | |||
| } | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_size_ = 1; | |||
| if (input_shape_.size() > MAX_LOGITS_DIMENSION) { | |||
| MS_LOG(EXCEPTION) << "Input dimension is " << input_shape_.size() | |||
| << ", but BCEWithLogitsLoss can only support up to " << MAX_LOGITS_DIMENSION << "-D."; | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < input_shape_.size(); i++) { | |||
| input_size_ *= input_shape_[i]; | |||
| } | |||
| // weight shape | |||
| weight_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| weight_size_ = 1; | |||
| for (size_t i = 0; i < weight_shape_.size(); i++) { | |||
| weight_size_ *= weight_shape_[i]; | |||
| } | |||
| weight_need_broadcast_ = NeedBroadcast(&weight_shape_, input_shape_); | |||
| // pos_weight shape | |||
| pos_weight_size_ = 1; | |||
| pos_weight_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); | |||
| for (size_t i = 0; i < pos_weight_shape_.size(); i++) { | |||
| pos_weight_size_ *= pos_weight_shape_[i]; | |||
| } | |||
| pos_weight_need_broadcast_ = NeedBroadcast(&pos_weight_shape_, input_shape_); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| input_size_ = 1; | |||
| weight_size_ = 1; | |||
| pos_weight_size_ = 1; | |||
| weight_need_broadcast_ = false; | |||
| pos_weight_need_broadcast_ = false; | |||
| input_shape_.clear(); | |||
| weight_shape_.clear(); | |||
| pos_weight_shape_.clear(); | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| input_size_list_.push_back(weight_size_ * sizeof(T)); | |||
| input_size_list_.push_back(pos_weight_size_ * sizeof(T)); | |||
| workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t)); | |||
| workspace_size_list_.push_back(weight_shape_.size() * sizeof(size_t)); | |||
| workspace_size_list_.push_back(pos_weight_shape_.size() * sizeof(size_t)); | |||
| // extra space for holding extra array shape of input, for broadcasted | |||
| // weight and pos_weight | |||
| workspace_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| bool NeedBroadcast(std::vector<size_t> *shape, const std::vector<size_t> &result_shape) { | |||
| // result_shape is larger that shape | |||
| // and shape is able to broadcasted to result_shape | |||
| if (shape->size() != result_shape.size()) { | |||
| size_t fill_size = result_shape.size() - shape->size(); | |||
| (void)shape->insert(shape->begin(), fill_size, 1); | |||
| return true; | |||
| } | |||
| for (size_t i = 0; i < result_shape.size(); i++) { | |||
| if (shape->at(i) != result_shape[i]) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| size_t input_size_; | |||
| size_t weight_size_; | |||
| size_t pos_weight_size_; | |||
| bool weight_need_broadcast_; | |||
| bool pos_weight_need_broadcast_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> weight_shape_; | |||
| std::vector<size_t> pos_weight_shape_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BCE_WITH_LOGITS_LOSS_KERNEL_H_ | |||
| @@ -0,0 +1,93 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h" | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/utils.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "runtime/device/gpu/kernel_info_setter.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> node_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))}; | |||
| (void)node_inputs.insert(node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| CNodePtr new_cnode = func_graph->NewCNode(node_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| auto predict_input = cnode->inputs()[1]; | |||
| auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)}; | |||
| auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get()); | |||
| // Add reduce node | |||
| string reduction = AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction); | |||
| MS_LOG(INFO) << "Create reduce node for BCEWithLogitsLoss, reduction attr is: " << reduction; | |||
| std::vector<AnfNodePtr> reduce_inputs; | |||
| if (reduction == "sum") { | |||
| reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode}; | |||
| } else if (reduction == "mean") { | |||
| reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode}; | |||
| } else { | |||
| MS_LOG(INFO) << "Reduction is none, no optimization on current BCEWithLogitsLoss."; | |||
| return nullptr; | |||
| } | |||
| auto reduce_node = func_graph->NewCNode(reduce_inputs); | |||
| MS_EXCEPTION_IF_NULL(reduce_node); | |||
| auto type = AnfAlgo::GetOutputInferDataType(node, 0); | |||
| auto shape = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node); | |||
| AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node); | |||
| reduce_node->set_scope(cnode->scope()); | |||
| return reduce_node; | |||
| } | |||
| const BaseRef BCEWithLogitsLossFusion::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs}); | |||
| } | |||
| const AnfNodePtr BCEWithLogitsLossFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (GetBoolAttr(cnode, kAttrVisited)) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| if (cnode->inputs().size() == 0) { | |||
| return nullptr; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr("reduction", cnode)) { | |||
| MS_LOG(INFO) << "Primitive BCEWithLogitsLoss doesn't not have reduction attr."; | |||
| return nullptr; | |||
| } | |||
| return AddReduceNode(func_graph, node); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BCEWithLogitsLossFusion : public PatternProcessPass { | |||
| public: | |||
| explicit BCEWithLogitsLossFusion(bool multigraph = true) | |||
| : PatternProcessPass("bce_with_logits_loss_fusion", multigraph) {} | |||
| ~BCEWithLogitsLossFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_BCE_WITH_LOGITS_FUSION_H_ | |||
| @@ -38,6 +38,7 @@ | |||
| #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" | |||
| #include "backend/optimizer/gpu/replace_addn_fusion.h" | |||
| #include "backend/optimizer/gpu/print_reduce_fusion.h" | |||
| #include "backend/optimizer/gpu/bce_with_logits_loss_fusion.h" | |||
| #include "backend/optimizer/gpu/remove_format_transform_pair.h" | |||
| #include "backend/optimizer/gpu/remove_redundant_format_transform.h" | |||
| #include "backend/optimizer/gpu/reduce_precision_fusion.h" | |||
| @@ -143,6 +144,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||
| pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>()); | |||
| pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce")); | |||
| pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -0,0 +1,102 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import math | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class Net(nn.Cell): | |||
| def __init__(self, reduction): | |||
| super(Net, self).__init__() | |||
| self.loss = P.BCEWithLogitsLoss(reduction=reduction) | |||
| def construct(self, predict, target, weight, pos_weight): | |||
| return self.loss(predict, target, weight, pos_weight) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reduction_none_testcases(): | |||
| # fp32 + both modes | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| loss = Net("none") | |||
| predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) | |||
| target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) | |||
| weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) | |||
| pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) | |||
| output = loss(predict, target, weight, pos_weight) | |||
| expected = np.array([[0.6111006, 0.5032824, 0.26318598], | |||
| [0.58439666, 0.55301523, -0.436814]]).astype(np.float32) | |||
| np.testing.assert_almost_equal(expected, output.asnumpy()) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| loss = Net("none") | |||
| predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) | |||
| target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) | |||
| weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) | |||
| pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) | |||
| output = loss(predict, target, weight, pos_weight) | |||
| expected = np.array([[0.6111006, 0.5032824, 0.26318598], | |||
| [0.58439666, 0.55301523, -0.436814]]) | |||
| np.testing.assert_almost_equal(expected, output.asnumpy()) | |||
| # fp16 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| loss = Net("none") | |||
| predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float16)) | |||
| target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float16)) | |||
| weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float16)) | |||
| pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float16)) | |||
| output = loss(predict, target, weight, pos_weight) | |||
| expected = np.array([[0.611, 0.503, 0.2627], | |||
| [0.584, 0.5527, -0.437]]).astype(np.float16) | |||
| np.testing.assert_almost_equal(expected, output.asnumpy(), decimal=3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reduction_mean_testcases(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| loss = Net("mean") | |||
| predict = Tensor(np.arange(6).reshape(2, 3).astype(np.float32)) | |||
| target = Tensor(np.arange(34, 40).reshape(2, 3).astype(np.float32)) | |||
| weight = Tensor(np.array([2, 3, 1]).astype(np.float32)) | |||
| pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32)) | |||
| output = loss(predict, target, weight, pos_weight) | |||
| expected = -113.55404 | |||
| # assert scalar | |||
| assert math.isclose(output.asnumpy().tolist(), expected, abs_tol=0.00001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_reduction_sum_testcases(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| loss = Net("sum") | |||
| predict = Tensor(np.arange(6, 12).reshape(2, 3).astype(np.float32)) | |||
| target = Tensor(np.arange(6).reshape(2, 3).astype(np.float32)) | |||
| weight = Tensor(np.array([3, 3, 4]).astype(np.float32)) | |||
| pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32)) | |||
| output = loss(predict, target, weight, pos_weight) | |||
| expected = -333.96677 | |||
| # assert scalar | |||
| assert math.isclose(output.asnumpy().tolist(), expected, abs_tol=0.00001) | |||