|
|
|
@@ -33,6 +33,19 @@ def test_root_mean_square_distance(): |
|
|
|
assert math.isclose(distance, 1.0000000000000002, abs_tol=0.001) |
|
|
|
|
|
|
|
|
|
|
|
def test_root_mean_square_distance_indexes_awareness(): |
|
|
|
"""A indexes aware version of test_root_mean_square_distance""" |
|
|
|
x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]])) |
|
|
|
y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]])) |
|
|
|
y2 = Tensor(np.array([[0, 0, 1], [0, 2, 1], [2, 0, 1]])) |
|
|
|
metric = get_metric_fn('root_mean_square_distance').set_indexes([0, 2, 3]) |
|
|
|
metric.clear() |
|
|
|
metric.update(x, y, y2, 0) |
|
|
|
distance = metric.eval() |
|
|
|
|
|
|
|
assert math.isclose(distance, 0.6666666666666669, abs_tol=0.001) |
|
|
|
|
|
|
|
|
|
|
|
def test_root_mean_square_distance_update1(): |
|
|
|
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) |
|
|
|
metric = RootMeanSquareDistance() |
|
|
|
|