Browse Source

!10040 nn_notes

From: @bai-yangfan
Reviewed-by: @kingxian,@c_34
Signed-off-by: @kingxian
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
51ed499f3d
5 changed files with 13 additions and 4 deletions
  1. +2
    -0
      mindspore/nn/metrics/error.py
  2. +3
    -2
      mindspore/nn/metrics/fbeta.py
  3. +3
    -0
      mindspore/nn/metrics/precision.py
  4. +2
    -0
      mindspore/nn/metrics/recall.py
  5. +3
    -2
      mindspore/nn/wrap/cell_wrapper.py

+ 2
- 0
mindspore/nn/metrics/error.py View File

@@ -39,6 +39,8 @@ class MAE(Metric):
>>> error.clear() >>> error.clear()
>>> error.update(x, y) >>> error.update(x, y)
>>> result = error.eval() >>> result = error.eval()
>>> print(result)
0.037499990314245224
""" """
def __init__(self): def __init__(self):
super(MAE, self).__init__() super(MAE, self).__init__()


+ 3
- 2
mindspore/nn/metrics/fbeta.py View File

@@ -36,10 +36,11 @@ class Fbeta(Metric):
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([1, 0, 1])) >>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.Fbeta(1) >>> metric = nn.Fbeta(1)
>>> metric.clear()
>>> metric.update(x, y) >>> metric.update(x, y)
>>> fbeta = metric.eval() >>> fbeta = metric.eval()
>>> print(fbeta) >>> print(fbeta)
[0.66666667 0.66666667]
[0.66666667 0.66666667]
""" """
def __init__(self, beta): def __init__(self, beta):
super(Fbeta, self).__init__() super(Fbeta, self).__init__()
@@ -133,7 +134,7 @@ class F1(Fbeta):
>>> metric.update(x, y) >>> metric.update(x, y)
>>> result = metric.eval() >>> result = metric.eval()
>>> print(result) >>> print(result)
[0.66666667 0.66666667]
[0.66666667 0.66666667]
""" """
def __init__(self): def __init__(self):
super(F1, self).__init__(1.0) super(F1, self).__init__(1.0)

+ 3
- 0
mindspore/nn/metrics/precision.py View File

@@ -47,6 +47,9 @@ class Precision(EvaluationBase):
>>> metric.clear() >>> metric.clear()
>>> metric.update(x, y) >>> metric.update(x, y)
>>> precision = metric.eval() >>> precision = metric.eval()
>>> print(precision)
[0.5 1. ]

""" """
def __init__(self, eval_type='classification'): def __init__(self, eval_type='classification'):
super(Precision, self).__init__(eval_type) super(Precision, self).__init__(eval_type)


+ 2
- 0
mindspore/nn/metrics/recall.py View File

@@ -47,6 +47,8 @@ class Recall(EvaluationBase):
>>> metric.clear() >>> metric.clear()
>>> metric.update(x, y) >>> metric.update(x, y)
>>> recall = metric.eval() >>> recall = metric.eval()
>>> print(recall)
[1. 0.5]
""" """
def __init__(self, eval_type='classification'): def __init__(self, eval_type='classification'):
super(Recall, self).__init__(eval_type) super(Recall, self).__init__(eval_type)


+ 3
- 2
mindspore/nn/wrap/cell_wrapper.py View File

@@ -276,7 +276,7 @@ class GetNextSingleOp(Cell):
>>> relu = P.ReLU() >>> relu = P.ReLU()
>>> result = relu(data).asnumpy() >>> result = relu(data).asnumpy()
>>> print(result.shape) >>> print(result.shape)
>>> (32, 1, 32, 32)
(32, 1, 32, 32)
""" """


def __init__(self, dataset_types, dataset_shapes, queue_name): def __init__(self, dataset_types, dataset_shapes, queue_name):
@@ -356,6 +356,7 @@ class WithEvalCell(Cell):
Args: Args:
network (Cell): The network Cell. network (Cell): The network Cell.
loss_fn (Cell): The loss Cell. loss_fn (Cell): The loss Cell.
add_cast_fp32 (bool): Adjust the data type to float32.


Inputs: Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
@@ -410,7 +411,7 @@ class ParameterUpdate(Cell):
>>> param = network.parameters_dict()['weight'] >>> param = network.parameters_dict()['weight']
>>> update = nn.ParameterUpdate(param) >>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param" >>> update.phase = "update_param"
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32)
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
>>> network_updata = update(weight) >>> network_updata = update(weight)
""" """




Loading…
Cancel
Save