# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # # Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """The names of functional part are summarized here.""" from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common import ms_function from mindspore.common import Tensor from mindspore.nn.grad.cell_grad import _JvpInner from mindspore.nn.grad.cell_grad import _VjpInner from mindspore.ops import _constants from mindspore.ops.primitive import constexpr from .primitive import Primitive from . import operations as P from .operations import _grad_ops from .composite import _Grad from .._c_expression import security typeof = Primitive('typeof') hastype = Primitive('hastype') cast = P.Cast() dtype = P.DType() isconstant = Primitive('is_constant') isconstant.set_const_prim(True) issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() eye = P.Eye() fill = P.Fill() tile = P.Tile() select = P.Select() size = P.Size() ones_like = P.OnesLike() shape = P.Shape() dyn_shape = P.DynamicShape() rank = P.Rank() reshape = P.Reshape() merge = P.Merge() geswitch = P.GeSwitch() addn = P.AddN() absolute = P.Abs() tensor_add = P.Add() add = tensor_add neg_tensor = P.Neg() tensor_lt = P.Less() less = tensor_lt tensor_le = P.LessEqual() le = tensor_le tensor_gt = P.Greater() gt = tensor_gt tensor_ge = P.GreaterEqual() ge = tensor_ge tensor_sub = P.Sub() sub = tensor_sub tensor_mul = P.Mul() mul = tensor_mul tensor_div = P.RealDiv() div = tensor_div tensor_floordiv = P.FloorDiv() floordiv = tensor_floordiv tensor_pow = P.Pow() pows = tensor_pow tensor_mod = P.FloorMod() floormod = tensor_mod tensor_exp = P.Exp() exp = tensor_exp tensor_expm1 = P.Expm1() tensor_slice = P.Slice() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() check_bprop = P.CheckBprop() equal = P.Equal() not_equal = P.NotEqual() isfinite = P.IsFinite() isnan = P.IsNan() assign_sub = P.AssignSub() assign_add = P.AssignAdd() assign = P.Assign() square = P.Square() sqrt = P.Sqrt() log = P.Log() reduce_sum = P.ReduceSum() reduce_max = P.ReduceMax() reduce_min = P.ReduceMin() reduce_mean = P.ReduceMean() reduce_prod = P.ReduceProd() tensor_slice = P.Slice() maximum = P.Maximum() minimum = P.Minimum() floor = P.Floor() logical_not = P.LogicalNot() logical_or = P.LogicalOr() logical_and = P.LogicalAnd() sin = P.Sin() cos = P.Cos() tan = P.Tan() asin = P.Asin() acos = P.ACos() atan = P.Atan() sinh = P.Sinh() cosh = P.Cosh() tanh = P.Tanh() asinh = P.Asinh() acosh = P.Acosh() atanh = P.Atanh() atan2 = P.Atan2() bitwise_and = P.BitwiseAnd() bitwise_or = P.BitwiseOr() bitwise_xor = P.BitwiseXor() invert = P.Invert() erf = P.Erf() erfc = P.Erfc() sort = P.Sort() tensor_range = P.Range() scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() if not security.enable_security(): print_ = P.Print() expand_dims = P.ExpandDims() transpose = P.Transpose() squeeze = P.Squeeze() scatter_nd = P.ScatterNd() gather = P.Gather() gather_d = P.GatherD() gather_nd = P.GatherNd() scatter_update = P.ScatterUpdate() tensor_scatter_update = P.TensorScatterUpdate() scatter_nd_update = P.ScatterNdUpdate() stack = P.Stack() def pack(x): """Call stack in this pack function.""" print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead" ".") return stack(x) partial = P.Partial() # depend: mount a node to another node depend = P.Depend() identity = P.identity() @constexpr def _convert_grad_position_type(grad_position): """Check and convert the type and size of grad position index.""" if isinstance(grad_position, tuple): for gp in grad_position: if not isinstance(gp, int): raise TypeError(f"For 'F.grad', the element in 'grad_position' should be int, " f"but got {type(gp).__name__}") if gp < 0: raise ValueError("The element in grad_position must be >= 0.") elif isinstance(grad_position, int): if grad_position < 0: raise ValueError("grad_position must be >= 0.") grad_position = (grad_position,) else: raise TypeError(f"For 'F.grad', the 'grad_position' should be int or tuple, " f"but got {type(grad_position).__name__}") return grad_position grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True) grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True) def grad(fn, grad_position=0, sens_param=False): r""" A wrapper function to generate the gradient function for the input function. Args: fn (Union(Cell, function)): Function to do GradOperation. grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input. If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0. sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. Returns: Function, returns the gradient function for the input function or cell. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import numpy as np >>> import mindspore.nn as nn >>> import mindspore.context as context >>> from mindspore import Tensor >>> from mindspore.ops.functional import grad >>> context.set_context(mode=context.GRAPH_MODE) >>> class Net(nn.Cell): ... def construct(self, x, y, z): ... return x*y*z >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) >>> z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) >>> net = Net() >>> output = grad(net, grad_position=(1, 2))(x, y, z) >>> print(output) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 0.00000000e+00, 6.00000000e+00], [ 1.50000000e+01, -4.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= [[-2.00000000e+00, 6.00000000e+00], [-3.00000000e+00, 8.00000000e+00]])) """ grad_position = _convert_grad_position_type(grad_position) if sens_param: return grad_by_position_with_sens(fn, None, grad_position) return grad_by_position(fn, None, grad_position) def jvp(fn, inputs, v): """ Compute the jacobian-vector-product of the given network. Args: fn (Function or Cell): The function or net that takes Tensor inputs and returns a tensor or tuple of Tensors. inputs (Tensor or tuple or list): The inputs to `fn`. v (Tensor or tuple or list): The shape and type of v should be the same as inputs. Returns: Tuple, tuple of output and jvp. - **netout** (Tensors or Tuple of Tensors) - The output of "fn(inputs)". - **jvp** (Tensors or Tuple of Tensors) - The result of the dot product. Raises: TypeError: If the input is not a tensor or tuple or list of tensors. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> from mindspore.ops import functional as F >>> from mindspore import Tensor >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x**3 + y >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> output = F.jvp(Net(), (x, y), (v, v)) >>> print(output[0]) [[ 2. 10.] [30. 68.]] >>> print(output[1]) [[ 4. 13.] [28. 49.]] """ jvp_inner = _JvpInner() @ms_function def _wrap_container(*arg): args = arg[1:] vectors = arg[0] return jvp_inner(fn, vectors, *args) if not isinstance(inputs, (Tensor, tuple, list)): _raise_type_error() if isinstance(inputs, (tuple, list)): return _wrap_container(v, *inputs) return _wrap_container(v, inputs) def vjp(fn, inputs, v): """ Compute the vector-jacobian-product of the given network. Args: fn (Function or Cell): The function or net that takes Tensor inputs and returns a tensor or tuple of Tensors. inputs (Tensor or tuple or list): The inputs to `fn`. v (Tensor or tuple or list): The shape and type of v should be the same as outputs. Returns: Tuple, tuple of output and vjp. - **netout** (Tensors or Tuple of Tensors) - The output of "fn(inputs)". - **vjp** (Tensors or Tuple of Tensors) - The result of the dot product. Raises: TypeError: If the input is not a tensor or tuple or list of tensors. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> from mindspore.ops import functional as F >>> from mindspore import Tensor >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x**3 + y >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> output = F.vjp(Net(), (x, y), v) >>> print(output[0]) [[ 2. 10.] [30. 68.]] >>> print(output[1]) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 3.00000000e+00, 1.20000000e+01], [ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value= [[ 1.00000000e+00, 1.00000000e+00], [ 1.00000000e+00, 1.00000000e+00]])) """ vjp_inner = _VjpInner() @ms_function def wrap_container(*arg): args = arg[:-1] vectors = arg[-1] return vjp_inner(fn, *args, vectors) if not isinstance(inputs, (Tensor, tuple, list)): _raise_type_error() if isinstance(inputs, (tuple, list)): return wrap_container(*inputs, v) return wrap_container(inputs, v) @constexpr def _raise_type_error(): raise TypeError("The inputs type should be a Tensor, tuple or list of Tensor.") tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive(_constants.kTupleGetItem) list_getitem = Primitive('list_getitem') list_setitem = Primitive('list_setitem') dict_getitem = Primitive('dict_getitem') dict_setitem = Primitive('dict_setitem') tuple_div = Primitive("tuple_div") tuple_len = Primitive("tuple_len") list_len = Primitive("list_len") tuple_reversed = Primitive("tuple_reversed") make_range = Primitive("make_range") make_tuple = Primitive('MakeTuple') make_dict = Primitive('make_dict') make_list = Primitive('make_list') make_slice = Primitive('make_slice') tuple_equal = Primitive("tuple_equal") list_equal = Primitive("list_equal") make_ref = Primitive("make_ref") scalar_add = Primitive(_constants.kScalarAdd) scalar_mul = Primitive(_constants.kScalarMul) scalar_sub = Primitive(_constants.kScalarSub) scalar_div = Primitive(_constants.kScalarDiv) scalar_floordiv = Primitive(_constants.kScalarFloordiv) scalar_log = Primitive('scalar_log') scalar_pow = Primitive(_constants.kScalarPow) scalar_gt = Primitive('scalar_gt') scalar_ge = Primitive('scalar_ge') scalar_le = Primitive('scalar_le') scalar_lt = Primitive('scalar_lt') scalar_eq = Primitive('scalar_eq') scalar_ne = Primitive('scalar_ne') scalar_uadd = Primitive(_constants.kScalarUadd) scalar_usub = Primitive(_constants.kScalarUsub) scalar_mod = Primitive(_constants.kScalarMod) string_eq = Primitive('string_equal') string_concat = Primitive('string_concat') bool_not = Primitive("bool_not") bool_or = Primitive("bool_or") bool_and = Primitive("bool_and") bool_eq = Primitive("bool_eq") logical_and = P.LogicalAnd() logical_or = P.LogicalOr() logical_not = P.LogicalNot() cumsum = P.CumSum() cumprod = P.CumProd() tensor_scatter_add = P.TensorScatterAdd() array_to_scalar = Primitive('array_to_scalar') is_ = Primitive("is_") is_not = Primitive("is_not") in_dict = Primitive("in_dict") not_in_dict = Primitive("not_in_dict") mixed_precision_cast = Primitive("mixed_precision_cast") broadcast_gradient_args = Primitive('BroadcastGradientArgs') array_reduce = Primitive('array_reduce') zeros_like = P.ZerosLike() distribute = Primitive('distribute') embed = Primitive('embed') ref_to_embed = _grad_ops.RefToEmbed() env_setitem = Primitive('env_setitem') env_getitem = Primitive('env_getitem') env_add = Primitive('env_add') J = Primitive('J') SliceGetItem = Primitive("SliceGetItem") switch = Primitive('Switch') switch_layer = Primitive('switch_layer') # for sum bprop reduced_shape = Primitive("reduced_shape") # shape_mul:input must be shape multiply elements in tuple(shape) shape_mul = Primitive("shape_mul") # a primitive to compare between tuple. stop_gradient = Primitive("stop_gradient") make_row_tensor = Primitive('MakeRowTensor') row_tensor_get_values = Primitive('RowTensorGetValues') row_tensor_get_indices = Primitive('RowTensorGetIndices') row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') row_tensor_add = Primitive('RowTensorAdd') make_sparse_tensor = Primitive('MakeSparseTensor') sparse_tensor_get_values = Primitive('SparseTensorGetValues') sparse_tensor_get_indices = Primitive('SparseTensorGetIndices') sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape') make_csr_tensor = Primitive('MakeCSRTensor') csr_tensor_get_values = Primitive('CSRTensorGetValues') csr_tensor_get_indices = Primitive('CSRTensorGetIndices') csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr') csr_tensor_get_shape = Primitive('CSRTensorGetDenseShape') tensor_operator_registry.register('all', P.ReduceAll) tensor_operator_registry.register('any', P.ReduceAny) tensor_operator_registry.register('abs', P.Abs) tensor_operator_registry.register('mean', P.ReduceMean) tensor_operator_registry.register('reshape', P.Reshape) tensor_operator_registry.register('transpose', P.Transpose) tensor_operator_registry.register('broadcast_to', P.BroadcastTo) tensor_operator_registry.register('matmul', P.MatMul) tensor_operator_registry.register('argmax', P.Argmax) tensor_operator_registry.register('cumsum', P.CumSum) tensor_operator_registry.register('reduce_max', P.ReduceMax) tensor_operator_registry.register('reduce_min', P.ReduceMin) tensor_operator_registry.register('maximum', P.Maximum) tensor_operator_registry.register('minimum', P.Minimum) tensor_operator_registry.register('fill', P.Fill) tensor_operator_registry.register('tile', P.Tile) tensor_operator_registry.register('logical_not', P.LogicalNot) tensor_operator_registry.register('sum', P.ReduceSum) tensor_operator_registry.register('split', P.Split) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) tensor_operator_registry.register('__neg__', neg_tensor) tensor_operator_registry.register('__lt__', tensor_lt) tensor_operator_registry.register('__le__', tensor_le) tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__ge__', tensor_ge) tensor_operator_registry.register('__logical_not__', logical_not) tensor_operator_registry.register('shape', shape) tensor_operator_registry.register('squeeze', squeeze) # support GE backend for no compare operators tensor_operator_registry.register('cast', cast) tensor_operator_registry.register('shape_mul', shape_mul) tensor_operator_registry.register('fill', fill) tensor_operator_registry.register('concatenate', P.Concat) tensor_operator_registry.register('eye', eye) tensor_operator_registry.register('reduce_sum', reduce_sum) tensor_operator_registry.register('tensor_slice', tensor_slice) tensor_operator_registry.register('select', select) tensor_operator_registry.register('gather_d', gather_d) tensor_operator_registry.register('gather_nd', gather_nd) tensor_operator_registry.register('stack', P.Stack) tensor_operator_registry.register('log', log) tensor_operator_registry.register('floor', floor) __all__ = [name for name in dir() if name[0] != "_"] __all__.remove('Primitive')