Browse Source

!12530 add meshgrid support for start <= stop & complex step with real component

From: @jachua
Reviewed-by: @guoqi1024,@liangchenghui
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
842ca43df3
1 changed files with 19 additions and 1 deletions
  1. +19
    -1
      mindspore/numpy/array_creations.py

+ 19
- 1
mindspore/numpy/array_creations.py View File

@@ -1291,6 +1291,12 @@ def meshgrid(*xi, sparse=False, indexing='xy'):
if indexing not in ('xy', 'ij'): if indexing not in ('xy', 'ij'):
_raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.") _raise_type_error("Valid values for `indexing` are 'xy' and 'ij'.")


shape_out = ()
for x in xi:
shape_out += (x.size,)
if _is_shape_empty(shape_out):
return ones(shape_out)

grids = [] grids = []
for x in xi: for x in xi:
if F.rank(x) == 1: if F.rank(x) == 1:
@@ -1351,7 +1357,7 @@ class nd_grid:
else: else:
step = 1 step = 1
if isinstance(step, complex): if isinstance(step, complex):
v = linspace(k.start, k.stop, int(abs(step.imag)))
v = linspace(k.start, k.stop, int(abs(step)))
else: else:
v = arange(k.start, k.stop, step) v = arange(k.start, k.stop, step)
xi.append(v) xi.append(v)
@@ -1362,6 +1368,8 @@ class nd_grid:
if self.sparse: if self.sparse:
return grids return grids


if isinstance(grids, Tensor_):
return grids
expanded = [] expanded = []
for grid in grids: for grid in grids:
expanded.append(F.expand_dims(grid, 0)) expanded.append(F.expand_dims(grid, 0))
@@ -1380,6 +1388,11 @@ class mGridClass(nd_grid):
as specifying the number of points to create between the start and as specifying the number of points to create between the start and
stop values, where the stop value is inclusive. stop values, where the stop value is inclusive.


Note:
Unlike Numpy, if the step length is a complex number with a real
component, the step length is handled as equivalent to
``int(abs(step))``.

Returns: Returns:
Tensor or tuple of tensor, a meshgrid. Tensor or tuple of tensor, a meshgrid.


@@ -1422,6 +1435,11 @@ class oGridClass(nd_grid):
as specifying the number of points to create between the start and as specifying the number of points to create between the start and
stop values, where the stop value is inclusive. stop values, where the stop value is inclusive.


Note:
Unlike Numpy, if the step length is a complex number with a real
component, the step length is handled as equivalent to
``int(abs(step))``.

Raises: Raises:
TypeError: if slicing indices are not integers. TypeError: if slicing indices are not integers.




Loading…
Cancel
Save