Merge pull request !2177 from jiangjinsheng/vm_lin_spacetags/v0.5.0-beta
| @@ -112,6 +112,7 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"square_sum_all", "square_sum_all"}, | {"square_sum_all", "square_sum_all"}, | ||||
| {"cum_sum", "cumsum_d"}, | {"cum_sum", "cumsum_d"}, | ||||
| {"range", "range_d"}, | {"range", "range_d"}, | ||||
| {"lin_space", "lin_space_d"}, | |||||
| {"inv_grad", "inv_grad"}, | {"inv_grad", "inv_grad"}, | ||||
| {"apply_rms_prop", "apply_rms_prop_d"}, | {"apply_rms_prop", "apply_rms_prop_d"}, | ||||
| {"cum_prod", "cumprod_d"}, | {"cum_prod", "cumprod_d"}, | ||||
| @@ -20,8 +20,11 @@ from mindspore.common.tensor import Tensor | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | |||||
| __all__ = ['ReduceLogSumExp', 'Range', 'LinSpace'] | |||||
| __all__ = ['ReduceLogSumExp', 'Range'] | |||||
| class ReduceLogSumExp(Cell): | class ReduceLogSumExp(Cell): | ||||
| r""" | r""" | ||||
| @@ -125,3 +128,48 @@ class Range(Cell): | |||||
| def construct(self): | def construct(self): | ||||
| range_out = self.range_x(self.input_tensor) | range_out = self.range_x(self.input_tensor) | ||||
| return range_out | return range_out | ||||
| class LinSpace(Cell): | |||||
| r""" | |||||
| Generates values in an interval. And return the corresponding interpolation accroding to assist. | |||||
| Args: | |||||
| - **start** (Union[int, float]) - The start of interval, With shape of 0-D. | |||||
| - **stop** (Union[int, float]) - The end of interval, With shape of 0-D. | |||||
| - **num** (int) - ticks number in the interval, the ticks include start and stop value. | |||||
| With shape of 0-D. | |||||
| Outputs: | |||||
| Tensor, With type same as `start`. The shape is 1-D with length of `num`. | |||||
| Examples: | |||||
| >>> linspace = nn.LinSpace() | |||||
| >>> start = Tensor(1, mindspore.float32) | |||||
| >>> stop = Tensor(10, mindspore.float32) | |||||
| >>> num = Tensor(5, mindspore.int32) | |||||
| >>> output = linspace(start, stop, num) | |||||
| [1, 3.25, 5.5, 7.75, 10] | |||||
| """ | |||||
| def __init__(self, start, stop, num): | |||||
| super(LinSpace, self).__init__() | |||||
| validator.check_value_type("start", start, [int, float], self.cls_name) | |||||
| validator.check_value_type("stop", stop, [int, float], self.cls_name) | |||||
| validator.check_value_type("num", num, [int], self.cls_name) | |||||
| validator.check_integer("num", num, 0, Rel.GT, self.cls_name) | |||||
| self.is_single = bool(num == 1) | |||||
| self.lin_space = inner.LinSpace() | |||||
| self.start = Tensor(start, mstype.float32) | |||||
| self.stop = Tensor(stop, mstype.float32) | |||||
| self.assist = Tensor(list(range(num)), mstype.float32) | |||||
| self.num = Tensor(num, mstype.int32) | |||||
| self.start_array = Tensor([start], mstype.float32) | |||||
| def construct(self): | |||||
| if self.is_single: | |||||
| return self.start_array | |||||
| lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num) | |||||
| return lin_space_out | |||||
| @@ -21,6 +21,7 @@ from mindspore.ops import _selected_grad_ops as SG | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| from ..operations import _grad_ops as G | from ..operations import _grad_ops as G | ||||
| from ..operations import _inner_ops as inner | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..functional import broadcast_gradient_args, reduced_shape, tuple_div | from ..functional import broadcast_gradient_args, reduced_shape, tuple_div | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| @@ -1049,3 +1050,13 @@ def get_bprop_inv(self): | |||||
| dx = inv_grad(out, dout) | dx = inv_grad(out, dout) | ||||
| return (dx,) | return (dx,) | ||||
| return bprop | return bprop | ||||
| @bprop_getters.register(inner.LinSpace) | |||||
| def get_bprop_lin_space(self): | |||||
| """Grad definition for `LinSpace` operation.""" | |||||
| def bprop(assist, start, stop, num, out, dout): | |||||
| return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num) | |||||
| return bprop | |||||
| @@ -262,3 +262,4 @@ from .tensor_scatter_update import _tensor_scatter_update_tbe | |||||
| from .inplace_update import _inplace_update_tbe | from .inplace_update import _inplace_update_tbe | ||||
| from .splitv import _split_v_tbe | from .splitv import _split_v_tbe | ||||
| from .in_top_k import _in_top_k_tbe | from .in_top_k import _in_top_k_tbe | ||||
| from .lin_space import _lin_space_tbe | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """LinSpace op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| lin_space_op_info = TBERegOp("LinSpace") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("lin_space.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("lin_space") \ | |||||
| .partial_flag(True) \ | |||||
| .op_pattern("broadcast") \ | |||||
| .input(0, "assist", False, "required", "all") \ | |||||
| .input(1, "start", False, "required", "all") \ | |||||
| .input(2, "stop", False, "required", "all") \ | |||||
| .input(3, "num", False, "required", "all") \ | |||||
| .output(0, "output", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | |||||
| DataType.F32_Default,) \ | |||||
| .get_op_info() | |||||
| @op_info_register(lin_space_op_info) | |||||
| def _lin_space_tbe(): | |||||
| """LinSpace TBE register""" | |||||
| return | |||||
| @@ -328,3 +328,42 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||||
| 'dtype': params['dtype'], | 'dtype': params['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| return out | return out | ||||
| class LinSpace(PrimitiveWithInfer): | |||||
| r""" | |||||
| Generates values in an interval. And return the corresponding interpolation accroding to assist. | |||||
| Inputs: | |||||
| - **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D. | |||||
| - **start** (Tensor[float32]) - The start of interval, With shape of 0-D. | |||||
| - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D. | |||||
| - **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value. | |||||
| With shape of 0-D. | |||||
| Outputs: | |||||
| Tensor, has the same shape as `assist`. | |||||
| Examples: | |||||
| >>> linspace = P.LinSpace() | |||||
| >>> assist = Tensor([5, 5.5], mindspore.float32) | |||||
| >>> start = Tensor(1, mindspore.float32) | |||||
| >>> stop = Tensor(10, mindspore.float32) | |||||
| >>> num = Tensor(5, mindspore.int32) | |||||
| >>> output = linspace(assist, start, stop, num) | |||||
| [12.25, 13.375] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| pass | |||||
| def infer_shape(self, assist, start, stop, num): | |||||
| return assist | |||||
| def infer_dtype(self, assist, start, stop, num): | |||||
| args = {"num": num} | |||||
| validator.check_tensor_type_same(args, (mstype.int32,), self.name) | |||||
| args = {"assist": assist, "start": start, "stop": stop} | |||||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | |||||
| return assist | |||||
| @@ -1599,6 +1599,14 @@ test_case_array_ops = [ | |||||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | ||||
| Tensor(np.array([1, 2, 3]).astype(np.int32))], | Tensor(np.array([1, 2, 3]).astype(np.int32))], | ||||
| 'desc_bprop': [[3, 3]]}), | 'desc_bprop': [[3, 3]]}), | ||||
| ('LinSpace', { | |||||
| 'block': inner.LinSpace(), | |||||
| 'desc_inputs': [Tensor([5, 5.5], mstype.float32), | |||||
| Tensor(1, mstype.float32), | |||||
| Tensor(10, mstype.float32), | |||||
| Tensor(5, mstype.int32)], | |||||
| 'skip': ['backward'], | |||||
| }), | |||||
| ] | ] | ||||
| test_case_other_ops = [ | test_case_other_ops = [ | ||||