From: @jiangzg001 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -147,6 +147,8 @@ constexpr const char kNameCumSum[] = "CumSum"; | |||
| constexpr const char kNameHuberLossGrad[] = "HuberLossGrad"; | |||
| constexpr const char kNameSparseSoftmaxCrossEntropy[] = "SparseSoftmaxCrossEntropy"; | |||
| constexpr const char kNameSparseSoftmaxCrossEntropyGrad[] = "SparseSoftmaxCrossEntropyGrad"; | |||
| constexpr const char kNameNLLLoss[] = "NLLLoss"; | |||
| constexpr const char kNameNLLLossGrad[] = "NLLLossGrad"; | |||
| constexpr const char kNameTopK[] = "TopK"; | |||
| constexpr const char kNameSoftmaxGrad[] = "SoftmaxGrad"; | |||
| constexpr const char kNameMaxPool[] = "MaxPool"; | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "transform/graph_ir/op_declare/math_ops_declare.h" | |||
| namespace mindspore::transform { | |||
| // NLLLoss | |||
| INPUT_MAP(NLLLoss) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(target)}, {3, INPUT_DESC(weight)}}; | |||
| ATTR_MAP(NLLLoss) = {{"reduction", ATTR_DESC(reduction, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(NLLLoss) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(total_weight)}}; | |||
| REG_ADPT_DESC(NLLLoss, kNameNLLLoss, ADPT_DESC(NLLLoss)) | |||
| // NLLLossGrad | |||
| INPUT_MAP(NLLLossGrad) = {{1, INPUT_DESC(x)}, | |||
| {2, INPUT_DESC(y_grad)}, | |||
| {3, INPUT_DESC(target)}, | |||
| {4, INPUT_DESC(weight)}, | |||
| {5, INPUT_DESC(total_weight)}}; | |||
| ATTR_MAP(NLLLossGrad) = {{"reduction", ATTR_DESC(reduction, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(NLLLossGrad) = {{0, OUTPUT_DESC(x_grad)}}; | |||
| REG_ADPT_DESC(NLLLossGrad, kNameNLLLossGrad, ADPT_DESC(NLLLossGrad)) | |||
| } // namespace mindspore::transform | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_ | |||
| #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "transform/graph_ir/op_declare/op_declare_macro.h" | |||
| #include "ops/math_ops.h" | |||
| namespace mindspore::transform { | |||
| DECLARE_OP_ADAPTER(NLLLoss) | |||
| DECLARE_OP_USE_OUTPUT(NLLLoss) | |||
| DECLARE_OP_ADAPTER(NLLLossGrad) | |||
| DECLARE_OP_USE_OUTPUT(NLLLossGrad) | |||
| } // namespace mindspore::transform | |||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATH_OPS_DECLARE_H_ | |||
| @@ -761,6 +761,20 @@ def get_bprop_softmax_cross_entropy_with_logits(self): | |||
| return bprop | |||
| @bprop_getters.register(P.NLLLoss) | |||
| def get_bprop_nll_loss(self): | |||
| """Grad definition for `NLLLoss` operation.""" | |||
| nll_loss_grad = G.NLLLossGrad(reduction=self.reduction) | |||
| def bprop(x, target, weight, out, dout): | |||
| total_weight = out[1] | |||
| dout_x = dout[0] | |||
| dx = nll_loss_grad(x, dout_x, target, weight, total_weight) | |||
| return dx, zeros_like(target), zeros_like(weight) | |||
| return bprop | |||
| @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits) | |||
| def get_bprop_sparse_softmax_cross_entropy_with_logits(self): | |||
| """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation.""" | |||
| @@ -353,3 +353,5 @@ from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe | |||
| from .conv3d_transpose import _conv3d_transpose_tbe | |||
| from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe | |||
| from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe | |||
| from .nll_loss import _nll_loss_tbe | |||
| from .nll_loss_grad import _nll_loss_grad_tbe | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """NLLLoss op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| nll_loss_op_info = TBERegOp("NLLLoss") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("nll_loss.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("nll_loss") \ | |||
| .partial_flag(True) \ | |||
| .attr("reduction", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "target", False, "required", "all") \ | |||
| .input(2, "weight", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .output(1, "total_weight", False, "optional", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(nll_loss_op_info) | |||
| def _nll_loss_tbe(): | |||
| """NLLLoss TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """NLLLossGrad op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| nll_loss_grad_op_info = TBERegOp("NLLLossGrad") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("nll_loss_grad.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("nll_loss_grad") \ | |||
| .partial_flag(True) \ | |||
| .attr("reduction", "optional", "str", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "y_grad", False, "required", "all") \ | |||
| .input(2, "target", False, "required", "all") \ | |||
| .input(3, "weight", False, "required", "all") \ | |||
| .input(4, "total_weight", False, "required", "all") \ | |||
| .output(0, "x_grad", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(nll_loss_grad_op_info) | |||
| def _nll_loss_grad_tbe(): | |||
| """NLLLossGrad TBE register""" | |||
| return | |||
| @@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam | |||
| AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SigmoidCrossEntropyWithLogits, NLLLoss, | |||
| SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, | |||
| SoftmaxCrossEntropyWithLogits, ROIAlign, | |||
| SparseSoftmaxCrossEntropyWithLogits, Tanh, | |||
| @@ -147,6 +147,7 @@ __all__ = [ | |||
| 'SoftmaxCrossEntropyWithLogits', | |||
| 'ROIAlign', | |||
| 'SparseSoftmaxCrossEntropyWithLogits', | |||
| 'NLLLoss', | |||
| 'SGD', | |||
| 'ApplyMomentum', | |||
| 'ExpandDims', | |||
| @@ -1746,6 +1746,35 @@ class SliceGrad(PrimitiveWithInfer): | |||
| 'value': None} | |||
| class NLLLossGrad(PrimitiveWithInfer): | |||
| """Computes the gradients of `NLLLoss`.""" | |||
| @prim_attr_register | |||
| def __init__(self, reduction="mean"): | |||
| """Initialize NLLLoss""" | |||
| self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) | |||
| self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) | |||
| self.add_prim_attr('reduction', self.reduction) | |||
| def infer_shape(self, x_shape, y_grad_shape, t_shape, w_shape, tw_shape): | |||
| validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name) | |||
| validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) | |||
| validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) | |||
| validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) | |||
| validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype): | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("y_grad_dtype", y_grad_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name) | |||
| validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("tw_dtype", tw_dtype, valid_dtypes, self.name) | |||
| validator.check('tw_shape_dtype', tw_dtype, 'w_shape_dtype', w_dtype, Rel.EQ, self.name) | |||
| return x_dtype | |||
| class SmoothL1LossGrad(PrimitiveWithInfer): | |||
| """Computes gradient for prediction on SmoothL1Loss.""" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -1917,6 +1917,66 @@ class TopK(PrimitiveWithInfer): | |||
| 'value': None} | |||
| class NLLLoss(PrimitiveWithInfer): | |||
| r""" | |||
| Gets the negative log likelihood loss between logits and labels. | |||
| Args: | |||
| reduction (string): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean". | |||
| Inputs: | |||
| - **input** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type only support float32 or float16. | |||
| - **target** (Tensor) - Ground truth labels, with shape :math:`(N)`. Data type only support int32. | |||
| - **weight** (Tensor) - The rescaling weight to each class, with shape :math:`(C)` and data type only | |||
| support float32 or float16`. | |||
| Outputs: | |||
| Tuple of 2 tensors composed with `loss` and `total_weight`. when `reduction` is `none` and `input` is 2D | |||
| tensor, the `loss` shape is `(N,)`. Otherwise, the `loss` and the `total_weight` is a scalar. The data type | |||
| of `loss` and `total_weight` are same with `input's` and `weight's` respectively. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> input = Tensor(np.array([[0.5488135, 0.71518934], | |||
| >>> [0.60276335, 0.5448832], | |||
| >>> [0.4236548, 0.6458941]]).astype(np.float32)) | |||
| >>> target = Tensor(np.array([0, 0, 0]).astype(np.int32)) | |||
| >>> weight = Tensor(np.array([0.3834415, 0.79172504]).astype(np.float32)) | |||
| >>> nll_loss = ops.NLLLoss(reduction="mean") | |||
| >>> loss, weight = nll_loss(input, target, weight) | |||
| >>> print(loss) | |||
| [-0.52507716] | |||
| >>> print(weight) | |||
| [1.1503246 0.79172504] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, reduction="mean"): | |||
| """Initialize NLLLoss""" | |||
| self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) | |||
| self.reduction = validator.check_string(reduction.lower(), ['none', 'sum', 'mean'], 'reduction', self.name) | |||
| self.add_prim_attr('reduction', self.reduction) | |||
| def infer_shape(self, x_shape, t_shape, w_shape): | |||
| validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name) | |||
| validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) | |||
| validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) | |||
| validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) | |||
| validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) | |||
| if self.reduction == "none": | |||
| return t_shape, () | |||
| return (), () | |||
| def infer_dtype(self, x_dtype, t_dtype, w_dtype): | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name) | |||
| validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name) | |||
| return x_dtype, w_dtype | |||
| class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): | |||
| r""" | |||
| Gets the softmax cross-entropy value between logits and labels with one-hot encoding. | |||
| @@ -242,6 +242,18 @@ class BatchNorm3d(nn.Cell): | |||
| return bn3d_out | |||
| class NLLLoss(nn.Cell): | |||
| """NLLLoss net definition""" | |||
| def __init__(self, reduction): | |||
| super(NLLLoss, self).__init__() | |||
| self.nll_loss = P.NLLLoss(reduction=reduction) | |||
| def construct(self, input_x, target, weight): | |||
| loss = self.nll_loss(input_x, target, weight) | |||
| return loss | |||
| class ClipByNorm(nn.Cell): | |||
| """ClipByNorm net definition""" | |||
| @@ -1253,6 +1265,12 @@ test_case_math_ops = [ | |||
| 'block': Moments(axis=(), keep_dims=False), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ('NLLLoss', { | |||
| 'block': NLLLoss(reduction="mean"), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 16), mstype.float32), | |||
| Tensor(np.random.rand(3), mstype.int32), | |||
| Tensor(np.random.rand(16), mstype.float32)], | |||
| 'desc_bprop': [(Tensor(np.random.rand(1), mstype.float32), Tensor(np.random.rand(1), mstype.float32))]}), | |||
| ('BatchNorm3d', { | |||
| 'block': BatchNorm3d(num_features=3), | |||
| 'desc_inputs': [Tensor(np.random.rand(3, 3, 3, 5, 4).astype(np.float32))], | |||