|
- # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
- #
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """standard_method"""
- from dataclasses import dataclass
- from mindspore.common import dtype as mstype
- from ...ops import functional as F
- from ...ops import operations as P
- from ...ops.primitive import constexpr
- from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
- zeros_like, ones_like
- from ...ops.composite.base import _append
-
- __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
-
- trans = P.Transpose()
-
-
- def transpose(x):
- """Implementation of `transpose`."""
- shape = F.shape(x)
- length = F.tuple_len(shape)
- perm = F.make_range(0, length)
- revert_perm = F.tuple_reversed(perm)
- out = trans(x, revert_perm)
- return out
-
-
- def getitem(data, item):
- """Implementation of `getitem`."""
- return data.__getitem__(item)
-
-
- def setitem(data, item, value):
- """Implementation of `setitem`."""
- return data.__setitem__(item, value)
-
-
- def ms_iter(xs):
- """Implementation of `iter`."""
- return xs.__ms_iter__()
-
-
- def ms_next(it):
- """Implementation of `next`."""
- return it.__ms_next__()
-
-
- def hasnext(it):
- """Implementation of `hasnext`."""
- return it.__ms_hasnext__()
-
-
- def ms_len(data):
- """Implementation of `len`."""
- return data.__len__()
-
-
- def floor(x):
- """Implementation of `floor`."""
- return x.__floor__()
-
-
- def trunc(x):
- """Implementation of `trunc`."""
- return x.__trunc__()
-
-
- def uadd(x):
- """Implementation of `uadd`."""
- return x.__pos__()
-
-
- def usub(x):
- """Implementation of `usub`."""
- return x.__neg__()
-
-
- def scalar_truediv(x, y):
- """Implementation of `scalar_truediv`."""
- return x.__truediv__(y)
-
-
- def scalar_floordiv(x, y):
- """Implementation of `scalar_floordiv`."""
- return x.__floordiv__(y)
-
-
- def bool_(x):
- """Implementation of `bool`."""
- return x.__bool__()
-
-
- def enumerate_(x, start=0):
- """Enumerate list or tuple."""
- x_type = F.typeof(x)
- ret = ()
- op_name = "enumerate"
- if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
- ret = zip(range(start, start + len(x)), x)
- return ret
-
-
- def isinstance_(x, base_type):
- """Determine whether x is an instance of base_type."""
- x_type = F.typeof(x)
- return check_type_same(x_type, base_type)
-
-
- def while_cond(x):
- """For while condtion, if the condition is a tensor, the loop will not be unrolled"""
- if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
- is_cond = check_is_tensor_bool_cond(F.shape(x))
- if is_cond:
- return F.cast(x, mstype.bool_)
- return x
-
-
- @constexpr
- def check_type_same(x_type, base_type):
- """Check x_type is same as base_type."""
- if mstype.issubclass_(x_type, base_type):
- return True
- raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.")
-
-
- @constexpr
- def check_is_tuple_or_list(x, op_name, arg_name):
- """check whether x is list or tuple."""
- if isinstance(x, (mstype.list_type, mstype.tuple_type)):
- return True
- raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.")
-
-
- @constexpr
- def check_is_const_int(x, op_name, arg_name):
- """check whether x is const int."""
- if x is None:
- raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.")
- if not isinstance(x, int):
- raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.")
- return True
-
-
-
- @constexpr
- def check_is_tensor_bool_cond(shp):
- """check if tensor is a bool condition"""
- if shp in ((), (1,)):
- return True
- raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
-
-
- @constexpr
- def const_tensor_to_bool(x):
- """convert bool tensor to bool condition"""
- if x is None:
- raise ValueError("Only constant tensor bool can be converted to bool")
- x = x.asnumpy()
- if x.shape not in ((), (1,)):
- raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
- if x.shape == ():
- value = bool(x)
- else:
- value = bool(x[0])
- return value
-
-
- def tensor_bool(x):
- """tensor as conditon, if is constant, return immediate bool value"""
- is_cond = check_is_tensor_bool_cond(F.shape(x))
- if is_cond and F.isconstant(x):
- return const_tensor_to_bool(x)
- return F.cast(x, mstype.bool_)
-
-
- def and_(x, y):
- """Implementation of `and` (`&`)."""
- return x.__and__(y)
-
-
- def or_(x, y):
- """Implementation of `or` (`|`)."""
- return x.__or__(y)
-
-
- def matmul(x, y):
- """Implementation of `matmul` (`@`)."""
- return x.__matmul__(y)
-
-
- def float_bool(x):
- """Implementation of `float_bool`."""
- return x != 0.0
-
-
- def int_bool(x):
- """Implementation of `int_bool`."""
- return x != 0
-
-
- def str_bool(x):
- """Implementation of `str_bool`."""
- if x == "":
- return False
- return True
-
-
- def list_bool(x):
- """Implementation of `tuple_bool`."""
- return len(x) != 0
-
-
- def tuple_bool(x):
- """Implementation of `tuple_bool`."""
- return len(x) != 0
-
-
- def dict_bool(x):
- """Implementation of `dict_bool`."""
- return len(x) != 0
-
-
- def none_bool(x):
- """Implementation of `none_bool`."""
- return False
-
-
- def float_floordiv(x, y):
- """Implementation of `float_floordiv`."""
- return floor(x / y)
-
-
- #############
- # Iteration #
- #############
-
-
- @dataclass(frozen=True)
- class SequenceIterator:
- """
- SequenceIterator is a util dataclass for iterating sequence object.
-
- Iterator to use for sequences like List, Array.
- """
-
- idx: int
- seq: list
-
- @core(ignore_values=True)
- def __ms_hasnext__(self):
- """Whether the index is past the length of the sequence."""
- return self.idx < ms_len(self.seq)
-
- @core(ignore_values=True)
- def __ms_next__(self):
- """Return the next element and a new iterator."""
- return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
-
-
- def list_iter(xs):
- """Iterator for List."""
- return SequenceIterator(0, xs)
-
-
- def array_iter(xs):
- """Iterator for Array."""
- return SequenceIterator(0, xs)
-
-
- def tuple_next(xs):
- """Next tuple."""
- return xs[0], tail(xs)
-
-
- def tuple_hasnext(xs):
- """Whether the tuple is empty or not."""
- return len(xs) > 0
-
-
- def list_next(xs):
- """Next list."""
- return xs[0], tail(xs)
-
-
- def list_hasnext(xs):
- """Whether the list is empty or not."""
- return len(xs) > 0
-
-
- def list_append(self_, item):
- return _append(self_, item)
-
-
- #################
- # Array methods #
- #################
-
-
- def to_array(x):
- """Implementation of `to_array`."""
- return x.__ms_to_array__()
|