|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Generate vm_impl function for array ops"""
- import numpy as np
- import mindspore.common.dtype as mstype
- from mindspore.common.tensor import Tensor
- from mindspore.ops import operations as P
- from mindspore.ops.operations import _grad_ops as G
- from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
- from .vm_interface import vm
-
- # pylint: disable=unused-argument
- @vm_impl_getters.register(P.Assign)
- def vm_impl_assign(self):
- """Generate vm_impl function for Assign"""
- def vm_impl(x, value, u=None):
- x.assign_value(value)
- return x
- return vm_impl
-
- @vm_impl_getters.register(P.ExpandDims)
- def vm_impl_expand_dims(self):
- """Generate vm_impl function for ExpandDims"""
-
- def vm_impl(x, axis):
- if isinstance(x, float):
- x = Tensor(np.array([x]))
- x = x.asnumpy()
- out = vm.expand_dims(x, axis)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.DType)
- def vm_impl_dType(self):
- """Generate vm_impl function for DType"""
-
- def vm_impl(x):
- # update the src type
- return x.dtype
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Cast)
- def vm_impl_cast(self):
- """Generate vm_impl function for Cast"""
-
- def vm_impl(x, t):
- if isinstance(t, type(mstype.tensor)):
- t = t.element_type()
- # update the src type
- x = x.asnumpy()
- out = x.astype(mstype.dtype_to_nptype(t))
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Reshape)
- def vm_impl_reshape(self):
- """Generate vm_impl function for Reshape"""
-
- def vm_impl(x, shp):
- x = x.asnumpy()
- out = vm.reshape(x, shp)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Shape)
- def vm_impl_shape(self):
- """Generate vm_impl function for Shape"""
-
- def vm_impl(x):
- shp = vm.shape(x.asnumpy())
- return shp
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Squeeze)
- def vm_impl_squeeze(self):
- """Generate vm_impl function for Squeeze"""
-
- def vm_impl(x):
- x = x.asnumpy()
- out = vm.squeeze(x, self.axis)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Transpose)
- def vm_impl_transpose(self):
- """Generate vm_impl function for Transpose"""
-
- def vm_impl(x, perm=None):
- x = x.asnumpy()
- if perm is None:
- perm = [i for i in reversed(range(len(x.shape)))]
- out = vm.transpose(x, perm)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Split)
- def vm_impl_split(self):
- """Generate vm_impl function for Split"""
-
- def vm_impl(x):
- x = x.asnumpy()
- output = np.array_split(x, (self.pos,))
- return Tensor(output[0]), Tensor(output[1])
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Fill)
- def vm_impl_fill(self):
- """Generate vm_impl function for Fill"""
-
- def vm_impl(dims, x):
- if isinstance(x, int):
- ret = np.full(dims, x, np.int32)
- else:
- ret = np.full(dims, x, np.float32)
- return Tensor(ret)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Eye)
- def vm_impl_eye(self):
- """Generate vm_impl function for Eye"""
-
- def vm_impl(n, m, t):
- np_type = mstype.dtype_to_nptype(t)
- ret = np.eye(n, m, dtype=np_type)
- return Tensor(ret)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.InvertPermutation)
- def vm_impl_invert_permutation(self):
- """Generate vm_impl function for InvertPermutation"""
-
- def vm_impl(x):
- out = vm.invert_permutation(x)
- return out
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Argmax)
- def vm_impl_argmax(self):
- """Generate vm_impl function for Argmax"""
-
- def vm_impl(x):
- output = np.argmax(x.asnumpy(), axis=self.axis)
- return Tensor(output.ravel())
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Tile)
- def vm_impl_tile(self):
- """Generate vm_impl function for Tile"""
-
- def vm_impl(x, multiples):
- x = x.asnumpy()
- out = np.tile(x, multiples)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.ReduceAll)
- def vm_impl_all(self):
- """Generate vm_impl function for All"""
-
- def vm_impl(x, axis):
- x = x.asnumpy()
- out = vm.all(x, axis, self.keep_dims)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.ReduceAny)
- def vm_impl_any(self):
- """Generate vm_impl function for Any"""
-
- def vm_impl(x, axis):
- x = x.asnumpy()
- out = vm.any(x, axis, self.keep_dims)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Concat)
- def vm_impl_concatV2(self):
- """Generate vm_impl function for Concat"""
-
- def vm_impl(x):
- x = x.asnumpy()
- out = vm.Concat(x, self.axis)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Slice)
- def vm_impl_slice(self):
- """Generate vm_impl function for Slice"""
-
- def vm_impl(x, begin, size):
- x = x.asnumpy()
- begin = begin.asnumpy()
- size = size.asnumpy()
- out = vm.Slice(x, begin, size)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(G.ConcatOffset)
- def vm_impl_concatOffset(self):
- """Generate vm_impl function for ConcatOffset"""
-
- def vm_impl(x):
- out = vm.ConcatOffset(x) # out is tuple
- return out
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.ReduceSum)
- def vm_impl_sum(self):
- """Generate vm_impl function for Sum"""
-
- def vm_impl(x, axis):
- x = x.asnumpy()
- if axis == ():
- out = np.sum(x)
- else:
- out = np.sum(x, axis=axis)
- return Tensor(np.array(out))
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Select)
- def vm_impl_select(self):
- """Generate vm_impl function for Select"""
-
- def vm_impl(cond, x, y):
- """
- Args:
- cond: A `Tensor` of type `bool`
- x: A Tensor which may have the same shape as `condition`.
- y: A `Tensor` with the same shape and type as `x`.
- """
- cond = cond.asnumpy()
- x = x.asnumpy()
- y = y.asnumpy()
- out = vm.select(cond, x, y)
- return Tensor(out)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Square)
- def vm_impl_square(self):
- """Generate vm_impl function for Square"""
-
- def vm_impl(x):
- x = x.asnumpy()
- return Tensor(x * x)
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.ZerosLike)
- def vm_impl_zeros_like(self):
- """Generate vm_impl function for ZerosLike"""
- def vm_impl(x):
- return Tensor(np.zeros_like(x.asnumpy()))
-
-
- @vm_impl_getters.register(P.Partial)
- def vm_impl_partial(self):
- """Generate vm_impl function for Partial"""
- def vm_impl(*args):
- func = args[0].__call__
- partial_func = functools.partial(func, *args[1:])
- return partial_func
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Depend)
- def vm_impl_depend(self):
- """Generate vm_impl function for Depend"""
- def vm_impl(value, expr):
- return value
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.UpdateState)
- def vm_impl_updatestate(self):
- """Generate vm_impl function for UpdateState"""
- def vm_impl(monad, expr):
- return monad
-
- return vm_impl
-
-
- @vm_impl_getters.register(P.Load)
- def vm_impl_load(self):
- """Generate vm_impl function for Load"""
- def vm_impl(value, u=None):
- return value
-
- return vm_impl
|