From: @liu_xiao_93 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -32,6 +32,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" | |||
| #include "backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.h" | |||
| #include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h" | |||
| #include "backend/optimizer/pass/communication_op_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" | |||
| @@ -191,6 +192,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { | |||
| ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<GatherV2DsFission>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>()); | |||
| } | |||
| } // namespace | |||
| void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| @@ -333,6 +335,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<DynamicRnnGradFissionV2>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>()); | |||
| optimizer->AddPassManager(ir_fusion_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/ir_fission/bce_with_logits_loss_fission.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include "utils/utils.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/trace_base.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // Copy a new sigmoid node, shape of output is the same as input | |||
| std::vector<AnfNodePtr> new_simoid_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(prim::kPrimBCEWithLogitsLoss->name()))}; | |||
| new_simoid_inputs.insert(new_simoid_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| CNodePtr new_cnode = func_graph->NewCNode(new_simoid_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| auto predict_input = cnode->inputs()[1]; | |||
| auto new_node_dtype = {AnfAlgo::GetOutputInferDataType(predict_input, 0)}; | |||
| auto new_node_shape = {AnfAlgo::GetOutputInferShape(predict_input, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(new_node_dtype, new_node_shape, new_cnode.get()); | |||
| // Add reduce node | |||
| string reduction = AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction); | |||
| MS_LOG(INFO) << "Create reduce node, reduction attr is: " << reduction; | |||
| std::vector<AnfNodePtr> reduce_inputs; | |||
| if (reduction == "sum") { | |||
| reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode}; | |||
| } else if (reduction == "mean") { | |||
| reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode}; | |||
| } else { | |||
| MS_LOG(INFO) << "Reduction attr is not mean or sum, can not do fission."; | |||
| return nullptr; | |||
| } | |||
| auto reduce_node = func_graph->NewCNode(reduce_inputs); | |||
| MS_EXCEPTION_IF_NULL(reduce_node); | |||
| auto type = AnfAlgo::GetOutputInferDataType(node, 0); | |||
| if (type == kNumberTypeFloat16) { | |||
| type = kNumberTypeFloat32; | |||
| } | |||
| auto shape = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape({type}, shape, reduce_node.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node); | |||
| AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node); | |||
| AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_node); | |||
| reduce_node->set_scope(cnode->scope()); | |||
| return reduce_node; | |||
| } | |||
| } // namespace | |||
| const BaseRef BCEWithLogitsLossFission::DefinePattern() const { | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| MS_EXCEPTION_IF_NULL(Xs); | |||
| return VectorRef({prim::kPrimBCEWithLogitsLoss, Xs}); | |||
| } | |||
| const AnfNodePtr BCEWithLogitsLossFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (GetBoolAttr(cnode, kAttrVisited)) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| if (cnode->inputs().size() == 0) { | |||
| return nullptr; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr("reduction", cnode)) { | |||
| MS_LOG(INFO) << "Has no reduction attr."; | |||
| return nullptr; | |||
| } | |||
| return AddReduceNode(func_graph, node); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_ | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class BCEWithLogitsLossFission : public PatternProcessPass { | |||
| public: | |||
| explicit BCEWithLogitsLossFission(bool multigraph = true) | |||
| : PatternProcessPass("bce_with_logits_loss_fission", multigraph) {} | |||
| ~BCEWithLogitsLossFission() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_BCE_WITH_LOGITS_LOSS_FISSION_FISSION_H_ | |||
| @@ -344,6 +344,7 @@ constexpr auto kAttrWaitEventStream = "wait_event_stream"; | |||
| constexpr auto kAttrIndex = "index"; | |||
| constexpr auto kAttrSplitDim = "split_dim"; | |||
| constexpr auto kAttrNumSplit = "num_split"; | |||
| constexpr auto kAttrReduction = "reduction"; | |||
| constexpr auto kAttrOutputNum = "output_num"; | |||
| constexpr auto kAttrSizeSplits = "size_splits"; | |||
| constexpr auto kAttrOutputDefault = "output_default"; | |||
| @@ -282,6 +282,7 @@ inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Pri | |||
| inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam"); | |||
| inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay"); | |||
| inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD"); | |||
| inline const PrimitivePtr kPrimBCEWithLogitsLoss = std::make_shared<Primitive>("BCEWithLogitsLoss"); | |||
| inline const PrimitivePtr kPrimClipByNormNoDivSum = std::make_shared<Primitive>("ClipByNormNoDivSum"); | |||
| inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorMove"); | |||
| inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Normalize"); | |||
| @@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which | |||
| from .loss import L1Loss, MSELoss, SmoothL1Loss, \ | |||
| SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ | |||
| SampledSoftmaxLoss, DiceLoss | |||
| SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss | |||
| __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', | |||
| 'SoftmaxCrossEntropyWithLogits', 'BCELoss', | |||
| 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', | |||
| 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss'] | |||
| @@ -15,6 +15,7 @@ | |||
| """loss""" | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops.primitive import constexpr | |||
| @@ -739,3 +740,86 @@ class CosineEmbeddingLoss(_Loss): | |||
| output_unreduced = pos_part + neg_part | |||
| return self.get_loss(output_unreduced) | |||
| class BCEWithLogitsLoss(_Loss): | |||
| r""" | |||
| Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy | |||
| between the target and the output. | |||
| Sets input predict as `X`, input target as `Y`, output as `L`. Then, | |||
| .. math:: | |||
| p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}} | |||
| .. math:: | |||
| L_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})] | |||
| Then, | |||
| .. math:: | |||
| \ell(x, y) = \begin{cases} | |||
| L, & \text{if reduction} = \text{`none';}\\ | |||
| \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ | |||
| \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} | |||
| \end{cases} | |||
| Args: | |||
| reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". | |||
| If "none", do not perform reduction. Default:`mean`. | |||
| weight (Tensor, optional): A rescaling weight applied to the loss of each batch element. | |||
| If not None, it must can be broadcast to a tensor with shape of `predict`, | |||
| data type must be float16 or float32. Default: None. | |||
| pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the | |||
| number of classes. If not None, it must can be broadcast to a tensor with shape of `predict`, | |||
| data type must be float16 or float32. Default: None. | |||
| Inputs: | |||
| - **predict** (Tensor) - Input logits. The data type must be float16 or float32. | |||
| - **target** (Tensor) - Ground truth label. Has the same data type and shape with `predict`. | |||
| Outputs: | |||
| Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`. | |||
| Raises: | |||
| TypeError: If data type of `predict` or `target` is neither float16 nor float32. | |||
| TypeError: If `weight` or `pos_weight` is Parameter. | |||
| TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32. | |||
| ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`. | |||
| ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) | |||
| >>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) | |||
| >>> loss = nn.BCEWithLogitsLoss() | |||
| >>> output = loss(inputs, labels) | |||
| >>> print(output) | |||
| 0.3463612 | |||
| """ | |||
| def __init__(self, reduction='mean', weight=None, pos_weight=None): | |||
| super(BCEWithLogitsLoss, self).__init__() | |||
| self.bce_with_logits_loss = P.BCEWithLogitsLoss(reduction=reduction) | |||
| if isinstance(weight, Parameter): | |||
| raise TypeError(f"For {self.cls_name}, weight can not be Parameter.") | |||
| if isinstance(pos_weight, Parameter): | |||
| raise TypeError(f"For {self.cls_name}, pos_weight can not be Parameter.") | |||
| self.weight = weight | |||
| self.pos_weight = pos_weight | |||
| self.ones = P.OnesLike() | |||
| def construct(self, predict, target): | |||
| ones_input = self.ones(predict) | |||
| if self.weight is not None: | |||
| weight = self.weight | |||
| else: | |||
| weight = ones_input | |||
| if self.pos_weight is not None: | |||
| pos_weight = self.pos_weight | |||
| else: | |||
| pos_weight = ones_input | |||
| loss = self.bce_with_logits_loss(predict, target, weight, pos_weight) | |||
| return loss | |||
| @@ -1212,6 +1212,32 @@ def get_bprop_binary_cross_entropy(self): | |||
| return bprop | |||
| @bprop_getters.register(P.BCEWithLogitsLoss) | |||
| def get_bprop_ce_with_logits_loss(self): | |||
| """Grad definition for `BCEWithLogitsLoss` operation.""" | |||
| reduction = self.reduction | |||
| mul = P.Mul() | |||
| sigmoid = P.Sigmoid() | |||
| add = P.TensorAdd() | |||
| sub = P.Sub() | |||
| size = P.Size() | |||
| def bprop(predict, target, weight, pos_weight, out, dout): | |||
| sigmoid_input = sigmoid(predict) | |||
| if pos_weight is not None: | |||
| t = mul(target, pos_weight) | |||
| dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) | |||
| else: | |||
| dx = mul((sigmoid_input - target), dout) | |||
| if weight is not None: | |||
| dx = mul(dx, weight) | |||
| if reduction == 'mean': | |||
| dx = dx / size(dx) | |||
| return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight) | |||
| return bprop | |||
| @bprop_getters.register(P.KLDivLoss) | |||
| def get_bprop_kl_div_loss(self): | |||
| """Grad definition for `KLDivLoss` operation.""" | |||
| @@ -254,6 +254,7 @@ from .prelu import _prelu_tbe | |||
| from .prelu_grad import _prelu_grad_tbe | |||
| from .binary_cross_entropy import _binary_cross_entropy_tbe | |||
| from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe | |||
| from .bce_with_logits_loss import _bce_with_logits_loss_op_tbe | |||
| from .sin import _sin_tbe | |||
| from .cos import _cos_tbe | |||
| from .tan import _tan_tbe | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed unde:q!r the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BCEWithLogitsLoss op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bce_with_logits_loss_op_info = TBERegOp("BCEWithLogitsLoss") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sigmoid_cross_entropy_with_logits_v2.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sigmoid_cross_entropy_with_logits_v2") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .attr("reduction", "optional", "str", "all", "mean") \ | |||
| .input(0, "predict", False, "required", "all") \ | |||
| .input(1, "target", False, "required", "all") \ | |||
| .input(2, "weight", False, "optional", "all") \ | |||
| .input(3, "pos_weight", False, "optional", "all") \ | |||
| .output(0, "loss", False, "required", "all") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||
| DataType.None_None) \ | |||
| .get_op_info() | |||
| @op_info_register(bce_with_logits_loss_op_info) | |||
| def _bce_with_logits_loss_op_tbe(): | |||
| """BCEWithLogitsLoss TBE register""" | |||
| return | |||
| @@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, SeLU, | |||
| SigmoidCrossEntropyWithLogits, NLLLoss, | |||
| SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss, | |||
| SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| @@ -149,6 +149,7 @@ __all__ = [ | |||
| 'Softsign', | |||
| 'LogSoftmax', | |||
| 'SoftmaxCrossEntropyWithLogits', | |||
| 'BCEWithLogitsLoss', | |||
| 'ROIAlign', | |||
| 'SparseSoftmaxCrossEntropyWithLogits', | |||
| 'NLLLoss', | |||
| @@ -20,7 +20,6 @@ import operator | |||
| from functools import reduce, partial | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import _check_3d_int_or_tuple | |||
| from mindspore import log as logger | |||
| import numpy as np | |||
| from ... import context | |||
| from .. import signature as sig | |||
| @@ -3701,6 +3700,99 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class BCEWithLogitsLoss(PrimitiveWithInfer): | |||
| r""" | |||
| Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy | |||
| between the target and the output. | |||
| Sets input predict as `X`, input target as `Y`, output as `L`. Then, | |||
| .. math:: | |||
| p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}} | |||
| .. math:: | |||
| L_{ij} = -[Y_{ij} * log(p_{ij}) + (1 - Y_{ij})log(1 - p_{ij})] | |||
| Then, | |||
| .. math:: | |||
| \ell(x, y) = \begin{cases} | |||
| L, & \text{if reduction} = \text{`none';}\\ | |||
| \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ | |||
| \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} | |||
| \end{cases} | |||
| Args: | |||
| reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". | |||
| If "none", do not perform reduction. Default:`mean`. | |||
| Inputs: | |||
| - **predict** (Tensor) - Input logits. Data type must be float16 or float32. | |||
| - **target** (Tensor) - Ground truth label. Has the same shape with `predict`. | |||
| Data type must be float16 or float32. | |||
| - **weight** (Tensor) - A rescaling weight applied to the loss of each batch element. It must can be | |||
| broadcast to a tensor with shape of `predict`. Data type must be float16 or float32. | |||
| - **pos_weight** (Tensor) - A weight of positive examples. Must be a vector with length equal to the | |||
| number of classes. It must can be broadcast to a tensor with shape of `predict`. | |||
| Data type must be float16 or float32. | |||
| Outputs: | |||
| Scalar. If reduction is "none", it's a tensor with the same shape and type as input `predict`. | |||
| Raises: | |||
| TypeError: If data type of any input is neither float16 nor float32. | |||
| ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `predict`. | |||
| ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) | |||
| >>> target = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32)) | |||
| >>> weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) | |||
| >>> pos_weight = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32)) | |||
| >>> loss = ops.BCEWithLogitsLoss() | |||
| >>> output = loss(predict, target, weight, pos_weight) | |||
| >>> print(output) | |||
| 0.3463612 | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, reduction='mean'): | |||
| """Initialize BCEWithLogitsLoss""" | |||
| self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) | |||
| def infer_shape(self, predict, target, weight, pos_weight): | |||
| validator.check('predict_shape', predict, 'target_shape', target, Rel.EQ, self.name) | |||
| reversed_weight_shape = tuple(reversed(weight)) | |||
| reversed_target = tuple(reversed(predict)) | |||
| for i, v in enumerate(reversed_weight_shape): | |||
| if v not in (reversed_target[i], 1): | |||
| raise ValueError(f"For {self.name}, shapes can not broadcast. " | |||
| f"predict: {tuple(predict)}, weight shape {tuple(weight)}.") | |||
| reversed_pos_shape = tuple(reversed(pos_weight)) | |||
| reversed_target = tuple(reversed(predict)) | |||
| for i, v in enumerate(reversed_pos_shape): | |||
| if v not in (reversed_target[i], 1): | |||
| raise ValueError(f"For {self.name}, shapes can not broadcast. " | |||
| f"predict: {tuple(predict)}, weight shape {tuple(weight)}.") | |||
| if self.reduction in ('mean', 'sum'): | |||
| shape = [] | |||
| else: | |||
| shape = predict | |||
| return shape | |||
| def infer_dtype(self, predict, target, weight, pos_weight): | |||
| validator.check_tensor_dtype_valid('predict dtype', predict, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('target dtype', target, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('weight dtype', weight, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('pos_weight dtype', pos_weight, [mstype.float16, mstype.float32], self.name) | |||
| return predict | |||
| class Pad(PrimitiveWithInfer): | |||
| """ | |||
| Pads the input tensor according to the paddings. | |||
| @@ -2058,6 +2058,10 @@ test_case_nn_ops = [ | |||
| 'block': P.L2Loss(), | |||
| 'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)], | |||
| 'desc_bprop': []}), | |||
| ('BCEWithLogitsLoss', { | |||
| 'block': P.BCEWithLogitsLoss(), | |||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], | |||
| 'desc_bprop': []}), | |||
| ('ResizeBilinear', { | |||
| 'block': P.ResizeBilinear((5, 5)), | |||
| 'desc_inputs': [Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mstype.float16)], | |||