|
|
|
@@ -315,7 +315,7 @@ class ScalarGradChecker(_GradChecker): |
|
|
|
output_selector=None, |
|
|
|
sampling_times=-1, |
|
|
|
reduce_output=False) -> None: |
|
|
|
grad_op = GradOperation('grad', get_all=True, sens_param=True) |
|
|
|
grad_op = GradOperation(get_all=True, sens_param=True) |
|
|
|
super(ScalarGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \ |
|
|
|
output_selector, sampling_times, reduce_output) |
|
|
|
|
|
|
|
@@ -358,7 +358,7 @@ class OperationGradChecker(_GradChecker): |
|
|
|
output_selector=None, |
|
|
|
sampling_times=-1, |
|
|
|
reduce_output=False) -> None: |
|
|
|
grad_op = GradOperation('grad', get_all=True, sens_param=True) |
|
|
|
grad_op = GradOperation(get_all=True, sens_param=True) |
|
|
|
super(OperationGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \ |
|
|
|
output_selector, sampling_times, reduce_output) |
|
|
|
|
|
|
|
@@ -390,7 +390,7 @@ class NNGradChecker(_GradChecker): |
|
|
|
output_selector=None, |
|
|
|
sampling_times=-1, |
|
|
|
reduce_output=False) -> None: |
|
|
|
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) |
|
|
|
grad_op = GradOperation(get_by_list=True, sens_param=True) |
|
|
|
self.params = ParameterTuple(fn.trainable_params()) |
|
|
|
super(NNGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \ |
|
|
|
output_selector, sampling_times, reduce_output) |
|
|
|
|