Browse Source

Add input check for axes which is float type or out of bound

pull/14227/head
w00535372 4 years ago
parent
commit
3fc2c16fea
1 changed files with 10 additions and 8 deletions
  1. +10
    -8
      mindspore/ops/composite/math_ops.py

+ 10
- 8
mindspore/ops/composite/math_ops.py View File

@@ -365,14 +365,16 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
raise ValueError("Require two axes inputs, given less")
if isinstance(axes, tuple):
axes = list(axes)
for sub_axes in axes:
if isinstance(sub_axes, (list, tuple)):
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
# Reverse if axis < 0
if axes[0] < 0:
axes[0] += len(x1_shape)
if axes[1] < 0:
axes[1] += len(x2_shape)
if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
elif isinstance(axes, int):
if axes == 0:
raise ValueError("Batch dim cannot be used as in axes.")
@@ -429,6 +431,9 @@ def batch_dot(x1, x2, axes=None):
"""
Computation of batch dot product between samples in two tensors containing batch dims.

.. math::
output = x1[batch, :] * x2[batch, :]

Inputs:
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float32
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. x2's datatype should
@@ -439,13 +444,10 @@ def batch_dot(x1, x2, axes=None):

Outputs:
Tensor, batch dot product of x1 and x2. The Shape of output for input shapes (batch, d1, axes, d2) and
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)

.. math::
output = x1[batch, :] * x2[batch, :]
(batch, d3, axes, d4) is (batch, d1, d2, d3, d4)

Raises:
TypeError: If shapes of x1 and x2 are not the same.
TypeError: If type of x1 and x2 are not the same.
ValueError: If rank of x1 or x2 less than 2.
ValueError: If batch dim used in axes.
ValueError: If dtype of x1 or x2 is not float32.


Loading…
Cancel
Save