Merge pull request !2177 from jiangjinsheng/vm_lin_spacetags/v0.5.0-beta
| @@ -112,6 +112,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"square_sum_all", "square_sum_all"}, | |||
| {"cum_sum", "cumsum_d"}, | |||
| {"range", "range_d"}, | |||
| {"lin_space", "lin_space_d"}, | |||
| {"inv_grad", "inv_grad"}, | |||
| {"apply_rms_prop", "apply_rms_prop_d"}, | |||
| {"cum_prod", "cumprod_d"}, | |||
| @@ -20,8 +20,11 @@ from mindspore.common.tensor import Tensor | |||
| from ..cell import Cell | |||
| from ...common import dtype as mstype | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace'] | |||
| __all__ = ['ReduceLogSumExp', 'Range'] | |||
| class ReduceLogSumExp(Cell): | |||
| r""" | |||
| @@ -125,3 +128,48 @@ class Range(Cell): | |||
| def construct(self): | |||
| range_out = self.range_x(self.input_tensor) | |||
| return range_out | |||
| class LinSpace(Cell): | |||
| r""" | |||
| Generates values in an interval. And return the corresponding interpolation accroding to assist. | |||
| Args: | |||
| - **start** (Union[int, float]) - The start of interval, With shape of 0-D. | |||
| - **stop** (Union[int, float]) - The end of interval, With shape of 0-D. | |||
| - **num** (int) - ticks number in the interval, the ticks include start and stop value. | |||
| With shape of 0-D. | |||
| Outputs: | |||
| Tensor, With type same as `start`. The shape is 1-D with length of `num`. | |||
| Examples: | |||
| >>> linspace = nn.LinSpace() | |||
| >>> start = Tensor(1, mindspore.float32) | |||
| >>> stop = Tensor(10, mindspore.float32) | |||
| >>> num = Tensor(5, mindspore.int32) | |||
| >>> output = linspace(start, stop, num) | |||
| [1, 3.25, 5.5, 7.75, 10] | |||
| """ | |||
| def __init__(self, start, stop, num): | |||
| super(LinSpace, self).__init__() | |||
| validator.check_value_type("start", start, [int, float], self.cls_name) | |||
| validator.check_value_type("stop", stop, [int, float], self.cls_name) | |||
| validator.check_value_type("num", num, [int], self.cls_name) | |||
| validator.check_integer("num", num, 0, Rel.GT, self.cls_name) | |||
| self.is_single = bool(num == 1) | |||
| self.lin_space = inner.LinSpace() | |||
| self.start = Tensor(start, mstype.float32) | |||
| self.stop = Tensor(stop, mstype.float32) | |||
| self.assist = Tensor(list(range(num)), mstype.float32) | |||
| self.num = Tensor(num, mstype.int32) | |||
| self.start_array = Tensor([start], mstype.float32) | |||
| def construct(self): | |||
| if self.is_single: | |||
| return self.start_array | |||
| lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num) | |||
| return lin_space_out | |||
| @@ -21,6 +21,7 @@ from mindspore.ops import _selected_grad_ops as SG | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| from ..operations import _grad_ops as G | |||
| from ..operations import _inner_ops as inner | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..functional import broadcast_gradient_args, reduced_shape, tuple_div | |||
| from .grad_base import bprop_getters | |||
| @@ -1049,3 +1050,13 @@ def get_bprop_inv(self): | |||
| dx = inv_grad(out, dout) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(inner.LinSpace) | |||
| def get_bprop_lin_space(self): | |||
| """Grad definition for `LinSpace` operation.""" | |||
| def bprop(assist, start, stop, num, out, dout): | |||
| return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num) | |||
| return bprop | |||
| @@ -262,3 +262,4 @@ from .tensor_scatter_update import _tensor_scatter_update_tbe | |||
| from .inplace_update import _inplace_update_tbe | |||
| from .splitv import _split_v_tbe | |||
| from .in_top_k import _in_top_k_tbe | |||
| from .lin_space import _lin_space_tbe | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LinSpace op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lin_space_op_info = TBERegOp("LinSpace") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lin_space.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lin_space") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("broadcast") \ | |||
| .input(0, "assist", False, "required", "all") \ | |||
| .input(1, "start", False, "required", "all") \ | |||
| .input(2, "stop", False, "required", "all") \ | |||
| .input(3, "num", False, "required", "all") \ | |||
| .output(0, "output", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | |||
| DataType.F32_Default,) \ | |||
| .get_op_info() | |||
| @op_info_register(lin_space_op_info) | |||
| def _lin_space_tbe(): | |||
| """LinSpace TBE register""" | |||
| return | |||
| @@ -328,3 +328,42 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||
| 'dtype': params['dtype'], | |||
| 'value': None} | |||
| return out | |||
| class LinSpace(PrimitiveWithInfer): | |||
| r""" | |||
| Generates values in an interval. And return the corresponding interpolation accroding to assist. | |||
| Inputs: | |||
| - **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D. | |||
| - **start** (Tensor[float32]) - The start of interval, With shape of 0-D. | |||
| - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D. | |||
| - **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value. | |||
| With shape of 0-D. | |||
| Outputs: | |||
| Tensor, has the same shape as `assist`. | |||
| Examples: | |||
| >>> linspace = P.LinSpace() | |||
| >>> assist = Tensor([5, 5.5], mindspore.float32) | |||
| >>> start = Tensor(1, mindspore.float32) | |||
| >>> stop = Tensor(10, mindspore.float32) | |||
| >>> num = Tensor(5, mindspore.int32) | |||
| >>> output = linspace(assist, start, stop, num) | |||
| [12.25, 13.375] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| def infer_shape(self, assist, start, stop, num): | |||
| return assist | |||
| def infer_dtype(self, assist, start, stop, num): | |||
| args = {"num": num} | |||
| validator.check_tensor_type_same(args, (mstype.int32,), self.name) | |||
| args = {"assist": assist, "start": start, "stop": stop} | |||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | |||
| return assist | |||
| @@ -1599,6 +1599,14 @@ test_case_array_ops = [ | |||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | |||
| Tensor(np.array([1, 2, 3]).astype(np.int32))], | |||
| 'desc_bprop': [[3, 3]]}), | |||
| ('LinSpace', { | |||
| 'block': inner.LinSpace(), | |||
| 'desc_inputs': [Tensor([5, 5.5], mstype.float32), | |||
| Tensor(1, mstype.float32), | |||
| Tensor(10, mstype.float32), | |||
| Tensor(5, mstype.int32)], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ] | |||
| test_case_other_ops = [ | |||