Browse Source

fix meshgrid with complex step & stop <= start

tags/v1.2.0-rc1
huangmengxi 4 years ago
parent
commit
e17dffbd76
1 changed files with 19 additions and 1 deletions
  1. +19
    -1
      mindspore/numpy/array_creations.py

+ 19
- 1
mindspore/numpy/array_creations.py View File

@@ -1291,6 +1291,12 @@ def meshgrid(*xi, sparse=False, indexing='xy'):
if indexing not in ('xy', 'ij'): if indexing not in ('xy', 'ij'):
_raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.") _raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.")


shape_out = ()
for x in xi:
shape_out += (x.size,)
if _is_shape_empty(shape_out):
return ones(shape_out)

grids = [] grids = []
for x in xi: for x in xi:
if F.rank(x) == 1: if F.rank(x) == 1:
@@ -1351,7 +1357,7 @@ class nd_grid:
else: else:
step = 1 step = 1
if isinstance(step, complex): if isinstance(step, complex):
v = linspace(k.start, k.stop, int(abs(step.imag)))
v = linspace(k.start, k.stop, int(abs(step)))
else: else:
v = arange(k.start, k.stop, step) v = arange(k.start, k.stop, step)
xi.append(v) xi.append(v)
@@ -1362,6 +1368,8 @@ class nd_grid:
if self.sparse: if self.sparse:
return grids return grids


if isinstance(grids, Tensor_):
return grids
expanded = [] expanded = []
for grid in grids: for grid in grids:
expanded.append(F.expand_dims(grid, 0)) expanded.append(F.expand_dims(grid, 0))
@@ -1380,6 +1388,11 @@ class mGridClass(nd_grid):
as specifying the number of points to create between the start and as specifying the number of points to create between the start and
stop values, where the stop value is inclusive. stop values, where the stop value is inclusive.


Note:
Unlike Numpy, if the step length is a complex number with a real
component, the step length is handled as equivalent to
``int(abs(step))``.

Returns: Returns:
Tensor or tuple of tensor, a meshgrid. Tensor or tuple of tensor, a meshgrid.


@@ -1422,6 +1435,11 @@ class oGridClass(nd_grid):
as specifying the number of points to create between the start and as specifying the number of points to create between the start and
stop values, where the stop value is inclusive. stop values, where the stop value is inclusive.


Note:
Unlike Numpy, if the step length is a complex number with a real
component, the step length is handled as equivalent to
``int(abs(step))``.

Raises: Raises:
TypeError: if slicing indices are not integers. TypeError: if slicing indices are not integers.




Loading…
Cancel
Save