|
|
|
@@ -27,7 +27,6 @@ from ..._checkparam import Validator as validator |
|
|
|
|
|
|
|
__all__ = ['ReduceLogSumExp', |
|
|
|
'Range', |
|
|
|
'LinSpace', |
|
|
|
'LGamma', |
|
|
|
'DiGamma', |
|
|
|
'IGamma', |
|
|
|
@@ -157,57 +156,6 @@ class Range(Cell): |
|
|
|
return range_out |
|
|
|
|
|
|
|
|
|
|
|
class LinSpace(Cell): |
|
|
|
r""" |
|
|
|
Generates values in an interval. |
|
|
|
|
|
|
|
Args: |
|
|
|
start (Union[int, float]): The start of interval. With shape of 0-D. |
|
|
|
stop (Union[int, float]): The end of interval. With shape of 0-D. |
|
|
|
num (int): ticks number in the interval, the ticks include start and stop value. With shape of 0-D. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, With type same as `start`. The shape is 1-D with length of `num`. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> linspace = nn.LinSpace(1, 10, 5) |
|
|
|
>>> output = linspace() |
|
|
|
>>> print(output) |
|
|
|
[ 1. 3.25 5.5 7.75 10. ] |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, start, stop, num): |
|
|
|
super(LinSpace, self).__init__() |
|
|
|
validator.check_value_type("start", start, [int, float], self.cls_name) |
|
|
|
validator.check_value_type("stop", stop, [int, float], self.cls_name) |
|
|
|
validator.check_value_type("num", num, [int], self.cls_name) |
|
|
|
validator.check_positive_int(num, "num", self.cls_name) |
|
|
|
|
|
|
|
self.is_single = bool(num == 1) |
|
|
|
self.lin_space = P.LinSpace() |
|
|
|
self.start = Tensor(start, mstype.float32) |
|
|
|
self.stop = Tensor(stop, mstype.float32) |
|
|
|
self.num = num |
|
|
|
self.start_array = Tensor([start], mstype.float32) |
|
|
|
|
|
|
|
def construct(self): |
|
|
|
if self.is_single: |
|
|
|
return self.start_array |
|
|
|
|
|
|
|
lin_space_out = self.lin_space(self.start, self.stop, self.num) |
|
|
|
return lin_space_out |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def check_tensors_dtype_same(data_dtype, value_dtype, op_name): |
|
|
|
"""Check tensors data type same.""" |
|
|
|
if data_dtype in value_dtype: |
|
|
|
return True |
|
|
|
raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " |
|
|
|
f"is not consistent with assigned tensor data type {data_dtype}.") |
|
|
|
|
|
|
|
class LGamma(Cell): |
|
|
|
r""" |
|
|
|
Calculate LGamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function". |
|
|
|
|