Browse Source

Update nn.Range for GPU backend.

tags/v1.2.0-rc1
liuxiao93 4 years ago
parent
commit
3d16a4d111
2 changed files with 14 additions and 3 deletions
  1. +13
    -2
      mindspore/nn/layer/math.py
  2. +1
    -1
      mindspore/ops/operations/array_ops.py

+ 13
- 2
mindspore/nn/layer/math.py View File

@@ -15,6 +15,7 @@
"""math"""
import math
import numpy as np
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
@@ -116,7 +117,7 @@ class Range(Cell):
Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.

Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> net = nn.Range(1, 8, 2)
@@ -127,6 +128,7 @@ class Range(Cell):

def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__()
self.is_gpu = context.get_context("device_target") == "GPU"
validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("delta", delta, [int, float], self.cls_name)
if delta == 0:
@@ -155,8 +157,17 @@ class Range(Cell):
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)

if self.is_gpu:
self.start = Tensor(start, self.dtype)
self.limit = Tensor(limit, self.dtype)
self.delta = Tensor(delta, self.dtype)
self.range_gpu = P.Range(length_input)

def construct(self):
range_out = self.range_x(self.input_tensor)
if self.is_gpu:
range_out = self.range_gpu(self.start, self.limit, self.delta)
else:
range_out = self.range_x(self.input_tensor)
return range_out




+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -4761,7 +4761,7 @@ class Range(PrimitiveWithCheck):
[0, 4, 8]

Supported Platforms:
``Ascend`` ``GPU``
``GPU``
"""

@prim_attr_register


Loading…
Cancel
Save