Browse Source

!11230 Modify nn.Range for GPU.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
d71c97dc2f
1 changed files with 1 additions and 12 deletions
  1. +1
    -12
      mindspore/nn/layer/math.py

+ 1
- 12
mindspore/nn/layer/math.py View File

@@ -15,7 +15,6 @@
"""math"""
import math
import numpy as np
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
@@ -128,7 +127,6 @@ class Range(Cell):

def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__()
self.is_gpu = context.get_context("device_target") == "GPU"
validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("delta", delta, [int, float], self.cls_name)
if delta == 0:
@@ -157,17 +155,8 @@ class Range(Cell):
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)

if self.is_gpu:
self.start = Tensor(start, self.dtype)
self.limit = Tensor(limit, self.dtype)
self.delta = Tensor(delta, self.dtype)
self.range_gpu = P.Range(length_input)

def construct(self):
if self.is_gpu:
range_out = self.range_gpu(self.start, self.limit, self.delta)
else:
range_out = self.range_x(self.input_tensor)
range_out = self.range_x(self.input_tensor)
return range_out




Loading…
Cancel
Save