From: @bai-yangfan Reviewed-by: @kingxian,@c_34 Signed-off-by: @kingxiantags/v1.1.0
| @@ -39,6 +39,8 @@ class MAE(Metric): | |||||
| >>> error.clear() | >>> error.clear() | ||||
| >>> error.update(x, y) | >>> error.update(x, y) | ||||
| >>> result = error.eval() | >>> result = error.eval() | ||||
| >>> print(result) | |||||
| 0.037499990314245224 | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(MAE, self).__init__() | super(MAE, self).__init__() | ||||
| @@ -36,10 +36,11 @@ class Fbeta(Metric): | |||||
| >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) | >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) | ||||
| >>> y = Tensor(np.array([1, 0, 1])) | >>> y = Tensor(np.array([1, 0, 1])) | ||||
| >>> metric = nn.Fbeta(1) | >>> metric = nn.Fbeta(1) | ||||
| >>> metric.clear() | |||||
| >>> metric.update(x, y) | >>> metric.update(x, y) | ||||
| >>> fbeta = metric.eval() | >>> fbeta = metric.eval() | ||||
| >>> print(fbeta) | >>> print(fbeta) | ||||
| [0.66666667 0.66666667] | |||||
| [0.66666667 0.66666667] | |||||
| """ | """ | ||||
| def __init__(self, beta): | def __init__(self, beta): | ||||
| super(Fbeta, self).__init__() | super(Fbeta, self).__init__() | ||||
| @@ -133,7 +134,7 @@ class F1(Fbeta): | |||||
| >>> metric.update(x, y) | >>> metric.update(x, y) | ||||
| >>> result = metric.eval() | >>> result = metric.eval() | ||||
| >>> print(result) | >>> print(result) | ||||
| [0.66666667 0.66666667] | |||||
| [0.66666667 0.66666667] | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super(F1, self).__init__(1.0) | super(F1, self).__init__(1.0) | ||||
| @@ -47,6 +47,9 @@ class Precision(EvaluationBase): | |||||
| >>> metric.clear() | >>> metric.clear() | ||||
| >>> metric.update(x, y) | >>> metric.update(x, y) | ||||
| >>> precision = metric.eval() | >>> precision = metric.eval() | ||||
| >>> print(precision) | |||||
| [0.5 1. ] | |||||
| """ | """ | ||||
| def __init__(self, eval_type='classification'): | def __init__(self, eval_type='classification'): | ||||
| super(Precision, self).__init__(eval_type) | super(Precision, self).__init__(eval_type) | ||||
| @@ -47,6 +47,8 @@ class Recall(EvaluationBase): | |||||
| >>> metric.clear() | >>> metric.clear() | ||||
| >>> metric.update(x, y) | >>> metric.update(x, y) | ||||
| >>> recall = metric.eval() | >>> recall = metric.eval() | ||||
| >>> print(recall) | |||||
| [1. 0.5] | |||||
| """ | """ | ||||
| def __init__(self, eval_type='classification'): | def __init__(self, eval_type='classification'): | ||||
| super(Recall, self).__init__(eval_type) | super(Recall, self).__init__(eval_type) | ||||
| @@ -276,7 +276,7 @@ class GetNextSingleOp(Cell): | |||||
| >>> relu = P.ReLU() | >>> relu = P.ReLU() | ||||
| >>> result = relu(data).asnumpy() | >>> result = relu(data).asnumpy() | ||||
| >>> print(result.shape) | >>> print(result.shape) | ||||
| >>> (32, 1, 32, 32) | |||||
| (32, 1, 32, 32) | |||||
| """ | """ | ||||
| def __init__(self, dataset_types, dataset_shapes, queue_name): | def __init__(self, dataset_types, dataset_shapes, queue_name): | ||||
| @@ -356,6 +356,7 @@ class WithEvalCell(Cell): | |||||
| Args: | Args: | ||||
| network (Cell): The network Cell. | network (Cell): The network Cell. | ||||
| loss_fn (Cell): The loss Cell. | loss_fn (Cell): The loss Cell. | ||||
| add_cast_fp32 (bool): Adjust the data type to float32. | |||||
| Inputs: | Inputs: | ||||
| - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | ||||
| @@ -410,7 +411,7 @@ class ParameterUpdate(Cell): | |||||
| >>> param = network.parameters_dict()['weight'] | >>> param = network.parameters_dict()['weight'] | ||||
| >>> update = nn.ParameterUpdate(param) | >>> update = nn.ParameterUpdate(param) | ||||
| >>> update.phase = "update_param" | >>> update.phase = "update_param" | ||||
| >>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32) | |||||
| >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) | |||||
| >>> network_updata = update(weight) | >>> network_updata = update(weight) | ||||
| """ | """ | ||||