From 3d16a4d111ba8ef2db701ebe6b47d67a13f61062 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Mon, 11 Jan 2021 15:18:40 +0800 Subject: [PATCH] Update nn.Range for GPU backend. --- mindspore/nn/layer/math.py | 15 +++++++++++++-- mindspore/ops/operations/array_ops.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 3904a495c8..6dca3575d9 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -15,6 +15,7 @@ """math""" import math import numpy as np +import mindspore.context as context from mindspore.ops import operations as P from mindspore.ops.operations import _inner_ops as inner from mindspore.common.tensor import Tensor @@ -116,7 +117,7 @@ class Range(Cell): Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> net = nn.Range(1, 8, 2) @@ -127,6 +128,7 @@ class Range(Cell): def __init__(self, start, limit=None, delta=1): super(Range, self).__init__() + self.is_gpu = context.get_context("device_target") == "GPU" validator.check_value_type("start", start, [int, float], self.cls_name) validator.check_value_type("delta", delta, [int, float], self.cls_name) if delta == 0: @@ -155,8 +157,17 @@ class Range(Cell): length_input = math.ceil((limit - start) / delta) self.input_tensor = Tensor(list(range(length_input)), self.dtype) + if self.is_gpu: + self.start = Tensor(start, self.dtype) + self.limit = Tensor(limit, self.dtype) + self.delta = Tensor(delta, self.dtype) + self.range_gpu = P.Range(length_input) + def construct(self): - range_out = self.range_x(self.input_tensor) + if self.is_gpu: + range_out = self.range_gpu(self.start, self.limit, self.delta) + else: + range_out = self.range_x(self.input_tensor) return range_out diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index da97966a65..9949878dbb 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4761,7 +4761,7 @@ class Range(PrimitiveWithCheck): [0, 4, 8] Supported Platforms: - ``Ascend`` ``GPU`` + ``GPU`` """ @prim_attr_register