|
|
|
@@ -13,6 +13,7 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""Faithfulness.""" |
|
|
|
from decimal import Decimal |
|
|
|
from typing import Callable, Optional, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
@@ -147,8 +148,8 @@ class NaiveFaithfulness(_FaithfulnessHelper): |
|
|
|
- faithfulness (np.ndarray): faithfulness score |
|
|
|
|
|
|
|
""" |
|
|
|
if not np.count_nonzero(saliency): |
|
|
|
log.warning("The saliency map is zero everywhere. The correlation will be set to zero.") |
|
|
|
if Decimal(str(saliency.max())) == Decimal(str(saliency.min())): |
|
|
|
log.warning("The saliency map is uniform everywhere. The correlation will be set to zero.") |
|
|
|
correlation = 0 |
|
|
|
return np.array([correlation], np.float) |
|
|
|
|
|
|
|
@@ -163,6 +164,11 @@ class NaiveFaithfulness(_FaithfulnessHelper): |
|
|
|
predictions = model(perturbations)[:, targets].asnumpy() |
|
|
|
predictions = predictions.reshape(*feature_importance.shape) |
|
|
|
|
|
|
|
if Decimal(str(predictions.max())) == Decimal(str(predictions.min())): |
|
|
|
log.warning("The perturbations do not affect the predictions. The correlation will be set to zero.") |
|
|
|
correlation = 0 |
|
|
|
return np.array([correlation], np.float) |
|
|
|
|
|
|
|
faithfulness = -np.corrcoef(feature_importance, predictions) |
|
|
|
faithfulness = np.diag(faithfulness[:batch_size, batch_size:]) |
|
|
|
return faithfulness |
|
|
|
|