| @@ -27,7 +27,6 @@ from ..._checkparam import Validator as validator | |||||
| __all__ = ['ReduceLogSumExp', | __all__ = ['ReduceLogSumExp', | ||||
| 'Range', | 'Range', | ||||
| 'LinSpace', | |||||
| 'LGamma', | 'LGamma', | ||||
| 'DiGamma', | 'DiGamma', | ||||
| 'IGamma', | 'IGamma', | ||||
| @@ -157,57 +156,6 @@ class Range(Cell): | |||||
| return range_out | return range_out | ||||
| class LinSpace(Cell): | |||||
| r""" | |||||
| Generates values in an interval. | |||||
| Args: | |||||
| start (Union[int, float]): The start of interval. With shape of 0-D. | |||||
| stop (Union[int, float]): The end of interval. With shape of 0-D. | |||||
| num (int): ticks number in the interval, the ticks include start and stop value. With shape of 0-D. | |||||
| Outputs: | |||||
| Tensor, With type same as `start`. The shape is 1-D with length of `num`. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> linspace = nn.LinSpace(1, 10, 5) | |||||
| >>> output = linspace() | |||||
| >>> print(output) | |||||
| [ 1. 3.25 5.5 7.75 10. ] | |||||
| """ | |||||
| def __init__(self, start, stop, num): | |||||
| super(LinSpace, self).__init__() | |||||
| validator.check_value_type("start", start, [int, float], self.cls_name) | |||||
| validator.check_value_type("stop", stop, [int, float], self.cls_name) | |||||
| validator.check_value_type("num", num, [int], self.cls_name) | |||||
| validator.check_positive_int(num, "num", self.cls_name) | |||||
| self.is_single = bool(num == 1) | |||||
| self.lin_space = P.LinSpace() | |||||
| self.start = Tensor(start, mstype.float32) | |||||
| self.stop = Tensor(stop, mstype.float32) | |||||
| self.num = num | |||||
| self.start_array = Tensor([start], mstype.float32) | |||||
| def construct(self): | |||||
| if self.is_single: | |||||
| return self.start_array | |||||
| lin_space_out = self.lin_space(self.start, self.stop, self.num) | |||||
| return lin_space_out | |||||
| @constexpr | |||||
| def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | |||||
| """Check tensors data type same.""" | |||||
| if data_dtype in value_dtype: | |||||
| return True | |||||
| raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " | |||||
| f"is not consistent with assigned tensor data type {data_dtype}.") | |||||
| class LGamma(Cell): | class LGamma(Cell): | ||||
| r""" | r""" | ||||
| Calculate LGamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function". | Calculate LGamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function". | ||||