| @@ -1442,6 +1442,39 @@ protected: | |||
| void backward_check_exec(const TensorLayout& src, const TensorLayout& dst); | |||
| }; | |||
| class LAMBUpdate : public OperatorBase { | |||
| DEF_OPR_PARAM(LAMBUpdate); | |||
| // input=(m_t-1,v_t-1,lamb_param,grad) , output = (m_t,v_t,new_param) | |||
| DEF_OPR_IMPL(LAMBUpdate, OperatorBase, 4, 3); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||
| _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||
| _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
| _megdnn_tensor_out new_param, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, | |||
| const TensorLayout& m_t, const TensorLayout& v_t, | |||
| const TensorLayout& new_param) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, | |||
| TensorLayout& v_t, TensorLayout& new_param); | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, | |||
| const TensorLayout& m_t, const TensorLayout& v_t, | |||
| const TensorLayout& new_param, size_t workspace_in_bytes); | |||
| }; | |||
| using LAMB = LAMBUpdate; | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/opr_header_epilogue.h" | |||
| @@ -36,13 +36,13 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
| add_enum(Doc('Format', 'convolution data/filter/output format; see ' | |||
| ':class:`RelayoutFormat` for more details'), | |||
| 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | |||
| 'NCHW44 = 7','NCHW44_DOT = 8', | |||
| 'NCHW44 = 7','NCHW44_DOT = 8', | |||
| Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), | |||
| Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), | |||
| Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), | |||
| Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
| Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
| Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
| Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
| Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
| Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
| Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' | |||
| 'output tensor is nchw layout'), | |||
| Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | |||
| @@ -96,9 +96,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
| add_enum(Doc('Format', 'convolution data/filter/output format; see ' | |||
| ':class:`RelayoutFormat` for more details'), | |||
| 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | |||
| 'NCHW44 = 7','NCHW44_DOT = 8', | |||
| Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
| Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
| 'NCHW44 = 7','NCHW44_DOT = 8', | |||
| Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
| Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
| Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
| Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' | |||
| 'output tensor is nchw layout'), | |||
| @@ -107,9 +107,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
| Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | |||
| 'output tensor is nchw4 layout, padding c=4'), | |||
| Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | |||
| 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||
| 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||
| Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' | |||
| 'instructions for 4-bit integers on Nvidia platforms'), | |||
| 'instructions for 4-bit integers on Nvidia platforms'), | |||
| Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). | |||
| add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') | |||
| ) | |||
| @@ -1038,10 +1038,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
| 'NCHW_NCHW4 = 24', | |||
| 'NCHW4_NCHW = 25', | |||
| 'NCHW_NCHW4_WEIGHT = 26', | |||
| 'NCHW_NCHW64 = 27', | |||
| 'NCHW64_NCHW = 28', | |||
| 'NCHW_NHWC = 29', | |||
| 'NHWC_NCHW = 30', | |||
| 'NCHW_NCHW64 = 27', | |||
| 'NCHW64_NCHW = 28', | |||
| 'NCHW_NHWC = 29', | |||
| 'NHWC_NCHW = 30', | |||
| 'NHWCD4I_NHWC = 31', | |||
| ) | |||
| ) | |||
| @@ -1264,3 +1264,14 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||
| add_fields('float32', Doc('dropout', 'If introduce a Dropout layer on the outputs of each LSTM layer'), '0.f'). | |||
| add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | |||
| ) | |||
| (pdef('LAMBUpdate'). | |||
| add_fields('float32', Doc('beta_1', 'beta_1 paramter of lamb'), '1.f'). | |||
| add_fields('float32', Doc('beta_2', 'beta_2 paramter of lamb'), '1.f'). | |||
| add_fields('float32', Doc('step', 'training step'), '1.f'). | |||
| add_fields('float32', Doc('lr', 'learning rate'), '1.f'). | |||
| add_fields('float32', Doc('weight_decay', 'weight decay to adjust learning rate'), '1.f'). | |||
| add_fields('float32', Doc('eps', 'eps to multi'), '1.f'). | |||
| add_fields('bool', Doc('bias_correction', 'whether correct bias'), 'true'). | |||
| add_fields('bool', Doc('always_adapt', 'apply adaptive lr to 0.0'), 'false') | |||
| ) | |||
| @@ -209,6 +209,7 @@ private: | |||
| cb(RNN) \ | |||
| cb(RNNBackward) \ | |||
| cb(LSTM) \ | |||
| cb(LAMBUpdate) \ | |||
| cb(LSTMBackward) \ | |||
| cb(SoftmaxForward) \ | |||
| cb(SoftmaxBackward) | |||
| @@ -0,0 +1,25 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| void LAMBUpdate::deduce_layout( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, TensorLayout& m_t, | |||
| TensorLayout& v_t, TensorLayout& new_param) { | |||
| m_t = TensorLayout(m_t_1); | |||
| v_t = TensorLayout(v_t_1); | |||
| new_param = TensorLayout(lamb_param); | |||
| MEGDNN_MARK_USED_VAR(grad); | |||
| } | |||
| void LAMBUpdate::check_exec( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, | |||
| const TensorLayout& m_t, const TensorLayout& v_t, const TensorLayout& new_param, | |||
| size_t workspace_in_bytes) { | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| } // namespace megdnn | |||
| @@ -127,6 +127,7 @@ DEF(LSQBackward, 7, true, false); | |||
| DEF(Fill, 1, true, false); | |||
| DEF(LayerNormForward, 6, true, true); | |||
| DEF(LayerNormBackward, 8, true, true); | |||
| DEF(LAMBUpdate, 7, true, true); | |||
| DEF(DropoutForward, 3, true, true); | |||
| DEF(DropoutBackward, 3, true, true); | |||
| DEF(RNNCellForward, 7, true, true); | |||
| @@ -35,6 +35,7 @@ | |||
| #include "src/cuda/images2neibs/opr_impl.h" | |||
| #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | |||
| #include "src/cuda/indexing_one_hot/opr_impl.h" | |||
| #include "src/cuda/lamb/opr_impl.h" | |||
| #include "src/cuda/layer_norm/opr_impl.h" | |||
| #include "src/cuda/linspace/opr_impl.h" | |||
| #include "src/cuda/local/opr_impl.h" | |||
| @@ -210,6 +211,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(LAMBUpdate); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); | |||
| @@ -0,0 +1,102 @@ | |||
| #include <thrust/device_vector.h> | |||
| #include <thrust/pair.h> | |||
| #include <thrust/transform_reduce.h> | |||
| #include <thrust/tuple.h> | |||
| #include <cfloat> | |||
| #include "megdnn/arch.h" | |||
| #include "megdnn/dtype.h" | |||
| #include "src/cuda/cuda_shfl_compat.cuh" | |||
| #include "src/cuda/lamb/lamb_cuda.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace lamb { | |||
| template <typename T> | |||
| struct square { | |||
| __host__ __device__ T operator()(const T& x) const { return x * x; } | |||
| }; | |||
| template <typename T, typename T_ACC> | |||
| __global__ void update_kernal_1( | |||
| T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
| T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
| float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
| size_t total_nr_elem) { | |||
| size_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
| T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, | |||
| bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||
| if (idx < total_nr_elem) { | |||
| m_t[idx] = beta_1 * m_t_1[idx] + (1 - beta_1) * static_cast<T_ACC>(grad[idx]); | |||
| v_t[idx] = beta_2 * v_t_1[idx] + | |||
| (1 - beta_2) * std::pow(static_cast<T_ACC>(grad[idx]), 2); | |||
| rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); | |||
| if (weight_decay != 0) { | |||
| rt[idx] += lamb_param[idx] * weight_decay; | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| __global__ void update_kernal_2( | |||
| T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
| T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
| float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
| size_t total_nr_elem, T_ACC trust_ratio) { | |||
| size_t idx = threadIdx.x + blockIdx.x * blockDim.x; | |||
| T_ACC bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1, | |||
| bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||
| if (idx < total_nr_elem) { | |||
| rt[idx] = (m_t[idx] / bc_1) / (std::sqrt(v_t[idx] / bc_2) + eps); | |||
| if (weight_decay != 0) { | |||
| rt[idx] += lamb_param[idx] * weight_decay; | |||
| } | |||
| new_param[idx] = lamb_param[idx] - lr * trust_ratio * rt[idx]; | |||
| } | |||
| } | |||
| template <typename T, typename T_ACC> | |||
| void update( | |||
| T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
| T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
| float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
| size_t total_nr_elem, cudaStream_t stream) { | |||
| size_t NR_BLOCKS = DIVUP(total_nr_elem, NR_THREADS); | |||
| update_kernal_1<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
| m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, | |||
| step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem); | |||
| after_kernel_launch(); | |||
| thrust::device_ptr<T_ACC> lamb_param_ptr(lamb_param); | |||
| thrust::device_ptr<T_ACC> rt_ptr(rt); | |||
| square<T_ACC> unary_op; | |||
| thrust::plus<T_ACC> binary_op; | |||
| T_ACC p_norm = std::sqrt(thrust::transform_reduce( | |||
| lamb_param_ptr, lamb_param_ptr + total_nr_elem, unary_op, 0.f, binary_op)); | |||
| T_ACC d_norm = std::sqrt(thrust::transform_reduce( | |||
| rt_ptr, rt_ptr + total_nr_elem, unary_op, 0.f, binary_op)); | |||
| T_ACC trust_ratio = 1; | |||
| if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { | |||
| trust_ratio = p_norm / d_norm; | |||
| } | |||
| update_kernal_2<T, T_ACC><<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
| m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, rt, beta_1, beta_2, | |||
| step, lr, weight_decay, eps, bias_correction, always_adapt, total_nr_elem, | |||
| trust_ratio); | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(T, T_ACC) \ | |||
| template void update<T, T_ACC>( \ | |||
| T_ACC*, T_ACC*, T_ACC*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC*, float, float, \ | |||
| float, float, float, float, bool, bool, size_t, cudaStream_t); | |||
| INST(dt_float32, dt_float32) | |||
| INST(dt_float16, dt_float32) | |||
| INST(dt_bfloat16, dt_float32) | |||
| #undef INST | |||
| } // namespace lamb | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,17 @@ | |||
| #pragma once | |||
| #include <cuda_runtime_api.h> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace lamb { | |||
| template <typename T, typename T_ACC> | |||
| void update( | |||
| T_ACC* m_t_1, T_ACC* v_t_1, T_ACC* lamb_param, T* grad, T_ACC* m_t, T_ACC* v_t, | |||
| T_ACC* new_param, T_ACC* rt, float beta_1, float beta_2, float step, float lr, | |||
| float weight_decay, float eps, bool bias_correction, bool always_adapt, | |||
| size_t total_nr_elem, cudaStream_t stream); | |||
| } // namespace lamb | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,45 @@ | |||
| #include "src/cuda/lamb/opr_impl.h" | |||
| #include "./lamb_cuda.cuh" | |||
| #include "src/cuda/utils.h" | |||
| #include <cmath> | |||
| #include <functional> | |||
| #include <numeric> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| void LAMBUpdateImpl::exec( | |||
| _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||
| _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
| _megdnn_tensor_out new_param, _megdnn_workspace workspace) { | |||
| auto p = param(); | |||
| float beta_1 = p.beta_1; | |||
| float beta_2 = p.beta_2; | |||
| float step = p.step; | |||
| float lr = p.lr; | |||
| float weight_decay = p.weight_decay; | |||
| float eps = p.eps; | |||
| bool bias_correction = p.bias_correction; | |||
| bool always_adapt = p.always_adapt; | |||
| size_t total_elem = lamb_param.layout.total_nr_elems(); | |||
| auto stream = cuda_stream(handle()); | |||
| using namespace ::megdnn::cuda::lamb; | |||
| #define cb(DType) \ | |||
| if (grad.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| using T_ACC = float; \ | |||
| update<T, T_ACC>( \ | |||
| m_t_1.ptr<T_ACC>(), v_t_1.ptr<T_ACC>(), lamb_param.ptr<T_ACC>(), \ | |||
| grad.ptr<T>(), m_t.ptr<T_ACC>(), v_t.ptr<T_ACC>(), \ | |||
| new_param.ptr<T_ACC>(), workspace.ptr<T_ACC>(), beta_1, beta_2, step, \ | |||
| lr, weight_decay, eps, bias_correction, always_adapt, total_elem, \ | |||
| stream); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,25 @@ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/cuda/cudnn_wrapper.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class LAMBUpdateImpl final : public LAMBUpdate { | |||
| public: | |||
| using LAMBUpdate::LAMBUpdate; | |||
| void exec( | |||
| _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||
| _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||
| _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
| _megdnn_tensor_out new_param, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, | |||
| const TensorLayout& m_t, const TensorLayout& v_t, | |||
| const TensorLayout& new_param) override { | |||
| return m_t.access_bytes(); | |||
| }; | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -37,6 +37,7 @@ | |||
| #include "src/naive/images2neibs/opr_impl.h" | |||
| #include "src/naive/indexing_multi_axis_vec/opr_impl.h" | |||
| #include "src/naive/indexing_one_hot/opr_impl.h" | |||
| #include "src/naive/lamb/opr_impl.h" | |||
| #include "src/naive/layer_norm/opr_impl.h" | |||
| #include "src/naive/linspace/opr_impl.h" | |||
| #include "src/naive/local/opr_impl.h" | |||
| @@ -0,0 +1,89 @@ | |||
| #include "src/naive/lamb/opr_impl.h" | |||
| #include <cmath> | |||
| #include <functional> | |||
| #include <numeric> | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace naive; | |||
| namespace { | |||
| using Param = megdnn::LAMBUpdate::Param; | |||
| template <typename T, typename T_ACC = float> | |||
| void update( | |||
| _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||
| _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
| _megdnn_tensor_out new_param, const Param& param) { | |||
| float beta_1 = param.beta_1; | |||
| float beta_2 = param.beta_2; | |||
| float step = param.step; | |||
| float lr = param.lr; | |||
| float weight_decay = param.weight_decay; | |||
| float eps = param.eps; | |||
| bool bias_correction = param.bias_correction; | |||
| bool always_adapt = param.always_adapt; | |||
| size_t total_elem = lamb_param.layout.total_nr_elems(); | |||
| T_ACC mt, vt, bc_1, bc_2, rt, d_norm = 0; | |||
| bc_1 = bias_correction ? 1 - pow(beta_1, step) : 1; | |||
| bc_2 = bias_correction ? 1 - pow(beta_2, step) : 1; | |||
| for (size_t i = 0; i < total_elem; i++) { | |||
| mt = m_t.ptr<T_ACC>()[i] = beta_1 * m_t_1.ptr<T_ACC>()[i] + | |||
| (1 - beta_1) * static_cast<T_ACC>(grad.ptr<T>()[i]); | |||
| vt = v_t.ptr<T_ACC>()[i] = | |||
| beta_2 * v_t_1.ptr<T_ACC>()[i] + | |||
| (1 - beta_2) * std::pow(static_cast<T_ACC>(grad.ptr<T>()[i]), 2); | |||
| rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); | |||
| if (weight_decay != 0) { | |||
| rt += lamb_param.ptr<T_ACC>()[i] * weight_decay; | |||
| } | |||
| d_norm += rt * rt; | |||
| } | |||
| d_norm = sqrt(d_norm); | |||
| auto get_norm = [=](_megdnn_tensor_in norm) -> T_ACC { | |||
| return sqrt(std::accumulate( | |||
| norm.ptr<T_ACC>(), norm.ptr<T_ACC>() + total_elem, 0, | |||
| [](T_ACC t1, T_ACC t2) -> T_ACC { return t1 + t2 * t2; })); | |||
| }; | |||
| T_ACC p_norm = get_norm(lamb_param), trust_ratio = 1; | |||
| if ((always_adapt || weight_decay > 0) && p_norm > 0 && d_norm > 0) { | |||
| trust_ratio = p_norm / d_norm; | |||
| } | |||
| for (size_t i = 0; i < total_elem; i++) { | |||
| mt = m_t.ptr<T_ACC>()[i]; | |||
| vt = v_t.ptr<T_ACC>()[i]; | |||
| rt = (mt / bc_1) / (sqrt(vt / bc_2) + eps); | |||
| if (weight_decay != 0) { | |||
| rt += lamb_param.ptr<T_ACC>()[i] * weight_decay; | |||
| } | |||
| new_param.ptr<T_ACC>()[i] = lamb_param.ptr<T_ACC>()[i] - lr * trust_ratio * rt; | |||
| } | |||
| } | |||
| } // namespace | |||
| namespace megdnn { | |||
| namespace naive { | |||
| void LAMBUpdateImpl::exec( | |||
| _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, _megdnn_tensor_in lamb_param, | |||
| _megdnn_tensor_in grad, _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
| _megdnn_tensor_out new_param, _megdnn_workspace workspace) { | |||
| check_exec( | |||
| m_t_1.layout, v_t_1.layout, lamb_param.layout, grad.layout, m_t.layout, | |||
| v_t.layout, new_param.layout, workspace.size); | |||
| #define cb(DType) \ | |||
| if (grad.layout.dtype == DType()) { \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(update<typename DTypeTrait<DType>::ctype>( \ | |||
| m_t_1, v_t_1, lamb_param, grad, m_t, v_t, new_param, param())); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
| #undef cb | |||
| megdnn_throw("bad dtype"); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,34 @@ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class LAMBUpdateImpl final : public LAMBUpdate { | |||
| public: | |||
| using LAMBUpdate::LAMBUpdate; | |||
| void exec( | |||
| _megdnn_tensor_in m_t_1, _megdnn_tensor_in v_t_1, | |||
| _megdnn_tensor_in lamb_param, _megdnn_tensor_in grad, | |||
| _megdnn_tensor_out m_t, _megdnn_tensor_out v_t, | |||
| _megdnn_tensor_out new_param, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& m_t_1, const TensorLayout& v_t_1, | |||
| const TensorLayout& lamb_param, const TensorLayout& grad, | |||
| const TensorLayout& m_t, const TensorLayout& v_t, | |||
| const TensorLayout& new_param) override { | |||
| MEGDNN_MARK_USED_VAR(m_t_1); | |||
| MEGDNN_MARK_USED_VAR(v_t_1); | |||
| MEGDNN_MARK_USED_VAR(lamb_param); | |||
| MEGDNN_MARK_USED_VAR(grad); | |||
| MEGDNN_MARK_USED_VAR(m_t); | |||
| MEGDNN_MARK_USED_VAR(v_t); | |||
| MEGDNN_MARK_USED_VAR(new_param); | |||
| return 0; | |||
| }; | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,36 @@ | |||
| #pragma once | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace lamb { | |||
| struct TestArg { | |||
| param::LAMBUpdate param; | |||
| TensorShape src; | |||
| TestArg(param::LAMBUpdate param, TensorShape src) : param(param), src(src) {} | |||
| }; | |||
| inline std::vector<TestArg> get_args() { | |||
| std::vector<TestArg> args; | |||
| param::LAMBUpdate cur_param; | |||
| cur_param.beta_1 = 0.9; | |||
| cur_param.beta_2 = 0.999; | |||
| cur_param.eps = 1e-8; | |||
| cur_param.weight_decay = 0; | |||
| cur_param.lr = 6.25e-5; | |||
| cur_param.bias_correction = true; | |||
| cur_param.always_adapt = false; | |||
| args.emplace_back( | |||
| cur_param, TensorShape{ | |||
| 1280, | |||
| }); | |||
| args.emplace_back(cur_param, TensorShape{1280, 1280}); | |||
| args.emplace_back(cur_param, TensorShape{1280, 3, 224, 224}); | |||
| return args; | |||
| } | |||
| } // namespace lamb | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,44 @@ | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/rng.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(CUDA, LAMBUpdate) { | |||
| LAMBUpdate::Param param; | |||
| param.beta_1 = 0.9; | |||
| param.beta_2 = 0.999; | |||
| param.eps = 1e-5; | |||
| param.weight_decay = 0.4; | |||
| param.lr = 1e-3; | |||
| param.step = 1; | |||
| param.bias_correction = true; | |||
| param.always_adapt = false; | |||
| Checker<LAMBUpdate> checker(handle_cuda()); | |||
| checker.set_epsilon(1e-3); | |||
| UniformFloatRNG rng0(0, 1); | |||
| auto run = [&](DType d) { | |||
| checker.set_param(param) | |||
| .set_rng(0, &rng0) | |||
| .set_rng(1, &rng0) | |||
| .set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Float32()) | |||
| .set_dtype(3, d) | |||
| .set_dtype(4, dtype::Float32()) | |||
| .set_dtype(5, dtype::Float32()) | |||
| .set_dtype(6, dtype::Float32()) | |||
| .execs({{2}, {2}, {2}, {2}, {}, {}, {}}); | |||
| }; | |||
| run(dtype::Float32()); | |||
| run(dtype::Float16()); | |||
| run(dtype::BFloat16()); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,33 @@ | |||
| #include "test/common/lamb.h" | |||
| #include "megdnn/dtype.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/naive/fixture.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| TEST_F(NAIVE, LAMBUpdate) { | |||
| Checker<LAMBUpdate> checker(handle(), false); | |||
| LAMBUpdate::Param param; | |||
| param.beta_1 = 0; | |||
| param.beta_2 = 0; | |||
| param.eps = 0; | |||
| param.weight_decay = 0; | |||
| param.lr = 1; | |||
| param.step = 1; | |||
| param.bias_correction = true; | |||
| param.always_adapt = false; | |||
| TensorND m_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
| TensorND v_t_1 = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
| TensorND param_lamb = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
| TensorND grad = TensorValue({2}, dtype::Float16(), {1, 1}); | |||
| TensorND m_t = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
| TensorND v_t = TensorValue({2}, dtype::Float32(), {1, 1}); | |||
| TensorND new_param = TensorValue({2}, dtype::Float32(), {0, 0}); | |||
| checker.set_param(param).exect( | |||
| Testcase{m_t_1, v_t_1, param_lamb, grad, {}, {}, {}}, | |||
| Testcase{{}, {}, {}, {}, m_t, v_t, new_param}); | |||
| } | |||
| @@ -4,6 +4,7 @@ from .adagrad import Adagrad | |||
| from .adam import Adam | |||
| from .adamw import AdamW | |||
| from .clip_grad import * | |||
| from .lamb import LAMB, LAMBFp16 | |||
| from .lr_scheduler import LRScheduler | |||
| from .multi_step_lr import MultiStepLR | |||
| from .optimizer import Optimizer | |||
| @@ -0,0 +1,160 @@ | |||
| # Copyright (c) 2020 Ross Wightman | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| """LAMB optimizer | |||
| References: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py | |||
| """ | |||
| import os | |||
| from typing import Iterable, Tuple, Union | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core.ops.builtin import LAMBUpdate | |||
| from .. import Parameter, tensor | |||
| from ..functional import sum | |||
| from ..functional.inplace import _inplace_add_ | |||
| from .optimizer import Optimizer | |||
| class LAMB(Optimizer): | |||
| r"""Implements LAMB algorithm. | |||
| LAMB is proposed in `"Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" | |||
| <https://arxiv.org/abs/1904.00962>`_. | |||
| Args: | |||
| params: iterable of parameters to optimize or dicts defining parameter groups. | |||
| lr: learning rate. | |||
| betas: coefficients used for computing running averages of gradient and its square. | |||
| Default: ``(0.9, 0.999)`` | |||
| eps: term added to the denominator to improve numerical stability. Default: ``1e-8`` | |||
| bias_correction: enables bias correction by ``1 - beta ** step``. Default: ``True`` | |||
| weight_decay: weight decay (L2 penalty). Default: ``0.0`` | |||
| always_adapt: apply adaptive lr to ``0.0`` weight decay parameter. Default: ``False`` | |||
| """ | |||
| def __init__( | |||
| self, | |||
| params: Union[Iterable[Parameter], dict], | |||
| lr: float, | |||
| betas: Tuple[float, float] = (0.9, 0.999), | |||
| eps: float = 1e-8, | |||
| bias_correction: bool = True, | |||
| weight_decay: float = 0.0, | |||
| always_adapt: bool = False, | |||
| ): | |||
| if lr < 0.0: | |||
| raise ValueError("Invalid learning rate: {}".format(lr)) | |||
| if weight_decay < 0.0: | |||
| raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | |||
| if not 0.0 <= betas[0] < 1.0: | |||
| raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | |||
| if not 0.0 <= betas[1] < 1.0: | |||
| raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | |||
| defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) | |||
| super().__init__(params, defaults) | |||
| self.bias_correction = bias_correction | |||
| self.always_adapt = always_adapt | |||
| self._disable_type_convert = True | |||
| def _create_state(self, param_group): | |||
| for param in param_group["params"]: | |||
| self._add_state(param, "exp_avg") | |||
| self._add_state(param, "exp_avg_sq") | |||
| self._add_state(param, "step", initializer=0.0, dtype="float32") | |||
| def _updates(self, param_group): | |||
| lr = param_group["lr"] | |||
| weight_decay = param_group["weight_decay"] | |||
| eps = param_group["eps"] | |||
| beta0, beta1 = param_group["betas"] | |||
| # since `conver_inputs` is disabled for param updates, | |||
| # scalar should be explicitly tansforred to tensor | |||
| c1 = tensor(1.0) | |||
| for param in param_group["params"]: | |||
| if param.grad is None: | |||
| continue | |||
| grad = param.grad | |||
| states = self._state[param] | |||
| step, exp_avg, exp_avg_sq = ( | |||
| states["step"], | |||
| states["exp_avg"], | |||
| states["exp_avg_sq"], | |||
| ) | |||
| step += c1 | |||
| op = LAMBUpdate( | |||
| beta0, | |||
| beta1, | |||
| int(step), | |||
| lr, | |||
| weight_decay, | |||
| eps, | |||
| self.bias_correction, | |||
| self.always_adapt, | |||
| ) | |||
| new_exp_avg, new_exp_avg_sq, new_param = apply( | |||
| op, exp_avg, exp_avg_sq, param, grad | |||
| ) | |||
| param._reset(new_param) | |||
| exp_avg._reset(new_exp_avg) | |||
| exp_avg_sq._reset(new_exp_avg_sq) | |||
| class LAMBFp16(LAMB): | |||
| def _create_state(self, param_group): | |||
| for param in param_group["params"]: | |||
| self._add_state(param, "exp_avg", dtype="float32") | |||
| self._add_state(param, "exp_avg_sq", dtype="float32") | |||
| self._add_state(param, "step", initializer=0.0, dtype="float32") | |||
| self._state[param]["param_fp32"] = param.astype("float32") | |||
| def _updates(self, param_group): | |||
| lr = param_group["lr"] | |||
| weight_decay = param_group["weight_decay"] | |||
| eps = param_group["eps"] | |||
| beta0, beta1 = param_group["betas"] | |||
| c1 = tensor(1.0) | |||
| for param in param_group["params"]: | |||
| if param.grad is None: | |||
| continue | |||
| grad = param.grad | |||
| states = self._state[param] | |||
| step, exp_avg, exp_avg_sq = ( | |||
| states["step"], | |||
| states["exp_avg"], | |||
| states["exp_avg_sq"], | |||
| ) | |||
| step += c1 | |||
| fp32_param = states["param_fp32"] | |||
| op = LAMBUpdate( | |||
| beta0, | |||
| beta1, | |||
| step, | |||
| lr, | |||
| weight_decay, | |||
| eps, | |||
| self.bias_correction, | |||
| self.always_adapt, | |||
| ) | |||
| new_exp_avg, new_exp_avg_sq, new_param = apply( | |||
| op, exp_avg, exp_avg_sq, fp32_param, grad | |||
| ) | |||
| fp32_param._reset(new_param) | |||
| param._reset(new_param.astype("float16")) | |||
| exp_avg._reset(new_exp_avg) | |||
| exp_avg_sq._reset(new_exp_avg_sq) | |||
| @@ -0,0 +1,85 @@ | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| from megengine import tensor | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core.ops.builtin import LAMBUpdate | |||
| def lamb_update( | |||
| param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt | |||
| ): | |||
| lr = param_group["lr"] | |||
| weight_decay = param_group["weight_decay"] | |||
| eps = param_group["eps"] | |||
| beta0, beta1 = param_group["betas"] | |||
| # since `conver_inputs` is disabled for param updates, | |||
| # scalar should be explicitly tansforred to tensor | |||
| _lr, _neg_lr = map(tensor, (lr, -lr)) | |||
| _weight_decay = tensor(weight_decay) | |||
| _eps = tensor(eps) | |||
| _beta0, _beta1 = map(tensor, (beta0, beta1)) | |||
| c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0)) | |||
| def norm(vec): | |||
| return sum(vec * vec) ** c05 | |||
| p_norm = norm(param.flatten()) | |||
| # step = step + c1 | |||
| step += c1 | |||
| # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) | |||
| exp_avg *= _beta0 | |||
| exp_avg += grad * (c1 - _beta0) | |||
| # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) | |||
| exp_avg_sq *= _beta1 | |||
| exp_avg_sq += (c1 - _beta1) * (grad * grad) | |||
| bias_correction1 = c1 - _beta0 ** step if bias_correction else c1 | |||
| bias_correction2 = c1 - _beta1 ** step if bias_correction else c1 | |||
| delta = (exp_avg / bias_correction1) / ( | |||
| (exp_avg_sq / bias_correction2) ** c05 + _eps | |||
| ) | |||
| if weight_decay != 0.0: | |||
| delta += param * _weight_decay | |||
| d_norm = norm(delta.flatten()) | |||
| trust_ratio = ( | |||
| p_norm / d_norm | |||
| if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0 | |||
| else c1 | |||
| ) | |||
| new_param = param - _lr * trust_ratio * delta | |||
| return exp_avg, exp_avg_sq, new_param | |||
| def test_lamb(): | |||
| op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False) | |||
| m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||
| v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||
| params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32) | |||
| grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16) | |||
| (new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad) | |||
| param_group = { | |||
| "betas": (0.9, 0.999), | |||
| "step": 1, | |||
| "lr": 1e-3, | |||
| "weight_decay": 0.4, | |||
| "eps": 1e-8, | |||
| } | |||
| gt_m_t, gt_v_t, gt_new_param = lamb_update( | |||
| param_group, 1, m_t_1, v_t_1, params, grad, True, False | |||
| ) | |||
| np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2) | |||
| np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2) | |||
| np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2) | |||
| @@ -0,0 +1,82 @@ | |||
| #include "megbrain/imperative/opr_utility.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../blob_manager_impl.h" | |||
| #include "../dnn_op_helper.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| namespace lamb { | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
| return layout_checker; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) { | |||
| mgb_assert(input_descs.size() == 4, "IndexingOneHot expects 4inputs"); | |||
| auto comp_node = input_descs[0].comp_node; | |||
| auto comp_node1 = input_descs[1].comp_node; | |||
| auto comp_node2 = input_descs[2].comp_node; | |||
| TensorLayout m_t_1 = input_descs[0].layout, v_t_1 = input_descs[1].layout, | |||
| lamb_param = input_descs[2].layout, grad = input_descs[3].layout; | |||
| TensorLayout new_param = lamb_param, m_t = m_t_1, v_t = v_t_1; | |||
| return {{{m_t, comp_node}, {v_t, comp_node1}, {new_param, comp_node2}}, true}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op = def.cast_final_safe<LAMBUpdate>(); | |||
| auto&& m_t_1 = inputs[0]; | |||
| auto&& v_t_1 = inputs[1]; | |||
| auto&& lamb_param = inputs[2]; | |||
| auto&& grad = inputs[3]; | |||
| TensorLayout m_t_1_layout{m_t_1->layout()}; | |||
| TensorLayout v_t_1_layout{v_t_1->layout()}; | |||
| TensorLayout lamb_param_layout{lamb_param->layout()}; | |||
| DeviceTensorND m_t = BlobManager::inst()->alloc_workspace_with_defrag( | |||
| m_t_1->comp_node(), m_t_1_layout); | |||
| DeviceTensorND v_t = BlobManager::inst()->alloc_workspace_with_defrag( | |||
| v_t_1->comp_node(), v_t_1_layout); | |||
| DeviceTensorND new_param = BlobManager::inst()->alloc_workspace_with_defrag( | |||
| lamb_param->comp_node(), lamb_param_layout); | |||
| DnnOprCaller<megdnn::LAMBUpdate> caller{lamb_param->comp_node()}; | |||
| TensorLayout m_layout( | |||
| {caller.op->get_workspace_in_bytes( | |||
| m_t_1->layout(), v_t_1->layout(), lamb_param->layout(), | |||
| grad->layout(), m_t.layout(), v_t.layout(), new_param.layout())}, | |||
| dtype::Byte()); | |||
| auto dnn_workspace = caller.create_workspace(m_layout); | |||
| caller.op->param() = op.param(); | |||
| caller.op->exec( | |||
| m_t_1->dev_tensor().as_megdnn(), v_t_1->dev_tensor().as_megdnn(), | |||
| lamb_param->dev_tensor().as_megdnn(), grad->dev_tensor().as_megdnn(), | |||
| m_t.as_megdnn(), v_t.as_megdnn(), new_param.as_megdnn(), dnn_workspace); | |||
| return {Tensor::make(m_t), Tensor::make(v_t), Tensor::make(new_param)}; | |||
| } | |||
| OP_TRAIT_REG(LAMBUpdate, LAMBUpdate) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| } // namespace lamb | |||
| } // namespace | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -477,6 +477,9 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||
| def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
| def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | |||
| def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; | |||
| def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | |||
| def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; | |||