From 80b8124d60f8bd118aed6ce11f4e200f8e2236a9 Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Wed, 16 Dec 2020 11:16:04 +0800 Subject: [PATCH] nn_notes --- mindspore/nn/metrics/error.py | 2 ++ mindspore/nn/metrics/fbeta.py | 5 +++-- mindspore/nn/metrics/precision.py | 3 +++ mindspore/nn/metrics/recall.py | 2 ++ mindspore/nn/wrap/cell_wrapper.py | 5 +++-- 5 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/metrics/error.py b/mindspore/nn/metrics/error.py index 4b517d9d82..5abae9430d 100644 --- a/mindspore/nn/metrics/error.py +++ b/mindspore/nn/metrics/error.py @@ -39,6 +39,8 @@ class MAE(Metric): >>> error.clear() >>> error.update(x, y) >>> result = error.eval() + >>> print(result) + 0.037499990314245224 """ def __init__(self): super(MAE, self).__init__() diff --git a/mindspore/nn/metrics/fbeta.py b/mindspore/nn/metrics/fbeta.py index d5a7087288..f2f3b435dc 100755 --- a/mindspore/nn/metrics/fbeta.py +++ b/mindspore/nn/metrics/fbeta.py @@ -36,10 +36,11 @@ class Fbeta(Metric): >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> y = Tensor(np.array([1, 0, 1])) >>> metric = nn.Fbeta(1) + >>> metric.clear() >>> metric.update(x, y) >>> fbeta = metric.eval() >>> print(fbeta) - [0.66666667 0.66666667] + [0.66666667 0.66666667] """ def __init__(self, beta): super(Fbeta, self).__init__() @@ -133,7 +134,7 @@ class F1(Fbeta): >>> metric.update(x, y) >>> result = metric.eval() >>> print(result) - [0.66666667 0.66666667] + [0.66666667 0.66666667] """ def __init__(self): super(F1, self).__init__(1.0) diff --git a/mindspore/nn/metrics/precision.py b/mindspore/nn/metrics/precision.py index 096d936a87..a0a4c727d7 100644 --- a/mindspore/nn/metrics/precision.py +++ b/mindspore/nn/metrics/precision.py @@ -47,6 +47,9 @@ class Precision(EvaluationBase): >>> metric.clear() >>> metric.update(x, y) >>> precision = metric.eval() + >>> print(precision) + [0.5 1. ] + """ def __init__(self, eval_type='classification'): super(Precision, self).__init__(eval_type) diff --git a/mindspore/nn/metrics/recall.py b/mindspore/nn/metrics/recall.py index 32657995b3..f3f8d89b3c 100644 --- a/mindspore/nn/metrics/recall.py +++ b/mindspore/nn/metrics/recall.py @@ -47,6 +47,8 @@ class Recall(EvaluationBase): >>> metric.clear() >>> metric.update(x, y) >>> recall = metric.eval() + >>> print(recall) + [1. 0.5] """ def __init__(self, eval_type='classification'): super(Recall, self).__init__(eval_type) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 57d5cc7c6b..0b5c874ca2 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -276,7 +276,7 @@ class GetNextSingleOp(Cell): >>> relu = P.ReLU() >>> result = relu(data).asnumpy() >>> print(result.shape) - >>> (32, 1, 32, 32) + (32, 1, 32, 32) """ def __init__(self, dataset_types, dataset_shapes, queue_name): @@ -356,6 +356,7 @@ class WithEvalCell(Cell): Args: network (Cell): The network Cell. loss_fn (Cell): The loss Cell. + add_cast_fp32 (bool): Adjust the data type to float32. Inputs: - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. @@ -410,7 +411,7 @@ class ParameterUpdate(Cell): >>> param = network.parameters_dict()['weight'] >>> update = nn.ParameterUpdate(param) >>> update.phase = "update_param" - >>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32) + >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) >>> network_updata = update(weight) """