#! /usr/bin/python # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function import itertools import mindspore as ms import mindspore.ops as P from mindspore import context from mindspore.nn.cell import Cell from mindspore._checkparam import Rel from mindspore.ops import functional as F from mindspore.communication import management from mindspore.ops.operations import _inner_ops as inner from mindspore._extends import cell_attr_register from mindspore.ops._grad.grad_base import bprop_getters from mindspore._checkparam import Validator as validator from mindspore.communication.management import get_group_size, get_rank def padding_format(padding): """ Checks that the padding format correspond format. Parameters ---------- padding : str Must be one of the following:"same", "SAME", "VALID", "valid" Returns ------- str "SAME" or "VALID" """ if padding in ["SAME", "same"]: padding = "same" elif padding in ["VALID", "valid"]: padding = "valid" elif padding == None: padding = None else: raise Exception("Unsupported padding: " + str(padding)) return padding def preprocess_1d_format(data_format, padding): """ Checks that the 1-D dataformat format correspond format. Parameters ---------- data_format : str Must be one of the following:"channels_last","NWC","NCW","channels_first" padding : str Must be one of the following:"same","valid","SAME","VALID" Returns ------- str "NWC" or "NCW" and "SAME" or "VALID" """ if data_format in ["channels_last", "NWC"]: data_format = "NWC" elif data_format in ["channels_first", "NCW"]: data_format = "NCW" elif data_format == None: data_format = None else: raise Exception("Unsupported data format: " + str(data_format)) padding = padding_format(padding) return data_format, padding def preprocess_2d_format(data_format, padding): """ Checks that the 2-D dataformat format correspond format. Parameters ---------- data_format : str Must be one of the following:"channels_last","NHWC","NCHW","channels_first" padding : str Must be one of the following:"same","valid","SAME","VALID" Returns ------- str "NHWC" or "NCHW" and "SAME" or "VALID" """ if data_format in ["channels_last", "NHWC", "nhwc"]: data_format = "NHWC" elif data_format in ["channels_first", "NCHW", "nchw"]: data_format = "NCHW" elif data_format == None: data_format = None else: raise Exception("Unsupported data format: " + str(data_format)) padding = padding_format(padding) return data_format, padding def preprocess_3d_format(data_format, padding): """ Checks that the 3-D dataformat format correspond format. Parameters ---------- data_format : str Must be one of the following:"channels_last","NDHWC","NCDHW","channels_first" padding : str Must be one of the following:"same","valid","SAME","VALID" Returns ------- str "NDHWC" or "NCDHW" and "SAME" or "VALID" """ if data_format in ['channels_last', 'NDHWC']: data_format = 'NDHWC' elif data_format in ['channels_first', 'NCDHW']: data_format = 'NCDHW' elif data_format == None: data_format = None else: raise Exception("Unsupported data format: " + str(data_format)) padding = padding_format(padding) return data_format, padding def nchw_to_nhwc(x): """ Channels first to channels last Parameters ---------- x : tensor channels first tensor data Returns ------- channels last tensor data """ if len(P.Shape()(x)) == 3: x = P.Transpose()(x, (0, 2, 1)) elif len(P.Shape()(x)) == 4: x = P.Transpose()(x, (0, 2, 3, 1)) elif len(P.Shape()(x)) == 5: x = P.Transpose()(x, (0, 2, 3, 4, 1)) # else: # raise Exception("Unsupported dimensions") return x def nhwc_to_nchw(x): """ Channles last to channels first Parameters ---------- x : tensor channels last tensor data Returns ------- channels first tensor data """ if len(P.Shape()(x)) == 3: x = P.Transpose()(x, (0, 2, 1)) elif len(P.Shape()(x)) == 4: x = P.Transpose()(x, (0, 3, 1, 2)) elif len(P.Shape()(x)) == 5: x = P.Transpose()(x, (0, 4, 1, 2, 3)) # else: # raise Exception("Unsupported dimensions") return x class ReLU(Cell): def __init__(self): super(ReLU, self).__init__() self.relu = P.ReLU() def construct(self, x): return self.relu(x) def relu(x): """ Computes rectified linear: max(features, 0). Parameters ---------- x : tensor Must be one of the following types: float32, float64, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64, qint8. Returns ------- A Tensor. Has the same type as features. """ outputs = P.ReLU() return outputs(x) class ReLU6(Cell): def __init__(self): super(ReLU6, self).__init__() self.relu6 = P.ReLU6() def construct(self, x): return self.relu6(x) def relu6(x): """ Computes Rectified Linear 6: min(max(features, 0), 6). Parameters ---------- x : tensor Must be one of the following types: float32, float64, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64, qint8. Returns ------- A Tensor with the same type as features. """ outputs = P.ReLU6() return outputs(x) class LeakyReLU(Cell): def __init__(self, alpha=0.2): super(LeakyReLU, self).__init__() self.leakyrelu = ms.nn.LeakyReLU(alpha=alpha) def construct(self, x): return self.leakyrelu(x) def leaky_relu(x, alpha=0.2): """ Compute the Leaky ReLU activation function. Parameters ---------- x : tensor representing preactivation values. Must be one of the following types: float16, float32, float64, int32, int64. Returns ------- The activation value. """ leaky_relu = LeakyReLU(alpha=alpha) output = leaky_relu(x) return leaky_relu class Softplus(Cell): def __init__(self): super(Softplus, self).__init__() self.softplus = P.Softplus() def construct(self, x): return self.softplus(x) def softplus(x): """ Computes softplus: log(exp(features) + 1). Parameters ---------- x : tensor Must be one of the following types: half, bfloat16, float32, float64. Returns ------- A Tensor. Has the same type as features. """ obj = Softplus() return obj(x) class Tanh(Cell): def __init__(self): super(Tanh, self).__init__() self.tanh = P.Tanh() def construct(self, x): return self.tanh(x) def tanh(x): """ Computes hyperbolic tangent of x element-wise. Parameters ---------- x : tensor Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128. Returns ------- A Tensor. Has the same type as x. """ _tanh = Tanh() return _tanh(x) class Sigmoid(Cell): def __init__(self): super(Sigmoid, self).__init__() self.sigmoid = P.Sigmoid() def construct(self, x): return self.sigmoid(x) def sigmoid(x): """ Computes sigmoid of x element-wise. Parameters ---------- x : tensor A Tensor with type float16, float32, float64, complex64, or complex128. Returns ------- A Tensor with the same type as x. """ outputs = P.Sigmoid() return outputs(x) class Softmax(Cell): def __init__(self): super(Softmax, self).__init__() self.softmax = P.Softmax() def construct(self, x): return self.softmax(x) def softmax(logits, axis=None): """ Computes softmax activations. Parameters ---------- logits : tensor Must be one of the following types: half, float32, float64. axis : int The dimension softmax would be performed on. The default is -1 which indicates the last dimension. Returns ------- A Tensor. Has the same type and shape as logits. """ outputs = P.Softmax(axis) return outputs(logits) class Dropout(Cell): def __init__(self, keep, seed=0): super(Dropout, self).__init__() self.dropout = P.Dropout(keep_prob=keep) self.is_gpu = context.get_context('device_target') in ["GPU"] self.get_shape = P.Shape() self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed, Seed1=0) self.dropout_do_mask = P.DropoutDoMask() self.cast = P.Cast() self.keep_prob = keep # ms.Tensor(keep, dtype=ms.float32) # print(self.keep_prob, type(self.keep_prob)) def construct(self, inputs): if self.is_gpu: outputs, _ = self.dropout(inputs) return outputs if self.keep_prob == 1: return inputs shape = self.get_shape(inputs) dtype = P.DType()(inputs) if self._is_float_dtype(dtype): keep_prob = self.cast(self.keep_prob, dtype=dtype) else: keep_prob = self.cast(self.keep_prob, ms.float16) output = self.dropout_gen_mask(shape, keep_prob) return self.dropout_do_mask(inputs, output, keep_prob) def _is_float_dtype(dtype): if dtype in [ms.float32, ms.float16]: return True return False class BiasAdd(Cell): """ Adds bias to value. Parameters ---------- x : tensor A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128. bias : tensor Must be the same type as value unless value is a quantized type, in which case a different quantized type may be used. Returns ------- A Tensor with the same type as value. """ def __init__(self, data_format='channels_first'): super(BiasAdd, self).__init__() self.bias_add = P.BiasAdd() if data_format in ['channels_first', 'NCW', 'NCHW', 'NCDHW']: self.data_format = 'channels_first' elif data_format in ['channels_last', 'NWC', 'NHWC', 'NDHWC']: self.data_format = 'channels_last' else: raise ("Unsupported data format: " + str(data_format)) def construct(self, x, bias): if self.data_format == 'channels_last': x = nhwc_to_nchw(x) outputs = self.bias_add(x, bias) if self.data_format == 'channels_last': outputs = nchw_to_nhwc(outputs) return outputs def bias_add(x, bias): """ Adds bias to value. Parameters ---------- x : tensor A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128. bias : tensor Must be the same type as value unless value is a quantized type, in which case a different quantized type may be used. data_format : A string. 'N...C' and 'NC...' are supported. name : str A name for the operation (optional). Returns ------- A Tensor with the same type as value. """ raise NotImplementedError class Conv1D(Cell): def __init__(self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None): super(Conv1D, self).__init__() self.data_format, self.padding = preprocess_1d_format(data_format, padding) self.stride = (1, stride) self.dilations = (1, dilations) self.k_size = (1, k_size) self.out_channel = out_channel self.conv2d = P.Conv2D( out_channel=self.out_channel, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride, dilation=self.dilations, mode=1, group=1 ) self.expand_dims = P.ExpandDims() self.squeeze = P.Squeeze(2) def construct(self, x, filters): if self.data_format == 'NWC': x = nhwc_to_nchw(x) x = self.expand_dims(x, 2) filters = self.expand_dims(filters, 2) output = self.conv2d(x, filters) output = self.squeeze(output) if self.data_format == 'NWC': output = nchw_to_nhwc(output) return output def conv1d(input, filters, stride, padding, data_format='NWC', dilations=None, name=None): """ Computes a 1-D convolution given 3-D input and filter tensors. Parameters ---------- input : tensor A 3D Tensor. Must be of type float16, float32, or float64 filters : tensor A 3D Tensor. Must have the same type as input. stride : int of list An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step. padding : string 'SAME' or 'VALID' data_format : string An optional string from "NWC", "NCW". Defaults to "NWC", the data is stored in the order of [batch, in_width, in_channels]. The "NCW" format stores data as [batch, in_channels, in_width]. dilations : int or list An int or list of ints that has length 1 or 3 which defaults to 1. The dilation factor for each dimension of input. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension. Dilations in the batch and depth dimensions must be 1. name : string A name for the operation (optional). Returns ------- A Tensor. Has the same type as input. """ pass class Conv2D(Cell): def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None): super(Conv2D, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) if self.data_format is 'NHWC': self.ms_stride = strides[1] self.ms_dilation = dilations[1] elif self.data_format is 'NCHW': self.ms_stride = strides[2] self.ms_dilation = dilations[2] self.conv2d = P.Conv2D( out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=1, data_format=self.data_format ) def construct(self, inputs, filters): outputs = self.conv2d(inputs, filters) return outputs def conv2d(input, filters, strides, padding, data_format='NCHW', dilations=None): """ Computes a 2-D convolution given 4-D input and filters tensors. Parameters ---------- input : tensor Must be one of the following types: half, bfloat16, float32, float64. A 4-D tensor. The dimension order is interpreted according to the value of data_format, see below for details. filters : tensor Must have the same type as input. A 4-D tensor of shape [filter_height, filter_width, in_channels, out_channels] strides : int of list The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 1. The dimension order is determined by the value of data_format, see below for details. padding : string "SAME" or "VALID" data_format : string "NHWC", "NCHW". Defaults to "NCHW". dilations : list or ints list of ints that has length 1, 2 or 4, defaults to 1. The dilation factor for each dimension ofinput. Returns ------- A Tensor. Has the same type as input. """ raise NotImplementedError class Conv3D(Cell): def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None): super(Conv3D, self).__init__() self.data_format, self.padding = preprocess_3d_format(data_format, padding) if self.data_format is 'NDHWC': self.ms_stride = strides[1] self.ms_dilation = dilations[1] raise NotImplementedError("The optional value for data format. Currently only support “NCDHW”.") elif self.data_format is 'NCDHW': self.ms_stride = strides[2] self.ms_dilation = dilations[2] self.conv3d = P.Conv3D( out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride, dilation=self.ms_dilation, data_format=data_format ) def construct(self, input, filters): outputs = self.conv3d(input, filters) return outputs def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None, name=None): """ Computes a 3-D convolution given 5-D input and filters tensors. Parameters ---------- input : tensor Must be one of the following types: half, bfloat16, float32, float64. Shape [batch, in_depth, in_height, in_width, in_channels]. filters : tensor Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels]. in_channels must match between input and filters. strides : list of ints A list of ints that has length >= 5. 1-D tensor of length 5. The stride of the sliding window for each dimension of input. Must have strides[0] = strides[4] = 1. padding : string A string from: "SAME", "VALID". The type of padding algorithm to use. data_format : string An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels]. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. dilations : list of ints Defaults to [1, 1, 1, 1, 1]. 1-D tensor of length 5. The dilation factor for each dimension of input. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension. The dimension order is determined by the value of data_format, see above for details. Dilations in the batch and depth dimensions must be 1. name : string A name for the operation (optional). Returns ------- A Tensor. Has the same type as input. """ raise NotImplementedError def lrn(inputs, depth_radius, bias, alpha, beta): """ Local Response Normalization. Parameters ---------- inputs : tensor Must be one of the following types: half, bfloat16, float32. 4-D. depth_radius : int Defaults to 5. 0-D. Half-width of the 1-D normalization window. bias : float Defaults to 1. An offset (usually positive to avoid dividing by 0). alpha : float Defaults to 1. A scale factor, usually positive. beta : float Defaults to 0.5. An exponent. Returns ------- A Tensor. Has the same type as input. """ pass def moments(x, axes, shift=None, keepdims=False): """ Calculates the mean and variance of x. Parameters ---------- x : tensor A Tensor axes : ints Axes along which to compute mean and variance. shift : int Not used in the current implementation. keepdims : bool produce moments with the same dimensionality as the input. Returns ------- Two Tensor objects: mean and variance. """ pass class MaxPool1d(Cell): def __init__(self, ksize, strides, padding, data_format=None): super(MaxPool1d, self).__init__() self.data_format, padding = preprocess_1d_format(data_format=data_format, padding=padding) self.expand = P.ExpandDims() _strides = (1, strides[0]) _ksize = (1, ksize[0]) if self.data_format == 'NWC': self.squeeze = P.Squeeze(1) _data_format = 'NHWC' if self.data_format == 'NCW': self.squeeze = P.Squeeze(2) _data_format = 'NCHW' self.max_pool = P.MaxPool(kernel_size=_ksize, strides=_strides, pad_mode=padding, data_format=_data_format) def construct(self, inputs): if self.data_format == 'NWC': x = self.expand(inputs, 1) if self.data_format == 'NCW': x = self.expand(inputs, 2) output = self.max_pool(x) output = self.squeeze(output) return output class MaxPool(Cell): def __init__(self, ksize, strides, padding, data_format=None): super(MaxPool, self).__init__() data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding) if data_format == 'NHWC': _strides = (strides[1], strides[2]) if data_format == 'NCHW': _strides = (strides[2], strides[3]) self.maxpool = P.MaxPool(kernel_size=ksize, strides=_strides, pad_mode=padding, data_format=data_format) def construct(self, inputs): outputs = self.maxpool(inputs) return outputs def max_pool(input, ksize, strides, padding, data_format=None): """ Performs the max pooling on the input. Parameters ---------- input : tensor Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC". Pooling happens over the spatial dimensions only. ksize : int or list of ints An int or list of ints that has length 1, N or N+2. The size of the window for each dimension of the input tensor. strides : list or list of ints An int or list of ints that has length 1, N or N+2. The stride of the sliding window for each dimension of the input tensor. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. Returns ------- A Tensor of format specified by data_format. The max pooled output tensor. """ data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding) if data_format == 'NHWC': _strides = (strides[1], strides[2]) if data_format == 'NCHW': _strides = (strides[2], strides[3]) outputs = P.MaxPool(kernel_size=ksize, strides=_strides, pad_mode=padding, data_format=data_format)(input) return outputs class AvgPool1d(Cell): def __init__(self, ksize, strides, padding, data_format=None): super(AvgPool1d, self).__init__() self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding) self.kernel_size = (1, ksize[0]) self.stride = (1, strides[0]) if self.data_format == 'NWC': _data_format = 'NHWC' self.squeeze = P.Squeeze(1) if self.data_format == 'NCW': _data_format = 'NCHW' self.squeeze = P.Squeeze(2) self.avg_pool = P.AvgPool( kernel_size=self.kernel_size, strides=self.stride, pad_mode=self.padding, data_format=_data_format ) self.reduce_mean = P.ReduceMean(keep_dims=True) self.slice = P.Slice() self.expand = P.ExpandDims() self.shape = P.Shape() def construct(self, inputs): x = inputs batch, channel, width = self.shape(inputs) if width == self.kernel_size[1]: x = self.reduce_mean(x, 2) elif width - self.kernel_size[1] < self.stride[1]: x = self.slice(x, (0, 0, 0), (batch, channel, self.kernel_size[1])) x = self.reduce_mean(x, 2) else: if self.data_format == 'NCW': x = self.expand(x, 2) if self.data_format == 'NWC': x = self.expand(x, 1) x = self.avg_pool(x) x = self.squeeze(x) return x class AvgPool(Cell): def __init__(self, ksize, strides, padding, data_format=None): super(AvgPool, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format=data_format, padding=padding) ms_ksize = ksize[1] ms_strides = strides[1] self.avgpool = P.AvgPool(ksize=ms_ksize, strides=ms_strides, padding=padding, data_format=self.data_format) def construct(self, inputs): outputs = self.avgpool(inputs) return outputs def avg_pool(input, ksize, strides, padding): """ Performs the avg pooling on the input. Parameters ---------- input : tensor Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC". Pooling happens over the spatial dimensions only. ksize : int or list of ints An int or list of ints that has length 1, N or N+2. The size of the window for each dimension of the input tensor. strides : int or list of ints An int or list of ints that has length 1, N or N+2. The stride of the sliding window for each dimension of the input tensor. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. Returns ------- A Tensor of format specified by data_format. The average pooled output tensor. """ padding = padding_format(padding) ms_ksize = ksize[0] ms_strides = strides[1] outputs = P.AvgPool(ksize=ms_ksize, strides=ms_strides, padding=padding) return outputs(input) class MaxPool3d(Cell): def __init__(self, ksize, strides, padding, data_format=None): super(MaxPool3d, self).__init__() self.data_format, self.padding = preprocess_3d_format(data_format, padding) if data_format == 'NDHWC': _strides = (strides[1], strides[2], strides[3]) if data_format == 'NCDHW': _strides = (strides[2], strides[3], strides[4]) self.max_pool3d = P.MaxPool3D( kernel_size=ksize, strides=_strides, padding=padding, data_format=self.data_format ) def __call__(self, inputs): outputs = self.max_pool3d(inputs) return outputs def max_pool3d(input, ksize, strides, padding, data_format=None, name=None): """ Performs the max pooling on the input. Parameters ---------- input : tensor A 5-D Tensor of the format specified by data_format. ksize : int or list of ints An int or list of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor. strides : int or list of ints An int or list of ints that has length 1, 3 or 5. The stride of the sliding window for each dimension of the input tensor. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. data_format : string "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels]. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width]. name : string A name for the operation (optional). Returns ------- A Tensor of format specified by data_format. The max pooled output tensor. """ pass class AvgPool3d(Cell): def __init__(self, ksize, strides, padding, data_format=None): super(AvgPool3d, self).__init__() self.data_format, self.padding = preprocess_3d_format(data_format, padding) if data_format == 'NDHWC': _strides = (strides[1], strides[2], strides[3]) if data_format == 'NCDHW': _strides = (strides[2], strides[3], strides[4]) raise NotImplementedError def __call__(self, inputs): pass def avg_pool3d(input, ksize, strides, padding, data_format=None, name=None): """ Performs the average pooling on the input. Parameters ---------- input : tensor A 5-D Tensor of shape [batch, height, width, channels] and type float32, float64, qint8, quint8, or qint32. ksize : int or list of ints An int or list of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor. strides : int or list of ints An int or list of ints that has length 1, 3 or 5. The stride of the sliding window for each dimension of the input tensor. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. data_format : string 'NDHWC' and 'NCDHW' are supported. name : string Optional name for the operation. Returns ------- A Tensor with the same type as value. The average pooled output tensor. """ pass def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_format=None, dilations=None, name=None): """ Performs an N-D pooling operation. Parameters ---------- input : tensor Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC". Pooling happens over the spatial dimensions only. window_shape : int Sequence of N ints >= 1. pooling_type : string Specifies pooling operation, must be "AVG" or "MAX". strides : ints Sequence of N ints >= 1. Defaults to [1]*N. If any value of strides is > 1, then all values of dilation_rate must be 1. padding : string The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME". See the "returns" section of tf.ops.convolution for details. data_format : string Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"), or the second dimension (if data_format starts with "NC"). For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For N=3, the valid values are "NDHWC" (default) and "NCDHW". dilations : list of ints Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1. name : string Optional. Name of the op. Returns ------- Tensor of rank N+2, of shape [batch_size] + output_spatial_shape + [num_channels] """ pass class DepthwiseConv2d(Cell): def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1): super(DepthwiseConv2d, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) self.ms_stride = strides[1] self.ms_dilation = dilations[1] self.depthwise_conv2d = P.DepthwiseConv2dNative( channel_multiplier=channel_multiplier, kernel_size=ksize, stride=self.ms_stride, dilation=self.ms_dilation ) def construct(self, input, filter): if self.data_format == 'NHWC': input = nhwc_to_nchw(input) outputs = self.depthwise_conv2d(input, filter) if self.data_format == 'NHWC': outputs = nchw_to_nhwc(outputs) return outputs def depthwise_conv2d(input, filter, strides, padding, data_format=None, dilations=None, name=None): """ Depthwise 2-D convolution. Parameters ---------- input : tensor 4-D with shape according to data_format. filter : tensor 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier]. strides : list 1-D of size 4. The stride of the sliding window for each dimension of input. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. data_format : string The data format for input. Either "NHWC" (default) or "NCHW". dilations : list 1-D of size 2. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution. If it is greater than 1, then all values of strides must be 1. name : string A name for this operation (optional). Returns ------- A 4-D Tensor with shape according to data_format. E.g., for "NHWC" format, shape is [batch, out_height, out_width, in_channels * channel_multiplier]. """ pass class Conv1d_transpose(Cell): def __init__(self, stride, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None): super(Conv1d_transpose, self).__init__() self.data_format, self.padding = preprocess_1d_format(data_format, padding) self.in_channels = in_channels self.out_channel = out_channel self.stride = (1, stride) self.dilations = (1, dilations) self.k_size = (1, k_size) if self.data_format == 'NWC': self.data_format = 'NHWC' self.h_axis = 1 else: self.data_format = 'NCHW' self.h_axis = 2 self.conv2d_transpose = P.Conv2DBackpropInput( out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride, dilation=self.dilations, mode=1, group=1, data_format=self.data_format ) self.shape = P.Shape() self.expand_dims = P.ExpandDims() self.squeeze = P.Squeeze(self.h_axis) def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size): length = 0 filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) if self.padding == 'same': length = input_length * stride_size elif self.padding == 'valid': length = input_length * stride_size + max(filter_size - stride_size, 0) return length def construct(self, x, filters): x = self.expand_dims(x, self.h_axis) filters = self.expand_dims(filters, self.h_axis) if self.data_format == 'NCHW': n, _, h, w = self.shape(x) else: n, h, w, _ = self.shape(x) h_out = self._deconv_output_length(h, self.k_size[0], self.stride[0], self.dilations[0]) w_out = self._deconv_output_length(w, self.k_size[1], self.stride[1], self.dilations[1]) if self.data_format == 'NCHW': output_size = (n, self.out_channel, h_out, w_out) else: output_size = (n, h_out, w_out, self.out_channel) output = self.conv2d_transpose(x, filters, output_size) output = self.squeeze(output) return output def conv1d_transpose( input, filters, output_shape, strides, padding='SAME', data_format='NWC', dilations=None, name=None ): """ The transpose of conv1d. Parameters ---------- input : tensor A 3-D Tensor of type float and shape [batch, in_width, in_channels] for NWC data format or [batch, in_channels, in_width] for NCW data format. filters : tensor A 3-D Tensor with the same type as value and shape [filter_width, output_channels, in_channels]. filter's in_channels dimension must match that of value. output_shape : tensor A 1-D Tensor, containing three elements, representing the output shape of the deconvolution op. strides : list An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. data_format : string 'NWC' and 'NCW' are supported. dilations : list An int or list of ints that has length 1 or 3 which defaults to 1. The dilation factor for each dimension of input. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension. Dilations in the batch and depth dimensions must be 1. name : string Optional name for the returned tensor. Returns ------- A Tensor with the same type as value. """ pass class Conv2d_transpose(Cell): def __init__(self, strides, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None): super(Conv2d_transpose, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) self.in_channels = in_channels self.out_channel = out_channel self.k_size = k_size self.strides = strides self.dilations = dilations self.conv2d_transpose = P.Conv2DBackpropInput( out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.strides, dilation=self.dilations, mode=1, group=1, data_format=self.data_format ) self.shape = P.Shape() def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size): length = 0 filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) if self.padding == 'same': length = input_length * stride_size elif self.padding == 'valid': length = input_length * stride_size + max(filter_size - stride_size, 0) return length def construct(self, x, filters): if self.data_format == 'NHWC': h_axis, w_axis = 1, 2 n, h, w, _ = self.shape(x) else: h_axis, w_axis = 2, 3 n, _, h, w = self.shape(x) if isinstance(self.strides, int): strides_h = self.strides strides_w = self.strides else: strides_list = list(self.strides) if len(strides_list) == 2: strides_h = strides_list[0] strides_w = strides_list[1] elif len(strides_list) == 4: strides_h = strides_list[h_axis] strides_w = strides_list[w_axis] if self.dilations is not None: if isinstance(self.dilations, int): dilations_h = self.dilations dilations_w = self.dilations else: dilations_list = list(self.dilations) if len(dilations_list) == 2: dilations_h = dilations_list[0] dilations_w = dilations_list[1] elif len(dilations_list) == 4: dilations_h = dilations_list[h_axis] dilations_w = dilations_list[w_axis] h_out = self._deconv_output_length(h, self.k_size[0], strides_h, dilations_h) w_out = self._deconv_output_length(w, self.k_size[1], strides_w, dilations_w) if self.data_format == 'NCHW': output_size = (n, self.out_channel, h_out, w_out) else: output_size = (n, h_out, w_out, self.out_channel) output = self.conv2d_transpose(x, filters, output_size) return output def conv2d_transpose( input, filters, output_shape, strides, padding='SAME', data_format='NHWC', dilations=None, name=None ): """ The transpose of conv2d. Parameters ---------- input : tensor A 4-D Tensor of type float and shape [batch, height, width, in_channels] for NHWC data format or [batch, in_channels, height, width] for NCHW data format. filters : tensor A 4-D Tensor with the same type as input and shape [height, width, output_channels, in_channels]. filter's in_channels dimension must match that of input. output_shape : tensor A 1-D Tensor representing the output shape of the deconvolution op. strides : list An int or list of ints that has length 1, 2 or 4. The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0. The dimension order is determined by the value of data_format, see below for details. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. data_format : string 'NHWC' and 'NCHW' are supported. dilations : list An int or list of ints that has length 1, 2 or 4, defaults to 1. name : string Optional name for the returned tensor. Returns ------- A Tensor with the same type as input. """ pass class Conv3d_transpose(Cell): def __init__( self, strides, padding, data_format='NDHWC', dilations=None, name=None, out_channel=None, k_size=None, in_channels=None ): super(Conv3d_transpose, self).__init__() self.data_format, self.padding = preprocess_3d_format(data_format, padding) self.conv3d_transpose = P.Conv3DTranspose( in_channel=in_channels, out_channel=out_channel, kernel_size=k_size, mode=1, pad_mode=self.padding, stride=strides, dilation=dilations, data_format=self.data_format ) def construct(self, input, filters): output = self.conv3d_transpose(input, filters) return output def conv3d_transpose( input, filters, output_shape, strides, padding='SAME', data_format='NDHWC', dilations=None, name=None ): """ The transpose of conv3d. Parameters ---------- input : tensor A 5-D Tensor of type float and shape [batch, height, width, in_channels] for NHWC data format or [batch, in_channels, height, width] for NCHW data format. filters : tensor A 5-D Tensor with the same type as value and shape [height, width, output_channels, in_channels]. filter's in_channels dimension must match that of value. output_shape : tensor A 1-D Tensor representing the output shape of the deconvolution op. strides : list An int or list of ints that has length 1, 3 or 5. padding : string 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details. data_format : string 'NDHWC' and 'NCDHW' are supported. dilations : list of ints An int or list of ints that has length 1, 3 or 5, defaults to 1. name : string Optional name for the returned tensor. Returns ------- A Tensor with the same type as value. """ pass class BatchNorm(Cell): """Batch Normalization base class.""" @cell_attr_register def __init__( self, num_features, epsilon=1e-5, decay=0.9, gamma=None, beta=None, moving_mean=None, moving_var=None, is_train=None, device_num_each_group=1, process_groups=0, data_format='NCHW' ): super(BatchNorm, self).__init__() if data_format in ["channels_last", "NHWC", "nhwc"]: data_format = "NHWC" elif data_format in ["channels_first", "NCHW", "nchw"]: data_format = "NCHW" validator.check_value_type('num_features', num_features, [int], self.cls_name) if num_features < 1: raise ValueError("num_features must be at least 1") if decay < 0 or decay > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(decay)) self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = is_train self.num_features = num_features self.eps = epsilon self.moving_mean = moving_mean self.moving_variance = moving_var self.gamma = gamma self.beta = beta self.group_device_num = validator.check_positive_int(device_num_each_group) self.process_groups = process_groups self.is_global = False self.parallel_mode = context.get_auto_parallel_context("parallel_mode") global SYNC_BN_GROUP_NAME # for GlobalBatchNorm if self.group_device_num != 1: self.rank_id = get_rank() self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)] self.rank_list = self.list_group(self.device_list, self.group_device_num) self.rank_list_idx = len(self.rank_list) for i in range(self.rank_list_idx): if self.rank_id in self.rank_list[i]: self.is_global = True if SYNC_BN_GROUP_NAME == "": SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) # for SyncBatchNorm if self.process_groups != 0: self.rank_id = get_rank() self.rank_size = get_group_size() if self.process_groups is not None: validator.check_isinstance("process_groups", self.process_groups, list) self._check_rank_ids(self.process_groups, self.rank_size) for i in range(len(self.process_groups)): validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) self.group_device_num = len(self.process_groups[i]) if self.rank_id in self.process_groups[i] and self.group_device_num > 1: self.is_global = True if SYNC_BN_GROUP_NAME == "": SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) elif self.rank_size > 1: self.is_global = True self.group_device_num = self.rank_size self.device_list = [i for i in range(0, self.rank_size)] if SYNC_BN_GROUP_NAME == "": SYNC_BN_GROUP_NAME = "sync_bn_group0" management.create_group(SYNC_BN_GROUP_NAME, self.device_list) self.shape = P.Shape() self.reduce_mean = P.ReduceMean(keep_dims=True) self.square = P.Square() self.sqrt = P.Sqrt() self.cast = P.Cast() self.dtype = P.DType() self.reshape = P.Reshape() self._target = context.get_context("device_target") self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE self.momentum = 1.0 - decay if context.get_context("enable_ge"): self.is_ge_backend = True else: self.is_ge_backend = False self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, momentum=self.momentum, data_format=self.format) if self.is_global: self.bn_train = inner.SyncBatchNorm( epsilon=self.eps, momentum=self.momentum, group=SYNC_BN_GROUP_NAME, device_num=self.group_device_num ) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) data_parallel_strategy = ((1, ), (1, )) data_parallel_strategy_one = ((1, ), ()) self.sub_mean = P.Sub().shard(data_parallel_strategy) self.sub_var = P.Sub().shard(data_parallel_strategy) self.mul_mean = P.Mul().shard(data_parallel_strategy_one) self.mul_var = P.Mul().shard(data_parallel_strategy_one) self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) def list_group(self, world_rank, group_size): if group_size > get_group_size(): raise ValueError( "group size can not be greater than local rank size, group size is {}, " "local_rank_size is {}".format(group_size, get_group_size()) ) if len(world_rank) % group_size != 0: raise ValueError("please make your group size correct.") world_rank_list = zip(*(iter(world_rank), ) * group_size) group_list = [list(i) for i in world_rank_list] return group_list def _check_rank_ids(self, process_groups, rank_size): seen = set() for rid in itertools.chain(*process_groups): validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups") if rid in seen: raise ValueError("rank id in process_groups should not be duplicated.") seen.add(rid) def construct(self, inputs): x_shape = F.shape(inputs) if len(x_shape) == 5: inputs = self.reshape(inputs, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) flag = self.use_batch_statistics if flag: output = self.bn_train(inputs, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0] if len(x_shape) == 5: output = self.reshape(output, x_shape) return output output = self.bn_infer(inputs, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0] if len(x_shape) == 5: output = self.reshape(output, x_shape) return output def extend_repr(self): return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance ) class GroupConv2D(Cell): def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, groups): super(GroupConv2D, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) if self.data_format is 'NHWC': self.ms_stride = strides[1] self.ms_dilation = dilations[1] elif self.data_format is 'NCHW': self.ms_stride = strides[2] self.ms_dilation = dilations[2] self.conv2d = P.Conv2D( out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=groups, data_format=self.data_format ) def construct(self, inputs, filters): outputs = self.conv2d(inputs, filters) return outputs class SeparableConv1D(Cell): def __init__(self, stride, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier): super(SeparableConv1D, self).__init__() self.data_format, self.padding = preprocess_1d_format(data_format, padding) self.stride = (1, stride) self.dilations = (1, dilations) self.k_size = (1, k_size) self.out_channel = out_channel self.in_channel = in_channel self.depth_multiplier = depth_multiplier self.depthwise_conv = P.Conv2D( out_channel=self.in_channel * self.depth_multiplier, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride, dilation=self.dilations, mode=1, group=self.in_channel ) self.pointwise_conv = P.Conv2D( out_channel=self.out_channel, kernel_size=(1, 1), pad_mode=self.padding, stride=(1, 1), dilation=(1, 1), mode=1, group=1 ) self.expand_dims = P.ExpandDims() self.squeeze = P.Squeeze(2) def construct(self, x, depthwise_filters, pointwise_filters): if self.data_format == 'NWC': x = nhwc_to_nchw(x) x = self.expand_dims(x, 2) depthwise_filters = self.expand_dims(depthwise_filters, 2) pointwise_filters = self.expand_dims(pointwise_filters, 2) outputs = self.depthwise_conv(x, depthwise_filters) outputs = self.pointwise_conv(outputs, pointwise_filters) outputs = self.squeeze(outputs) if self.data_format == 'NWC': outputs = nchw_to_nhwc(outputs) return outputs class SeparableConv2D(Cell): def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier): super(SeparableConv2D, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) self.k_size = k_size self.out_channel = out_channel self.in_channel = in_channel self.depth_multiplier = depth_multiplier if self.data_format is 'NHWC': self.ms_stride = strides[1] self.ms_dilation = dilations[1] elif self.data_format is 'NCHW': self.ms_stride = strides[2] self.ms_dilation = dilations[2] self.depthwise_conv = P.Conv2D( out_channel=self.in_channel * self.depth_multiplier, kernel_size=self.k_size, pad_mode=self.padding, stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=self.in_channel, data_format=self.data_format ) self.pointwise_conv = P.Conv2D( out_channel=self.out_channel, kernel_size=(1, 1), pad_mode=self.padding, stride=(1, 1), dilation=(1, 1), mode=1, group=1, data_format=self.data_format ) def construct(self, x, depthwise_filters, pointwise_filters): outputs = self.depthwise_conv(x, depthwise_filters) outputs = self.pointwise_conv(outputs, pointwise_filters) return outputs class AdaptiveMeanPool1D(Cell): def __init__(self, output_size, data_format): super(AdaptiveMeanPool1D, self).__init__() self.data_format, _ = preprocess_1d_format(data_format, None) self.output_size = output_size if self.data_format == 'NWC': self.data_format = 'NHWC' self.h_axis = 1 else: self.data_format = 'NCHW' self.h_axis = 2 self.expand_dims = P.ExpandDims() self.squeeze = P.Squeeze(self.h_axis) self.shape = P.Shape() def construct(self, inputs): if self.data_format == 'NHWC': n, w, c = self.shape(inputs) else: n, c, w = self.shape(inputs) inputs = self.expand_dims(inputs, self.h_axis) stride = (1, w // self.output_size) kernel = (1, w - (self.output_size - 1) * stride[1]) outputs = P.AvgPool(kernel_size=kernel, strides=stride, pad_mode='VALID', data_format=self.data_format)(inputs) outputs = self.squeeze(outputs) return outputs class AdaptiveMeanPool2D(Cell): def __init__(self, output_size, data_format): super(AdaptiveMeanPool2D, self).__init__() self.data_format, _ = preprocess_2d_format(data_format, None) self.output_size = output_size if self.data_format == 'NHWC': self.h_axis = 1 else: self.h_axis = 2 self.shape = P.Shape() def construct(self, inputs): if self.data_format == 'NHWC': n, h, w, c = self.shape(inputs) else: n, c, h, w = self.shape(inputs) out_h, out_w = self.output_size stride_h = h // out_h kernel_h = h - (out_h - 1) * stride_h stride_w = w // out_w kernel_w = w - (out_w - 1) * stride_w outputs = P.AvgPool( kernel_size=(kernel_h, kernel_w), strides=(stride_h, stride_w), pad_mode='VALID', data_format=self.data_format )(inputs) return outputs class AdaptiveMeanPool3D(Cell): def __init__(self, output_size, data_format): pass def __call__(self, inputs): raise NotImplementedError class AdaptiveMaxPool1D(Cell): def __init__(self, output_size, data_format): super(AdaptiveMaxPool1D, self).__init__() self.data_format, _ = preprocess_1d_format(data_format, None) self.output_size = output_size if self.data_format == 'NWC': self.data_format = 'NHWC' self.h_axis = 1 else: self.data_format = 'NCHW' self.h_axis = 2 self.expand_dims = P.ExpandDims() self.squeeze = P.Squeeze(self.h_axis) self.shape = P.Shape() def construct(self, inputs): if self.data_format == 'NHWC': n, w, c = self.shape(inputs) else: n, c, w = self.shape(inputs) inputs = self.expand_dims(inputs, self.h_axis) stride = (1, w // self.output_size) kernel = (1, w - (self.output_size - 1) * stride[1]) outputs = P.MaxPool(kernel_size=kernel, strides=stride, pad_mode='VALID', data_format=self.data_format)(inputs) outputs = self.squeeze(outputs) return outputs class AdaptiveMaxPool2D(Cell): def __init__(self, output_size, data_format): super(AdaptiveMaxPool2D, self).__init__() self.data_format, _ = preprocess_2d_format(data_format, None) self.output_size = output_size if self.data_format == 'NHWC': self.h_axis = 1 else: self.h_axis = 2 self.shape = P.Shape() def construct(self, inputs): if self.data_format == 'NHWC': n, h, w, c = self.shape(inputs) else: n, c, h, w = self.shape(inputs) out_h, out_w = self.output_size stride_h = h // out_h kernel_h = h - (out_h - 1) * stride_h stride_w = w // out_w kernel_w = w - (out_w - 1) * stride_w outputs = P.MaxPool( kernel_size=(kernel_h, kernel_w), strides=(stride_h, stride_w), pad_mode='VALID', data_format=self.data_format )(inputs) return outputs class AdaptiveMaxPool3D(Cell): def __init__(self, output_size, data_format): pass def __call__(self, inputs): raise NotImplementedError class BinaryConv2D(Cell): def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel): super(BinaryConv2D, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) if self.data_format is 'NHWC': self.ms_stride = strides[1] self.ms_dilation = dilations[1] elif self.data_format is 'NCHW': self.ms_stride = strides[2] self.ms_dilation = dilations[2] self.conv2d = P.Conv2D( out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=1, data_format=self.data_format ) @bprop_getters.register(P.Sign) def get_bprop_Sign(self): def bprop(x, out, dout): grad = P.clip_by_value(dout, -1, 1) return (grad, ) return bprop self.sign = P.Sign() def construct(self, inputs, filters): filters = self.sign(filters) outputs = self.conv2d(inputs, filters) return outputs class DorefaConv2D(Cell): def __init__(self, bitW, bitA, strides, padding, data_format, dilations, out_channel, k_size, in_channel): super(DorefaConv2D, self).__init__() self.data_format, self.padding = preprocess_2d_format(data_format, padding) self.bitW = ms.Tensor(bitW) self.bitA = ms.Tensor(bitA) if self.data_format is 'NHWC': self.ms_stride = strides[1] self.ms_dilation = dilations[1] # self.transpose = P.Transpose() elif self.data_format is 'NCHW': self.ms_stride = strides[2] self.ms_dilation = dilations[2] self.conv2d = P.Conv2D( out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=1 ) @bprop_getters.register(P.Round) def get_bprop_Round(self): def bprop(x, out, dout): return (dout, ) return bprop @bprop_getters.register(P.Sign) def get_bprop_Sign(self): def bprop(x, out, dout): return (dout, ) return bprop self.mimimum = P.Minimum() self.abs = P.Abs() self.round = P.Round() self.reducemean = P.ReduceMean() self.sign = P.Sign() self.pow = P.Pow() self.sub = P.Sub() self.oneslike = P.OnesLike() def cabs(self, inputs): a = P.stop_gradient(self.oneslike(inputs)) return self.mimimum(self.abs(inputs), a) def _quantize_dorefa(self, x, k): n = self.sub(self.pow(2.0, k), 1) return self.round(x * n) / n def quantize_active(self, x, bitA): if bitA == 32: return x return self._quantize_dorefa(x, bitA) def quantize_weight(self, x, bitW, force_quantization=False): if bitW == 32 and not force_quantization: return x if bitW == 1: E = P.stop_gradient(self.reducemean(self.abs(x))) return self.sign(x / E) * E x = P.clip_by_value(x * 0.5 + 0.5, 0.0, 1.0) return 2 * self._quantize_dorefa(x, bitW) - 1 def construct(self, inputs, filters): if self.data_format == 'NHWC': inputs = nhwc_to_nchw(inputs) inputs = self.quantize_active(self.cabs(inputs), self.bitA) filters = self.quantize_weight(filters, self.bitW) outputs = self.conv2d(inputs, filters) if self.data_format == 'NHWC': outputs = nchw_to_nhwc(outputs) return outputs class rnncell(Cell): def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act): super(rnncell, self).__init__() self.weight_ih = weight_ih self.weight_hh = weight_hh self.bias_ih = bias_ih self.bias_hh = bias_hh self.act_fn = P.ReLU() if act == 'relu' else P.Tanh() self.transpose = P.Transpose() def construct(self, input, h): self.weight_ih = self.transpose(self.weight_ih, (1, 0)) i2h = P.matmul(input, self.weight_ih) if self.bias_ih is not None: i2h += self.bias_ih self.weight_hh = self.transpose(self.weight_hh, (1, 0)) h2h = P.matmul(h, self.weight_hh) if self.bias_hh is not None: h2h += self.bias_hh h = self.act_fn(i2h + h2h) return h, h class lstmcell(Cell): def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): super(lstmcell, self).__init__() self.weight_ih = weight_ih self.weight_hh = weight_hh self.bias_ih = bias_ih self.bias_hh = bias_hh self.gate_act_fn = P.Sigmoid() self.act_fn = P.Tanh() self.transpose = P.Transpose() self.split = P.Split(axis=-1, output_num=4) def construct(self, input, h, c): self.weight_ih = self.transpose(self.weight_ih, (1, 0)) gates = P.matmul(input, self.weight_ih) if self.bias_ih is not None: gates += self.bias_ih self.weight_hh = self.transpose(self.weight_hh, (1, 0)) gates += P.matmul(h, self.weight_hh) if self.bias_hh is not None: gates += self.bias_hh gate_slices = self.split(gates) i = self.gate_act_fn(gate_slices[0]) f = self.gate_act_fn(gate_slices[1]) o = self.gate_act_fn(gate_slices[3]) c = f * c + i * self.act_fn(gate_slices[2]) h = o * self.act_fn(c) return h, h, c class grucell(Cell): def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): super(grucell, self).__init__() self.weight_ih = weight_ih self.weight_hh = weight_hh self.bias_ih = bias_ih self.bias_hh = bias_hh self.gate_act_fn = P.Sigmoid() self.act_fn = P.Tanh() self.transpose = P.Transpose() self.split = P.Split(axis=-1, output_num=3) def construct(self, input, h): self.weight_ih = self.transpose(self.weight_ih, (1, 0)) x_gates = P.matmul(input, self.weight_ih) if self.bias_ih is not None: x_gates += self.bias_ih self.weight_hh = self.transpose(self.weight_hh, (1, 0)) h_gates = P.matmul(h, self.weight_hh) if self.bias_hh is not None: h_gates += self.bias_hh x_r, x_z, x_c = self.split(x_gates) h_r, h_z, h_c = self.split(h_gates) r = self.gate_act_fn(x_r + h_r) z = self.gate_act_fn(x_r + h_z) c = self.act_fn(x_c + r * h_c) h = (h - c) * z + c return h, h class rnnbase(Cell): def __init__( self, mode, input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, is_train, ): super(rnnbase, self).__init__() self.mode = mode self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bidirect = 2 if bidirectional else 1 self.batch_first = batch_first if mode == 'LSTM': self.lstm = ms.nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional ) elif mode == 'GRU': raise NotImplementedError elif mode == 'RNN_TANH': raise NotImplementedError elif mode == 'RNN_RELU': raise NotImplementedError self.zeros = P.Zeros() def construct(self, input, states): input_shape = input.shape input_dtype = input.dtype if self.mode == 'LSTM': if self.batch_first: batch_size = input_shape[0] else: batch_size = input_shape[1] if states is None: h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype) c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype) states = (h, c) output, (h, c) = self.lstm(input, states) return output, (h, c)